225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
import inspect
|
|
import typing as t
|
|
import random
|
|
from functools import wraps
|
|
from collections import deque
|
|
|
|
try:
|
|
from typing import get_origin, get_args
|
|
except ImportError:
|
|
def get_origin(tp): return getattr(tp, "__origin__", None)
|
|
def get_args(tp): return getattr(tp, "__args__", ())
|
|
|
|
Any = t.Any
|
|
|
|
def is_any(tp):
|
|
return tp is Any
|
|
|
|
def match_union(tp, value_type):
|
|
origin = get_origin(tp)
|
|
if origin is t.Union:
|
|
return any(type_matches(arg, value_type) for arg in get_args(tp))
|
|
return False
|
|
|
|
def type_matches(expected, actual):
|
|
if expected is inspect._empty or is_any(expected):
|
|
return True
|
|
origin = get_origin(expected)
|
|
if origin is t.Union:
|
|
return any(type_matches(opt, actual) for opt in get_args(expected))
|
|
if origin in (t.Optional,):
|
|
return match_union(expected, actual)
|
|
try:
|
|
return issubclass(actual, expected)
|
|
except TypeError:
|
|
return True
|
|
|
|
def exact_match(expected, actual):
|
|
return (expected is not inspect._empty
|
|
and not is_any(expected)
|
|
and get_origin(expected) is None
|
|
and actual is expected)
|
|
|
|
def subclass_depth(expected, actual):
|
|
try:
|
|
mro = actual.mro()
|
|
return mro.index(expected) if expected in mro else 9999
|
|
except Exception:
|
|
return 9999
|
|
|
|
class Coercions:
|
|
table: dict[type, t.Tuple[t.Callable[[t.Any], t.Any], ...]] = {
|
|
int: (lambda v: int(v),),
|
|
float: (lambda v: float(v),),
|
|
str: (lambda v: str(v),),
|
|
bool: (lambda v: bool(int(v)) if isinstance(v, str) and v.isdigit() else bool(v),),
|
|
}
|
|
|
|
@classmethod
|
|
def can_coerce(cls, target: t.Type, value):
|
|
if target not in cls.table:
|
|
return False, None
|
|
for fn in cls.table[target]:
|
|
try:
|
|
coerced = fn(value)
|
|
if isinstance(coerced, target):
|
|
return True, coerced
|
|
except Exception:
|
|
pass
|
|
return False, None
|
|
|
|
class TinyLRU:
|
|
def __init__(self, maxsize=128):
|
|
self.maxsize = maxsize
|
|
self.d = {}
|
|
self.q = deque()
|
|
|
|
def get(self, key):
|
|
return self.d.get(key)
|
|
|
|
def put(self, key, value):
|
|
if key in self.d:
|
|
return
|
|
self.d[key] = value
|
|
self.q.append(key)
|
|
if len(self.q) > self.maxsize:
|
|
old = self.q.popleft()
|
|
self.d.pop(old, None)
|
|
|
|
class Dispatcher:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.overloads: list[dict] = []
|
|
self._cache = TinyLRU(256)
|
|
|
|
def register(self, func, *, priority=0):
|
|
sig = inspect.signature(func)
|
|
entry = {"sig": sig, "func": func, "priority": int(priority)}
|
|
self.overloads.append(entry)
|
|
|
|
def __get__(self, instance, owner):
|
|
@wraps(self)
|
|
def bound(*args, **kwargs):
|
|
return self._dispatch(instance, owner, *args, **kwargs)
|
|
return bound
|
|
|
|
def _score_entry(self, entry, instance, args, kwargs, expect_type, allow_coercion=True):
|
|
sig: inspect.Signature = entry["sig"]
|
|
func = entry["func"]
|
|
prio = entry["priority"]
|
|
|
|
try:
|
|
bound = sig.bind(instance, *args, **kwargs)
|
|
bound.apply_defaults()
|
|
except TypeError:
|
|
return None
|
|
|
|
score = 0
|
|
coercions_to_apply = {}
|
|
defaults_count = sum(
|
|
1 for p in sig.parameters.values()
|
|
if p.default is not inspect._empty
|
|
)
|
|
score -= defaults_count
|
|
|
|
for name, value in bound.arguments.items():
|
|
if name == "self":
|
|
continue
|
|
param = sig.parameters[name]
|
|
ann = param.annotation
|
|
actual_t = type(value)
|
|
|
|
if exact_match(ann, actual_t):
|
|
score += 30
|
|
elif ann is inspect._empty or is_any(ann) or get_origin(ann) is not None and get_origin(ann) is t.Union and any(is_any(a) for a in get_args(ann)):
|
|
score += 0
|
|
elif type_matches(ann, actual_t):
|
|
dist = subclass_depth(ann, actual_t)
|
|
score += max(15 - min(dist, 10), 5)
|
|
else:
|
|
if allow_coercion and isinstance(ann, type):
|
|
can, coerced = Coercions.can_coerce(ann, value)
|
|
if can:
|
|
coercions_to_apply[name] = coerced
|
|
score += 8
|
|
else:
|
|
return None
|
|
else:
|
|
origin = get_origin(ann)
|
|
if allow_coercion and origin is t.Union:
|
|
ok = False
|
|
for opt in get_args(ann):
|
|
if opt is type(None):
|
|
continue
|
|
if isinstance(opt, type):
|
|
can, coerced = Coercions.can_coerce(opt, value)
|
|
if can:
|
|
coercions_to_apply[name] = coerced
|
|
score += 6
|
|
ok = True
|
|
break
|
|
if not ok:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
if expect_type is not None:
|
|
ret_ann = sig.return_annotation
|
|
if ret_ann is inspect._empty or is_any(ret_ann):
|
|
score -= 1
|
|
else:
|
|
if get_origin(ret_ann) is t.Union:
|
|
ok = any(type_matches(opt, expect_type) for opt in get_args(ret_ann))
|
|
else:
|
|
try:
|
|
ok = issubclass(expect_type, ret_ann) or issubclass(ret_ann, expect_type)
|
|
except TypeError:
|
|
ok = True
|
|
if not ok:
|
|
return None
|
|
else:
|
|
score += 5
|
|
|
|
score += prio * 1000
|
|
return score, coercions_to_apply
|
|
|
|
def _dispatch(self, instance, owner, *args, **kwargs):
|
|
expect_type = kwargs.pop("__expect__", None)
|
|
|
|
key = (tuple(type(a) for a in args), tuple(sorted(kwargs.keys())), expect_type)
|
|
cached = self._cache.get(key)
|
|
if cached:
|
|
entry = cached
|
|
sig = entry["sig"]
|
|
bound = sig.bind(instance, *args, **kwargs)
|
|
bound.apply_defaults()
|
|
return entry["func"](*bound.args, **bound.kwargs)
|
|
|
|
candidates = []
|
|
for entry in self.overloads:
|
|
scored = self._score_entry(entry, instance, args, kwargs, expect_type)
|
|
if scored is not None:
|
|
score, coercions = scored
|
|
candidates.append((score, random.random(), coercions, entry))
|
|
|
|
if not candidates:
|
|
raise TypeError(f"No matching overload for {self.name}{args}")
|
|
|
|
candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
|
best_score, _, coercions, entry = candidates[0]
|
|
|
|
sig = entry["sig"]
|
|
bound = sig.bind(instance, *args, **kwargs)
|
|
bound.apply_defaults()
|
|
for k, v in coercions.items():
|
|
bound.arguments[k] = v
|
|
|
|
self._cache.put(key, entry)
|
|
return entry["func"](*bound.args, **bound.kwargs)
|
|
|
|
def overload(dispatcher: Dispatcher, *, priority: int = 0):
|
|
def decorator(func):
|
|
dispatcher.register(func, priority=priority)
|
|
return dispatcher
|
|
return decorator
|