330 lines
9.5 KiB
Python
330 lines
9.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Iterable, Tuple
|
|
|
|
from .exceptions import DivideByZeroError
|
|
from .number import StreamNumber, LOG_DIR
|
|
from .utils import register_log_file, wipe_log_records
|
|
|
|
|
|
def _ensure_log_dir() -> None:
|
|
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def _strip_leading_zeros(digits: str) -> str:
|
|
digits = digits.lstrip("0")
|
|
return digits or "0"
|
|
|
|
|
|
def _normalize_stream(num: StreamNumber) -> Tuple[int, str]:
|
|
"""Return (sign, digits) tuple for the streamed number."""
|
|
parts: list[str] = []
|
|
for chunk in num.stream():
|
|
chunk = chunk.strip()
|
|
if not chunk:
|
|
continue
|
|
parts.append(chunk)
|
|
raw = "".join(parts)
|
|
|
|
if not raw:
|
|
raise ValueError(f"Stream for {num.path} is empty")
|
|
|
|
sign = 1
|
|
if raw[0] in "+-":
|
|
sign = -1 if raw[0] == "-" else 1
|
|
raw = raw[1:]
|
|
|
|
if not raw.isdigit():
|
|
raise ValueError(f"Non-digit characters found in stream for {num.path}")
|
|
|
|
digits = _strip_leading_zeros(raw)
|
|
if digits == "0":
|
|
sign = 1
|
|
return sign, digits
|
|
|
|
|
|
def _compare_abs(a: str, b: str) -> int:
|
|
"""Compare two positive digit strings."""
|
|
if len(a) != len(b):
|
|
return 1 if len(a) > len(b) else -1
|
|
if a == b:
|
|
return 0
|
|
return 1 if a > b else -1
|
|
|
|
|
|
def _add_abs(a: str, b: str) -> str:
|
|
carry = 0
|
|
idx_a = len(a) - 1
|
|
idx_b = len(b) - 1
|
|
out: list[str] = []
|
|
|
|
while idx_a >= 0 or idx_b >= 0 or carry:
|
|
da = ord(a[idx_a]) - 48 if idx_a >= 0 else 0
|
|
db = ord(b[idx_b]) - 48 if idx_b >= 0 else 0
|
|
total = da + db + carry
|
|
carry, digit = divmod(total, 10)
|
|
out.append(str(digit))
|
|
idx_a -= 1
|
|
idx_b -= 1
|
|
|
|
return "".join(reversed(out))
|
|
|
|
|
|
def _sub_abs(a: str, b: str) -> str:
|
|
"""Return a - b for digit strings assuming a >= b."""
|
|
borrow = 0
|
|
idx_a = len(a) - 1
|
|
idx_b = len(b) - 1
|
|
out: list[str] = []
|
|
|
|
while idx_a >= 0:
|
|
da = ord(a[idx_a]) - 48
|
|
db = ord(b[idx_b]) - 48 if idx_b >= 0 else 0
|
|
diff = da - borrow - db
|
|
if diff < 0:
|
|
diff += 10
|
|
borrow = 1
|
|
else:
|
|
borrow = 0
|
|
out.append(str(diff))
|
|
idx_a -= 1
|
|
idx_b -= 1
|
|
|
|
return _strip_leading_zeros("".join(reversed(out)))
|
|
|
|
|
|
def _multiply_abs(a: str, b: str) -> str:
|
|
if a == "0" or b == "0":
|
|
return "0"
|
|
result = [0] * (len(a) + len(b))
|
|
for i in range(len(a) - 1, -1, -1):
|
|
ai = ord(a[i]) - 48
|
|
carry = 0
|
|
for j in range(len(b) - 1, -1, -1):
|
|
bj = ord(b[j]) - 48
|
|
pos = i + j + 1
|
|
total = result[pos] + ai * bj + carry
|
|
carry, result[pos] = divmod(total, 10)
|
|
result[i] += carry
|
|
return _strip_leading_zeros("".join(str(d) for d in result))
|
|
|
|
|
|
def _multiply_digit(num: str, digit: int) -> str:
|
|
if digit == 0 or num == "0":
|
|
return "0"
|
|
carry = 0
|
|
out: list[str] = []
|
|
for i in range(len(num) - 1, -1, -1):
|
|
total = (ord(num[i]) - 48) * digit + carry
|
|
carry, d = divmod(total, 10)
|
|
out.append(str(d))
|
|
if carry:
|
|
out.append(str(carry))
|
|
return "".join(reversed(out))
|
|
|
|
|
|
def _divide_abs(dividend: str, divisor: str) -> Tuple[str, str]:
|
|
if divisor == "0":
|
|
raise DivideByZeroError("division by zero")
|
|
if dividend == "0":
|
|
return "0", "0"
|
|
|
|
quotient_digits: list[str] = []
|
|
remainder = "0"
|
|
|
|
for digit in dividend:
|
|
remainder = _strip_leading_zeros(remainder + digit)
|
|
q_digit = 0
|
|
for guess in range(9, -1, -1):
|
|
candidate = _multiply_digit(divisor, guess)
|
|
if _compare_abs(candidate, remainder) <= 0:
|
|
q_digit = guess
|
|
remainder = _sub_abs(remainder, candidate) if guess else remainder
|
|
break
|
|
quotient_digits.append(str(q_digit))
|
|
|
|
quotient = _strip_leading_zeros("".join(quotient_digits))
|
|
remainder = _strip_leading_zeros(remainder)
|
|
return quotient, remainder
|
|
|
|
|
|
def _is_zero(digits: str) -> bool:
|
|
return digits == "0"
|
|
|
|
|
|
def _is_odd(digits: str) -> bool:
|
|
return (ord(digits[-1]) - 48) % 2 == 1
|
|
|
|
|
|
def _halve(digits: str) -> str:
|
|
carry = 0
|
|
out: list[str] = []
|
|
for ch in digits:
|
|
current = carry * 10 + (ord(ch) - 48)
|
|
quotient = current // 2
|
|
carry = current % 2
|
|
out.append(str(quotient))
|
|
return _strip_leading_zeros("".join(out))
|
|
|
|
|
|
def _write_result(operation: str, operands: Iterable[StreamNumber], digits: str) -> StreamNumber:
|
|
_ensure_log_dir()
|
|
operand_hash = "_".join(num.hash for num in operands)
|
|
out_file = LOG_DIR / f"{operation}_{operand_hash}.bin"
|
|
with open(out_file, "w", encoding="utf-8") as out:
|
|
out.write(digits)
|
|
register_log_file(out_file)
|
|
return StreamNumber(out_file)
|
|
|
|
|
|
def clear_logs():
|
|
if LOG_DIR.exists():
|
|
for p in LOG_DIR.glob("*"):
|
|
p.unlink()
|
|
_ensure_log_dir()
|
|
wipe_log_records()
|
|
|
|
|
|
def add(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
|
"""Return num_a + num_b without loading full ints into memory."""
|
|
sign_a, a_digits = _normalize_stream(num_a)
|
|
sign_b, b_digits = _normalize_stream(num_b)
|
|
|
|
if sign_a == sign_b:
|
|
digits = _add_abs(a_digits, b_digits)
|
|
sign = sign_a
|
|
else:
|
|
cmp = _compare_abs(a_digits, b_digits)
|
|
if cmp == 0:
|
|
digits = "0"
|
|
sign = 1
|
|
elif cmp > 0:
|
|
digits = _sub_abs(a_digits, b_digits)
|
|
sign = sign_a
|
|
else:
|
|
digits = _sub_abs(b_digits, a_digits)
|
|
sign = sign_b
|
|
|
|
result = digits if sign > 0 or digits == "0" else f"-{digits}"
|
|
return _write_result("add", (num_a, num_b), result)
|
|
|
|
|
|
def sub(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
|
"""Return num_a - num_b using streamed integer arithmetic."""
|
|
sign_a, a_digits = _normalize_stream(num_a)
|
|
sign_b, b_digits = _normalize_stream(num_b)
|
|
|
|
if sign_a != sign_b:
|
|
digits = _add_abs(a_digits, b_digits)
|
|
sign = sign_a
|
|
else:
|
|
cmp = _compare_abs(a_digits, b_digits)
|
|
if cmp == 0:
|
|
digits = "0"
|
|
sign = 1
|
|
elif cmp > 0:
|
|
digits = _sub_abs(a_digits, b_digits)
|
|
sign = sign_a
|
|
else:
|
|
digits = _sub_abs(b_digits, a_digits)
|
|
sign = -sign_a
|
|
|
|
result = digits if sign > 0 or digits == "0" else f"-{digits}"
|
|
return _write_result("sub", (num_a, num_b), result)
|
|
|
|
|
|
def mul(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
|
"""Return num_a * num_b with grade-school multiplication."""
|
|
sign_a, a_digits = _normalize_stream(num_a)
|
|
sign_b, b_digits = _normalize_stream(num_b)
|
|
|
|
digits = _multiply_abs(a_digits, b_digits)
|
|
sign = 1 if digits == "0" else sign_a * sign_b
|
|
result = digits if sign > 0 else f"-{digits}"
|
|
return _write_result("mul", (num_a, num_b), result)
|
|
|
|
|
|
def div(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
|
"""Return floor division num_a // num_b with streamed long division."""
|
|
sign_a, a_digits = _normalize_stream(num_a)
|
|
sign_b, b_digits = _normalize_stream(num_b)
|
|
|
|
quotient, remainder = _divide_abs(a_digits, b_digits)
|
|
|
|
if quotient == "0" and remainder == "0":
|
|
return _write_result("div", (num_a, num_b), "0")
|
|
|
|
sign_product = sign_a * sign_b
|
|
if sign_product < 0 and remainder != "0":
|
|
quotient = _add_abs(quotient, "1")
|
|
sign = -1
|
|
else:
|
|
sign = sign_product if quotient != "0" else 1
|
|
|
|
result = quotient if sign > 0 else f"-{quotient}"
|
|
return _write_result("div", (num_a, num_b), result)
|
|
|
|
|
|
def mod(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
|
"""Return num_a % num_b following Python's floor-division semantics."""
|
|
sign_a, a_digits = _normalize_stream(num_a)
|
|
sign_b, b_digits = _normalize_stream(num_b)
|
|
|
|
if b_digits == "0":
|
|
raise DivideByZeroError("modulo by zero")
|
|
|
|
_, remainder = _divide_abs(a_digits, b_digits)
|
|
|
|
if remainder == "0":
|
|
return _write_result("mod", (num_a, num_b), "0")
|
|
|
|
if sign_a == sign_b:
|
|
digits = remainder
|
|
else:
|
|
digits = _sub_abs(b_digits, remainder)
|
|
|
|
sign = 1 if sign_b > 0 else -1
|
|
result = digits if sign > 0 else f"-{digits}"
|
|
return _write_result("mod", (num_a, num_b), result)
|
|
|
|
|
|
def pow(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
|
"""Return num_a ** num_b using repeated squaring (integer exponent only)."""
|
|
base_sign, base_digits = _normalize_stream(num_a)
|
|
exp_sign, exp_digits = _normalize_stream(num_b)
|
|
|
|
if exp_sign < 0:
|
|
raise ValueError("Negative exponents are not supported for integer streams.")
|
|
|
|
if exp_digits == "0":
|
|
return _write_result("pow", (num_a, num_b), "1")
|
|
|
|
result_digits = "1"
|
|
base_abs = base_digits
|
|
exponent = exp_digits
|
|
|
|
while not _is_zero(exponent):
|
|
if _is_odd(exponent):
|
|
result_digits = _multiply_abs(result_digits, base_abs)
|
|
exponent = _halve(exponent)
|
|
if not _is_zero(exponent):
|
|
base_abs = _multiply_abs(base_abs, base_abs)
|
|
|
|
base_negative = base_sign < 0
|
|
result_sign = -1 if base_negative and _is_odd(exp_digits) else 1
|
|
if result_digits == "0":
|
|
result_sign = 1
|
|
result = result_digits if result_sign > 0 else f"-{result_digits}"
|
|
return _write_result("pow", (num_a, num_b), result)
|
|
|
|
|
|
def is_even(num: StreamNumber) -> bool:
|
|
"""Return True if the streamed integer is even."""
|
|
_, digits = _normalize_stream(num)
|
|
return (ord(digits[-1]) - 48) % 2 == 0
|
|
|
|
|
|
def is_odd(num: StreamNumber) -> bool:
|
|
"""Return True if the streamed integer is odd."""
|
|
return not is_even(num)
|