made module better
This commit is contained in:
parent
76fe7fb668
commit
034fc2b8b6
@ -1,2 +1,2 @@
|
|||||||
from .engine import clear_logs, add, sub, mul, div
|
from .engine import clear_logs, add, sub, mul, div, pow
|
||||||
from .number import StreamNumber
|
from .number import StreamNumber
|
||||||
|
|||||||
@ -1,45 +1,291 @@
|
|||||||
from pathlib import Path
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
from .number import StreamNumber, LOG_DIR
|
from .number import StreamNumber, LOG_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_log_dir() -> None:
|
||||||
|
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_leading_zeros(digits: str) -> str:
|
||||||
|
digits = digits.lstrip("0")
|
||||||
|
return digits or "0"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_stream(num: StreamNumber) -> Tuple[int, str]:
|
||||||
|
"""Return (sign, digits) tuple for the streamed number."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for chunk in num.stream():
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
parts.append(chunk)
|
||||||
|
raw = "".join(parts)
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
raise ValueError(f"Stream for {num.path} is empty")
|
||||||
|
|
||||||
|
sign = 1
|
||||||
|
if raw[0] in "+-":
|
||||||
|
sign = -1 if raw[0] == "-" else 1
|
||||||
|
raw = raw[1:]
|
||||||
|
|
||||||
|
if not raw.isdigit():
|
||||||
|
raise ValueError(f"Non-digit characters found in stream for {num.path}")
|
||||||
|
|
||||||
|
digits = _strip_leading_zeros(raw)
|
||||||
|
if digits == "0":
|
||||||
|
sign = 1
|
||||||
|
return sign, digits
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_abs(a: str, b: str) -> int:
|
||||||
|
"""Compare two positive digit strings."""
|
||||||
|
if len(a) != len(b):
|
||||||
|
return 1 if len(a) > len(b) else -1
|
||||||
|
if a == b:
|
||||||
|
return 0
|
||||||
|
return 1 if a > b else -1
|
||||||
|
|
||||||
|
|
||||||
|
def _add_abs(a: str, b: str) -> str:
|
||||||
|
carry = 0
|
||||||
|
idx_a = len(a) - 1
|
||||||
|
idx_b = len(b) - 1
|
||||||
|
out: list[str] = []
|
||||||
|
|
||||||
|
while idx_a >= 0 or idx_b >= 0 or carry:
|
||||||
|
da = ord(a[idx_a]) - 48 if idx_a >= 0 else 0
|
||||||
|
db = ord(b[idx_b]) - 48 if idx_b >= 0 else 0
|
||||||
|
total = da + db + carry
|
||||||
|
carry, digit = divmod(total, 10)
|
||||||
|
out.append(str(digit))
|
||||||
|
idx_a -= 1
|
||||||
|
idx_b -= 1
|
||||||
|
|
||||||
|
return "".join(reversed(out))
|
||||||
|
|
||||||
|
|
||||||
|
def _sub_abs(a: str, b: str) -> str:
|
||||||
|
"""Return a - b for digit strings assuming a >= b."""
|
||||||
|
borrow = 0
|
||||||
|
idx_a = len(a) - 1
|
||||||
|
idx_b = len(b) - 1
|
||||||
|
out: list[str] = []
|
||||||
|
|
||||||
|
while idx_a >= 0:
|
||||||
|
da = ord(a[idx_a]) - 48
|
||||||
|
db = ord(b[idx_b]) - 48 if idx_b >= 0 else 0
|
||||||
|
diff = da - borrow - db
|
||||||
|
if diff < 0:
|
||||||
|
diff += 10
|
||||||
|
borrow = 1
|
||||||
|
else:
|
||||||
|
borrow = 0
|
||||||
|
out.append(str(diff))
|
||||||
|
idx_a -= 1
|
||||||
|
idx_b -= 1
|
||||||
|
|
||||||
|
return _strip_leading_zeros("".join(reversed(out)))
|
||||||
|
|
||||||
|
|
||||||
|
def _multiply_abs(a: str, b: str) -> str:
|
||||||
|
if a == "0" or b == "0":
|
||||||
|
return "0"
|
||||||
|
result = [0] * (len(a) + len(b))
|
||||||
|
for i in range(len(a) - 1, -1, -1):
|
||||||
|
ai = ord(a[i]) - 48
|
||||||
|
carry = 0
|
||||||
|
for j in range(len(b) - 1, -1, -1):
|
||||||
|
bj = ord(b[j]) - 48
|
||||||
|
pos = i + j + 1
|
||||||
|
total = result[pos] + ai * bj + carry
|
||||||
|
carry, result[pos] = divmod(total, 10)
|
||||||
|
result[i] += carry
|
||||||
|
return _strip_leading_zeros("".join(str(d) for d in result))
|
||||||
|
|
||||||
|
|
||||||
|
def _multiply_digit(num: str, digit: int) -> str:
|
||||||
|
if digit == 0 or num == "0":
|
||||||
|
return "0"
|
||||||
|
carry = 0
|
||||||
|
out: list[str] = []
|
||||||
|
for i in range(len(num) - 1, -1, -1):
|
||||||
|
total = (ord(num[i]) - 48) * digit + carry
|
||||||
|
carry, d = divmod(total, 10)
|
||||||
|
out.append(str(d))
|
||||||
|
if carry:
|
||||||
|
out.append(str(carry))
|
||||||
|
return "".join(reversed(out))
|
||||||
|
|
||||||
|
|
||||||
|
def _divide_abs(dividend: str, divisor: str) -> Tuple[str, str]:
|
||||||
|
if divisor == "0":
|
||||||
|
raise ZeroDivisionError("division by zero")
|
||||||
|
if dividend == "0":
|
||||||
|
return "0", "0"
|
||||||
|
|
||||||
|
quotient_digits: list[str] = []
|
||||||
|
remainder = "0"
|
||||||
|
|
||||||
|
for digit in dividend:
|
||||||
|
remainder = _strip_leading_zeros(remainder + digit)
|
||||||
|
q_digit = 0
|
||||||
|
for guess in range(9, -1, -1):
|
||||||
|
candidate = _multiply_digit(divisor, guess)
|
||||||
|
if _compare_abs(candidate, remainder) <= 0:
|
||||||
|
q_digit = guess
|
||||||
|
remainder = _sub_abs(remainder, candidate) if guess else remainder
|
||||||
|
break
|
||||||
|
quotient_digits.append(str(q_digit))
|
||||||
|
|
||||||
|
quotient = _strip_leading_zeros("".join(quotient_digits))
|
||||||
|
remainder = _strip_leading_zeros(remainder)
|
||||||
|
return quotient, remainder
|
||||||
|
|
||||||
|
|
||||||
|
def _is_zero(digits: str) -> bool:
|
||||||
|
return digits == "0"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_odd(digits: str) -> bool:
|
||||||
|
return (ord(digits[-1]) - 48) % 2 == 1
|
||||||
|
|
||||||
|
|
||||||
|
def _halve(digits: str) -> str:
|
||||||
|
carry = 0
|
||||||
|
out: list[str] = []
|
||||||
|
for ch in digits:
|
||||||
|
current = carry * 10 + (ord(ch) - 48)
|
||||||
|
quotient = current // 2
|
||||||
|
carry = current % 2
|
||||||
|
out.append(str(quotient))
|
||||||
|
return _strip_leading_zeros("".join(out))
|
||||||
|
|
||||||
|
|
||||||
|
def _write_result(operation: str, operands: Iterable[StreamNumber], digits: str) -> StreamNumber:
|
||||||
|
_ensure_log_dir()
|
||||||
|
operand_hash = "_".join(num.hash for num in operands)
|
||||||
|
out_file = LOG_DIR / f"{operation}_{operand_hash}.bin"
|
||||||
|
with open(out_file, "w", encoding="utf-8") as out:
|
||||||
|
out.write(digits)
|
||||||
|
return StreamNumber(out_file)
|
||||||
|
|
||||||
|
|
||||||
def clear_logs():
|
def clear_logs():
|
||||||
if LOG_DIR.exists():
|
if LOG_DIR.exists():
|
||||||
for p in LOG_DIR.glob("*"):
|
for p in LOG_DIR.glob("*"):
|
||||||
p.unlink()
|
p.unlink()
|
||||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
_ensure_log_dir()
|
||||||
|
|
||||||
|
|
||||||
def add(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
def add(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
||||||
"""Digit-by-digit streamed addition."""
|
"""Return num_a + num_b without loading full ints into memory."""
|
||||||
out_file = LOG_DIR / f"{num_a.hash}_add_{num_b.hash}.bin"
|
sign_a, a_digits = _normalize_stream(num_a)
|
||||||
|
sign_b, b_digits = _normalize_stream(num_b)
|
||||||
|
|
||||||
carry = 0
|
if sign_a == sign_b:
|
||||||
a_buf = list(num_a.stream(1))
|
digits = _add_abs(a_digits, b_digits)
|
||||||
b_buf = list(num_b.stream(1))
|
sign = sign_a
|
||||||
|
else:
|
||||||
|
cmp = _compare_abs(a_digits, b_digits)
|
||||||
|
if cmp == 0:
|
||||||
|
digits = "0"
|
||||||
|
sign = 1
|
||||||
|
elif cmp > 0:
|
||||||
|
digits = _sub_abs(a_digits, b_digits)
|
||||||
|
sign = sign_a
|
||||||
|
else:
|
||||||
|
digits = _sub_abs(b_digits, a_digits)
|
||||||
|
sign = sign_b
|
||||||
|
|
||||||
# align lengths
|
result = digits if sign > 0 or digits == "0" else f"-{digits}"
|
||||||
max_len = max(len(a_buf), len(b_buf))
|
return _write_result("add", (num_a, num_b), result)
|
||||||
a_buf = ["0"] * (max_len - len(a_buf)) + a_buf
|
|
||||||
b_buf = ["0"] * (max_len - len(b_buf)) + b_buf
|
|
||||||
|
|
||||||
with open(out_file, "wb") as out:
|
|
||||||
for i in range(max_len - 1, -1, -1):
|
|
||||||
s = int(a_buf[i]) + int(b_buf[i]) + carry
|
|
||||||
carry, digit = divmod(s, 10)
|
|
||||||
out.write(str(digit).encode())
|
|
||||||
if carry:
|
|
||||||
out.write(str(carry).encode())
|
|
||||||
return StreamNumber(out_file)
|
|
||||||
|
|
||||||
def sub(num_a, num_b):
|
def sub(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
||||||
"""Basic streamed subtraction (assumes a >= b)."""
|
"""Return num_a - num_b using streamed integer arithmetic."""
|
||||||
# similar pattern with borrow propagation...
|
sign_a, a_digits = _normalize_stream(num_a)
|
||||||
pass
|
sign_b, b_digits = _normalize_stream(num_b)
|
||||||
|
|
||||||
def mul(num_a, num_b):
|
if sign_a != sign_b:
|
||||||
"""Chunked multiplication using repeated addition."""
|
digits = _add_abs(a_digits, b_digits)
|
||||||
# create temporary stage files for partial sums
|
sign = sign_a
|
||||||
pass
|
else:
|
||||||
|
cmp = _compare_abs(a_digits, b_digits)
|
||||||
|
if cmp == 0:
|
||||||
|
digits = "0"
|
||||||
|
sign = 1
|
||||||
|
elif cmp > 0:
|
||||||
|
digits = _sub_abs(a_digits, b_digits)
|
||||||
|
sign = sign_a
|
||||||
|
else:
|
||||||
|
digits = _sub_abs(b_digits, a_digits)
|
||||||
|
sign = -sign_a
|
||||||
|
|
||||||
def div(num_a, num_b):
|
result = digits if sign > 0 or digits == "0" else f"-{digits}"
|
||||||
"""Long division, streamed stage by stage."""
|
return _write_result("sub", (num_a, num_b), result)
|
||||||
# create multiple intermediate files: div_stage_1, div_stage_2, etc.
|
|
||||||
pass
|
|
||||||
|
def mul(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
||||||
|
"""Return num_a * num_b with grade-school multiplication."""
|
||||||
|
sign_a, a_digits = _normalize_stream(num_a)
|
||||||
|
sign_b, b_digits = _normalize_stream(num_b)
|
||||||
|
|
||||||
|
digits = _multiply_abs(a_digits, b_digits)
|
||||||
|
sign = 1 if digits == "0" else sign_a * sign_b
|
||||||
|
result = digits if sign > 0 else f"-{digits}"
|
||||||
|
return _write_result("mul", (num_a, num_b), result)
|
||||||
|
|
||||||
|
|
||||||
|
def div(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
||||||
|
"""Return floor division num_a // num_b with streamed long division."""
|
||||||
|
sign_a, a_digits = _normalize_stream(num_a)
|
||||||
|
sign_b, b_digits = _normalize_stream(num_b)
|
||||||
|
|
||||||
|
quotient, remainder = _divide_abs(a_digits, b_digits)
|
||||||
|
|
||||||
|
if quotient == "0" and remainder == "0":
|
||||||
|
return _write_result("div", (num_a, num_b), "0")
|
||||||
|
|
||||||
|
sign_product = sign_a * sign_b
|
||||||
|
if sign_product < 0 and remainder != "0":
|
||||||
|
quotient = _add_abs(quotient, "1")
|
||||||
|
sign = -1
|
||||||
|
else:
|
||||||
|
sign = sign_product if quotient != "0" else 1
|
||||||
|
|
||||||
|
result = quotient if sign > 0 else f"-{quotient}"
|
||||||
|
return _write_result("div", (num_a, num_b), result)
|
||||||
|
|
||||||
|
|
||||||
|
def pow(num_a: StreamNumber, num_b: StreamNumber) -> StreamNumber:
|
||||||
|
"""Return num_a ** num_b using repeated squaring (integer exponent only)."""
|
||||||
|
base_sign, base_digits = _normalize_stream(num_a)
|
||||||
|
exp_sign, exp_digits = _normalize_stream(num_b)
|
||||||
|
|
||||||
|
if exp_sign < 0:
|
||||||
|
raise ValueError("Negative exponents are not supported for integer streams.")
|
||||||
|
|
||||||
|
if exp_digits == "0":
|
||||||
|
return _write_result("pow", (num_a, num_b), "1")
|
||||||
|
|
||||||
|
result_digits = "1"
|
||||||
|
base_abs = base_digits
|
||||||
|
exponent = exp_digits
|
||||||
|
|
||||||
|
while not _is_zero(exponent):
|
||||||
|
if _is_odd(exponent):
|
||||||
|
result_digits = _multiply_abs(result_digits, base_abs)
|
||||||
|
exponent = _halve(exponent)
|
||||||
|
if not _is_zero(exponent):
|
||||||
|
base_abs = _multiply_abs(base_abs, base_abs)
|
||||||
|
|
||||||
|
base_negative = base_sign < 0
|
||||||
|
result_sign = -1 if base_negative and _is_odd(exp_digits) else 1
|
||||||
|
if result_digits == "0":
|
||||||
|
result_sign = 1
|
||||||
|
result = result_digits if result_sign > 0 else f"-{result_digits}"
|
||||||
|
return _write_result("pow", (num_a, num_b), result)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user