"""
CSD Multiplier — Verilog Code Generation
Translates a Canonical Signed Digit (CSD) string into a synthesizable Verilog
module that implements constant multiplication using only shift-and-add/subtract
operations. When the CSD string contains repeated patterns, LCSRe (Longest
Common Substring with Repeated Elements) is used to share hardware via a
sub-expression wire.
"""
from typing import Optional
from csdigit.lcsre import longest_repeated_substring
def _build_range_expr(csd: str, start: int, length: int, max_power: int) -> str:
"""Build a flat Verilog expression for csd[start:start+length].
Returns empty string if the range has no non-zero digits.
"""
expr_parts: list[str] = []
for i in range(start, start + length):
power = max_power - i
ch = csd[i]
if ch == "+":
expr_parts.append(f"x_shift{power}")
elif ch == "-":
expr_parts.append(f"-x_shift{power}")
if not expr_parts:
return ""
# Build signed expression: first term without sign prefix, rest with signs
result = expr_parts[0]
for part in expr_parts[1:]:
if part.startswith("-"):
result += " - " + part[1:]
else:
result += " + " + part
return result
def _find_pattern_occurrences(csd: str, pattern: str) -> list[int]:
"""Find all non-overlapping positions of |pattern| in |csd|."""
positions: list[int] = []
pos = 0
while True:
pos = csd.find(pattern, pos)
if pos == -1:
break
positions.append(pos)
pos += len(pattern)
return positions
[docs]
def generate_csd_multiplier(csd: str, input_width: int, max_power: int) -> str:
"""
Generate Verilog code for a CSD multiplier module with LCSRe optimization.
When the CSD string contains repeated non-overlapping patterns (detected
via the longest_repeated_substring algorithm), the generated Verilog
factors out a shared sub-expression wire — reducing the number of adders.
Args:
csd: CSD string (e.g., "+00-00+0+")
input_width: Bit width of the input signal x
max_power: Highest power of two in the CSD (must be len(csd)-1)
Returns:
Verilog module code as a string
Raises:
ValueError: If csd length doesn't match max_power+1 or the string
contains characters other than '+', '-', '0'
"""
# --- validation ---
if len(csd) != max_power + 1:
raise ValueError(
f"CSD length {len(csd)} doesn't match max_power={max_power} "
f"(should be max_power+1)"
)
if not all(ch in "+-0" for ch in csd):
raise ValueError("CSD string can only contain '+', '-', or '0'")
# Parse CSD and collect non-zero terms
terms: list[tuple[int, str]] = []
for index, ch in enumerate(csd):
power = max_power - index
if ch == "+":
terms.append((power, "add"))
elif ch == "-":
terms.append((power, "sub"))
output_width = input_width + max_power
# --- module header ---
verilog = f"""
module csd_multiplier (
input signed [{input_width - 1}:0] x, // Input value
output signed [{output_width - 1}:0] result // Result of multiplication
);"""
# --- wire declarations (deduplicated powers) ---
if terms:
verilog += "\n\n // Create shifted versions of input"
for p in sorted({p for p, _op in terms}, reverse=True):
verilog += (
f"\n wire signed [{output_width - 1}:0] x_shift{p} = x <<< {p};"
)
# --- detect LCSRe optimization opportunity ---
repeated = longest_repeated_substring(csd)
use_opt = False
pat_positions: list[int] = []
if repeated and len(repeated) > 1:
pat_nnz = repeated.count("+") + repeated.count("-")
if pat_nnz >= 2:
pat_positions = _find_pattern_occurrences(csd, repeated)
use_opt = len(pat_positions) >= 2
# --- combinational logic ---
if not terms:
verilog += "\n\n // CSD implementation"
verilog += "\n assign result = 0;"
elif use_opt:
# --- LCSRe-optimized path: share repeated sub-expression ---
base_pos = pat_positions[0]
pat_expr = _build_range_expr(csd, base_pos, len(repeated), max_power)
verilog += f'\n\n // LCSRe: repeated pattern "{repeated}"'
verilog += f"\n wire signed [{output_width - 1}:0] _pat = {pat_expr};"
# Build full expression from segments
expr_parts: list[str] = []
cur = 0
for pos in pat_positions:
# prefix/gap before this occurrence
if pos > cur:
gap_expr = _build_range_expr(csd, cur, pos - cur, max_power)
if gap_expr:
expr_parts.append(gap_expr)
# the pattern occurrence (shifted if not the first)
shift = pos - base_pos
if shift == 0:
expr_parts.append("_pat")
else:
expr_parts.append(f"(_pat >>> {shift})")
cur = pos + len(repeated)
# suffix after the last occurrence
if cur < len(csd):
suffix_expr = _build_range_expr(csd, cur, len(csd) - cur, max_power)
if suffix_expr:
expr_parts.append(suffix_expr)
# Join all parts with " + ", but handle leading minus terms correctly
# (first part might start with "-" from a negative gap term)
result_expr = " + ".join(expr_parts)
verilog += "\n\n // CSD implementation (LCSRe optimized)"
verilog += f"\n assign result = {result_expr};"
else:
# --- flat path (no repeated pattern) ---
first_power, first_op = terms[0]
if first_op == "sub":
expr = f"-x_shift{first_power}"
else:
expr = f"x_shift{first_power}"
for power, operation in terms[1:]:
if operation == "add":
expr += f" + x_shift{power}"
else:
expr += f" - x_shift{power}"
verilog += "\n\n // CSD implementation"
verilog += f"\n assign result = {expr};"
verilog += "\nendmodule\n"
return verilog
# ---------------------------------------------------------------------------
# Cross-CSE: multiple CSD multipliers sharing sub-expressions
# ---------------------------------------------------------------------------
def _find_cross_patterns(
csd_list: list[str], min_nnz: int = 2
) -> dict[str, list[tuple[int, int]]]:
"""Find substrings (NNZ >= min_nnz) appearing in >= 2 CSD strings.
Returns dict mapping pattern -> [(csd_idx, position), ...].
"""
patterns: dict[str, list[tuple[int, int]]] = {}
for ci, csd in enumerate(csd_list):
n = len(csd)
for i in range(n):
for j in range(i + 2, n + 1):
sub = csd[i:j]
nnz = sub.count("+") + sub.count("-")
if nnz >= min_nnz:
patterns.setdefault(sub, []).append((ci, i))
# Keep only patterns crossing >= 2 different CSD strings
return {
sub: occ for sub, occ in patterns.items() if len({ci for ci, _ in occ}) >= 2
}
def _build_coeff_expr(
csd: str,
max_power: int,
pattern: Optional[str],
base_pos: int,
cse_name: str,
) -> str:
"""Build a single coefficient's expression, using shared CSE wire if applicable.
The expression is: gap_terms + shifted_cse + gap_terms + ...
"""
if pattern is None:
# Flat expression (no CSE)
return _build_range_expr(csd, 0, len(csd), max_power)
parts: list[str] = []
cur = 0
positions = _find_pattern_occurrences(csd, pattern)
for pos in positions:
# gap before this occurrence
if pos > cur:
gap = _build_range_expr(csd, cur, pos - cur, max_power)
if gap:
parts.append(gap)
# CSE reference
shift = pos - base_pos
if shift == 0:
parts.append(cse_name)
else:
parts.append(f"({cse_name} >>> {shift})")
cur = pos + len(pattern)
# suffix
if cur < len(csd):
gap = _build_range_expr(csd, cur, len(csd) - cur, max_power)
if gap:
parts.append(gap)
if not parts:
return ""
return " + ".join(parts)
[docs]
def generate_csd_multipliers(
coeffs: list[tuple[str, str, int, int]],
module_name: str = "csd_filter",
) -> str:
"""Generate a Verilog module with multiple CSD multipliers and cross-CSE.
When the same CSD substring appears in multiple coefficients, a shared
sub-expression wire is created — reducing total adder count across the
entire filter.
All coefficients must share the same ``input_width`` and ``max_power``
so that the same bit position encodes the same power of two across all
multipliers. If a coefficient is narrower, pad its CSD with leading
``'0'`` characters.
Args:
coeffs: List of (output_name, csd_str, input_width, max_power) tuples.
All entries **must** share the same input_width and max_power.
module_name: Name for the generated Verilog module.
Returns:
Verilog module code as a string
"""
if not coeffs:
raise ValueError("At least one coefficient is required")
# --- validation & uniform width enforcement ---
input_width = coeffs[0][2]
max_power = coeffs[0][3]
for name, csd, iw, mp in coeffs:
if iw != input_width or mp != max_power:
raise ValueError(
"All coefficients must share the same input_width and max_power "
f"for cross-CSE. Got ({iw},{mp}) for '{name}', "
f"expected ({input_width},{max_power}). "
"Pad narrower CSDs with leading '0' characters."
)
if len(csd) != max_power + 1:
raise ValueError(
f"CSD '{csd}' length {len(csd)} doesn't match "
f"max_power={max_power} for coefficient '{name}'"
)
if not all(ch in "+-0" for ch in csd):
raise ValueError(
f"CSD string '{csd}' for '{name}' can only contain '+', '-', or '0'"
)
output_width = input_width + max_power
# --- collect all x_shift powers needed ---
all_powers: set[int] = set()
for _, csd, _, _ in coeffs:
for idx, ch in enumerate(csd):
if ch != "0":
all_powers.add(max_power - idx)
all_powers_sorted = sorted(all_powers, reverse=True)
# --- find best cross-CSD pattern ---
csd_strings = [csd for _, csd, _, _ in coeffs]
cross = _find_cross_patterns(csd_strings)
best_pattern: Optional[str] = None
best_occurrences: list[tuple[int, int]] = []
if cross:
def _scores(item):
sub, occ = item
nnz = sub.count("+") + sub.count("-")
return (nnz - 1) * (len(occ) - 1)
best_pattern, best_occurrences = max(cross.items(), key=_scores)
# --- determine base position for the CSE wire ---
cse_base_pos = 0
if best_pattern and best_occurrences:
cse_base_pos = min(pos for _, pos in best_occurrences)
# --- build Verilog ---
verilog = f"\nmodule {module_name} ("
verilog += f"\n input signed [{input_width - 1}:0] x, // Input value"
for name, _csd, _iw, _mp in coeffs:
ow = _iw + _mp
verilog += f"\n output signed [{ow - 1}:0] {name}"
verilog += "\n);"
# x_shift wires (use output_width for all)
if all_powers:
verilog += "\n\n // Create shifted versions of input"
for p in all_powers_sorted:
verilog += (
f"\n wire signed [{output_width - 1}:0] x_shift{p} = x <<< {p};"
)
cse_name = "_cse_0"
if best_pattern:
cse_expr = _build_range_expr(
best_pattern, 0, len(best_pattern), max_power - cse_base_pos
)
verilog += f'\n\n // Cross-CSE: shared pattern "{best_pattern}"'
verilog += f"\n wire signed [{output_width - 1}:0] {cse_name} = {cse_expr};"
# Per-coefficient assignments
for idx, (name, csd_str, _iw, _mp) in enumerate(coeffs):
verilog += f"\n\n // {name}: {csd_str}"
ow = _iw + _mp
if best_pattern and any(ci == idx for ci, _ in best_occurrences):
expr = _build_coeff_expr(
csd_str, max_power, best_pattern, cse_base_pos, cse_name
)
else:
expr = _build_coeff_expr(csd_str, max_power, None, 0, "")
if not expr:
verilog += f"\n wire signed [{ow - 1}:0] {name} = 0;"
else:
verilog += f"\n wire signed [{ow - 1}:0] {name} = {expr};"
verilog += "\nendmodule\n"
return verilog
# Example usage
if __name__ == "__main__":
# No repeated pattern
print("=== +00-00+0 (no repeat) ===")
print(generate_csd_multiplier("+00-00+0", 8, 7))
# Repeated pattern "+0-0" — optimized with _pat wire
print("=== +0-0+0-0 (repeat: +0-0) ===")
print(generate_csd_multiplier("+0-0+0-0", 8, 7))
# Triple repeat
print("=== +0-0+0-0+0-0 (triple repeat) ===")
print(generate_csd_multiplier("+0-0+0-0+0-0", 8, 11))
# Longer pattern
print("=== +00-00+00-00 (repeat: +00-00) ===")
print(generate_csd_multiplier("+00-00+00-00", 8, 11))
# All zeros
print("=== 000 (all zeros) ===")
print(generate_csd_multiplier("000", 8, 2))
# Leading minus — also has sign fix
print("=== -0- (no repeat benefit) ===")
print(generate_csd_multiplier("-0-", 8, 2))
# --- Cross-CSE demo ---
print("\n===== Cross-CSE Demo =====")
verilog = generate_csd_multipliers(
[
("h0", "+0-0+0-0", 8, 7),
("h1", "+0-0+0-0+0-0", 8, 11),
("h2", "+00-00+00-00", 8, 11),
("h3", "-0+0-0+", 8, 5),
],
module_name="fir_taps",
)
print(verilog)