mathy/mathstream/engine.py
2025-11-05 08:35:01 +01:00

330 lines
9.5 KiB
Python

from __future__ import annotations
from typing import Iterable, Tuple
from .exceptions import DivideByZeroError
from .number import StreamNumber, LOG_DIR
from .utils import register_log_file, wipe_log_records
def _ensure_log_dir() -> None:
LOG_DIR.mkdir(parents=True, exist_ok=True)
def _strip_leading_zeros(digits: str) -> str:
digits = digits.lstrip("0")
return digits or "0"
def _normalize_stream(num: StreamNumber) -> Tuple[int, str]:
"""Return (sign, digits) tuple for the streamed number."""
parts: list[str] = []
for chunk in num.stream():
chunk = chunk.strip()
if not chunk:
continue
parts.append(chunk)
raw = "".join(parts)
if not raw:
raise ValueError(f"Stream for {num.path} is empty")
sign = 1
if raw[0] in "+-":
sign = -1 if raw[0] == "-" else 1
raw = raw[1:]
if not raw.isdigit():
raise ValueError(f"Non-digit characters found in stream for {num.path}")
digits = _strip_leading_zeros(raw)
if digits == "0":
sign = 1
return sign, digits
def _compare_abs(a: str, b: str) -> int:
"""Compare two positive digit strings."""
if len(a) != len(b):
return 1 if len(a) > len(b) else -1
if a == b:
return 0
return 1 if a > b else -1
def _add_abs(a: str, b: str) -> str:
carry = 0
idx_a = len(a) - 1
idx_b = len(b) - 1
out: list[str] = []
while idx_a >= 0 or idx_b >= 0 or carry:
da = ord(a[idx_a]) - 48 if idx_a >= 0 else 0
db = ord(b[idx_b]) - 48 if idx_b >= 0 else 0
total = da + db + carry
carry, digit = divmod(total, 10)
out.append(str(digit))
idx_a -= 1
idx_b -= 1
return "".join(reversed(out))
def _sub_abs(a: str, b: str) -> str:
"""Return a - b for digit strings assuming a >= b."""
borrow = 0
idx_a = len(a) - 1
idx_b = len(b) - 1
out: list[str] = []
while idx_a >= 0:
da = ord(a[idx_a]) - 48
db = ord(b[idx_b]) - 48 if idx_b >= 0 else 0
diff = da - borrow - db
if diff < 0:
diff += 10
borrow = 1
else:
borrow = 0
out.append(str(diff))
idx_a -= 1
idx_b -= 1
return _strip_leading_zeros("".join(reversed(out)))
def _multiply_abs(a: str, b: str) -> str:
if a == "0" or b == "0":
return "0"
result = [0] * (len(a) + len(b))
for i in range(len(a) - 1, -1, -1):
ai = ord(a[i]) - 48
carry = 0
for j in range(len(b) - 1, -1, -1):
bj = ord(b[j]) - 48
pos = i + j + 1
total = result[pos] + ai * bj + carry
carry, result[pos] = divmod(total, 10)
result[i] += carry
return _strip_leading_zeros("".join(str(d) for d in result))
def _multiply_digit(num: str, digit: int) -> str:
if digit == 0 or num == "0":
return "0"
carry = 0
out: list[str] = []
for i in range(len(num) - 1, -1, -1):
total = (ord(num[i]) - 48) * digit + carry
carry, d = divmod(total, 10)
out.append(str(d))
if carry:
out.append(str(carry))
return "".join(reversed(out))
def _divide_abs(dividend: str, divisor: str) -> Tuple[str, str]:
if divisor == "0":
raise DivideByZeroError("division by zero")
if dividend == "0":
return "0", "0"
quotient_digits: list[str] = []
remainder = "0"
for digit in dividend:
remainder = _strip_leading_zeros(remainder + digit)
q_digit = 0
for guess in range(9, -1, -1):
candidate = _multiply_digit(divisor, guess)
if _compare_abs(candidate, remainder) <= 0:
q_digit = guess
remainder = _sub_abs(remainder, candidate) if guess else remainder
break
quotient_digits.append(str(q_digit))
quotient = _strip_leading_zeros("".join(quotient_digits))
remainder = _strip_leading_zeros(remainder)
return quotient, remainder
def _is_zero(digits: str) -> bool:
return digits == "0"
def _is_odd(digits: str) -> bool:
return (ord(digits[-1]) - 48) % 2 == 1
def _halve(digits: str) -> str:
carry = 0
out: list[str] = []
for ch in digits:
current = carry * 10 + (ord(ch) - 48)
quotient = current // 2
carry = current % 2
out.append(str(quotient))
return _strip_leading_zeros("".join(out))
def _write_result(operation: str, operands: Iterable[StreamNumber], digits: str) -> StreamNumber:
_ensure_log_dir()
operand_hash = "_".join(num.hash for num in operands)
out_file = LOG_DIR / f"{operation}_{operand_hash}.bin"
with open(out_file, "w", encoding="utf-8") as out:
out.write(digits)
register_log_file(out_file)
return StreamNumber(out_file)
def clear_logs():
if LOG_DIR.exists():
for p in LOG_DIR.glob("*"):
p.unlink()
_ensure_log_dir()
wipe_log_records()
def add(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
"""Return num_a + num_b without loading full ints into memory."""
sign_a, a_digits = _normalize_stream(num_a)
sign_b, b_digits = _normalize_stream(num_b)
if sign_a == sign_b:
digits = _add_abs(a_digits, b_digits)
sign = sign_a
else:
cmp = _compare_abs(a_digits, b_digits)
if cmp == 0:
digits = "0"
sign = 1
elif cmp > 0:
digits = _sub_abs(a_digits, b_digits)
sign = sign_a
else:
digits = _sub_abs(b_digits, a_digits)
sign = sign_b
result = digits if sign > 0 or digits == "0" else f"-{digits}"
return _write_result("add", (num_a, num_b), result)
def sub(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
"""Return num_a - num_b using streamed integer arithmetic."""
sign_a, a_digits = _normalize_stream(num_a)
sign_b, b_digits = _normalize_stream(num_b)
if sign_a != sign_b:
digits = _add_abs(a_digits, b_digits)
sign = sign_a
else:
cmp = _compare_abs(a_digits, b_digits)
if cmp == 0:
digits = "0"
sign = 1
elif cmp > 0:
digits = _sub_abs(a_digits, b_digits)
sign = sign_a
else:
digits = _sub_abs(b_digits, a_digits)
sign = -sign_a
result = digits if sign > 0 or digits == "0" else f"-{digits}"
return _write_result("sub", (num_a, num_b), result)
def mul(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
"""Return num_a * num_b with grade-school multiplication."""
sign_a, a_digits = _normalize_stream(num_a)
sign_b, b_digits = _normalize_stream(num_b)
digits = _multiply_abs(a_digits, b_digits)
sign = 1 if digits == "0" else sign_a * sign_b
result = digits if sign > 0 else f"-{digits}"
return _write_result("mul", (num_a, num_b), result)
def div(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
"""Return floor division num_a // num_b with streamed long division."""
sign_a, a_digits = _normalize_stream(num_a)
sign_b, b_digits = _normalize_stream(num_b)
quotient, remainder = _divide_abs(a_digits, b_digits)
if quotient == "0" and remainder == "0":
return _write_result("div", (num_a, num_b), "0")
sign_product = sign_a * sign_b
if sign_product < 0 and remainder != "0":
quotient = _add_abs(quotient, "1")
sign = -1
else:
sign = sign_product if quotient != "0" else 1
result = quotient if sign > 0 else f"-{quotient}"
return _write_result("div", (num_a, num_b), result)
def mod(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
"""Return num_a % num_b following Python's floor-division semantics."""
sign_a, a_digits = _normalize_stream(num_a)
sign_b, b_digits = _normalize_stream(num_b)
if b_digits == "0":
raise DivideByZeroError("modulo by zero")
_, remainder = _divide_abs(a_digits, b_digits)
if remainder == "0":
return _write_result("mod", (num_a, num_b), "0")
if sign_a == sign_b:
digits = remainder
else:
digits = _sub_abs(b_digits, remainder)
sign = 1 if sign_b > 0 else -1
result = digits if sign > 0 else f"-{digits}"
return _write_result("mod", (num_a, num_b), result)
def pow(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
"""Return num_a ** num_b using repeated squaring (integer exponent only)."""
base_sign, base_digits = _normalize_stream(num_a)
exp_sign, exp_digits = _normalize_stream(num_b)
if exp_sign < 0:
raise ValueError("Negative exponents are not supported for integer streams.")
if exp_digits == "0":
return _write_result("pow", (num_a, num_b), "1")
result_digits = "1"
base_abs = base_digits
exponent = exp_digits
while not _is_zero(exponent):
if _is_odd(exponent):
result_digits = _multiply_abs(result_digits, base_abs)
exponent = _halve(exponent)
if not _is_zero(exponent):
base_abs = _multiply_abs(base_abs, base_abs)
base_negative = base_sign < 0
result_sign = -1 if base_negative and _is_odd(exp_digits) else 1
if result_digits == "0":
result_sign = 1
result = result_digits if result_sign > 0 else f"-{result_digits}"
return _write_result("pow", (num_a, num_b), result)
def is_even(num: StreamNumber) -> bool:
"""Return True if the streamed integer is even."""
_, digits = _normalize_stream(num)
return (ord(digits[-1]) - 48) % 2 == 0
def is_odd(num: StreamNumber) -> bool:
"""Return True if the streamed integer is odd."""
return not is_even(num)