334 lines
9.1 KiB
Lua
334 lines
9.1 KiB
Lua
local oop = {}
|
|
|
|
oop.Visibility = {
|
|
PUBLIC = "public",
|
|
PROTECTED = "protected",
|
|
PRIVATE = "private"
|
|
}
|
|
|
|
local tpack = table.pack or function(...)
|
|
return { n = select("#", ...), ... }
|
|
end
|
|
local tunpack = table.unpack or unpack
|
|
|
|
local lookup_in_class
|
|
local lookup_with_owner
|
|
local lookup_visibility
|
|
|
|
local function super_call(self, method_name, ...)
|
|
local class = getmetatable(self)
|
|
if not class then
|
|
error("super() called on non-instance")
|
|
end
|
|
local base = rawget(class, "__base")
|
|
if not base then
|
|
error("super() called but no base class")
|
|
end
|
|
local fn = base[method_name]
|
|
if not fn then
|
|
error("base class has no method '" .. tostring(method_name) .. "'")
|
|
end
|
|
return fn(self, ...)
|
|
end
|
|
|
|
local function_def_cache = setmetatable({}, { __mode = "k" })
|
|
oop._call_stack = {}
|
|
|
|
local function normalize_visibility(visibility)
|
|
if visibility == nil then
|
|
return oop.Visibility.PUBLIC
|
|
end
|
|
if visibility == oop.Visibility.PUBLIC or visibility == oop.Visibility.PROTECTED or visibility == oop.Visibility.PRIVATE then
|
|
return visibility
|
|
end
|
|
error("visibility must be Visibility.PUBLIC, Visibility.PROTECTED, or Visibility.PRIVATE")
|
|
end
|
|
|
|
local function is_class_table(target)
|
|
return type(target) == "table" and rawget(target, "__visibility") ~= nil
|
|
end
|
|
|
|
local function call_with_context(owner, fn, self, ...)
|
|
local stack = oop._call_stack
|
|
stack[#stack + 1] = owner
|
|
local results = tpack(pcall(fn, self, ...))
|
|
stack[#stack] = nil
|
|
if not results[1] then
|
|
error(results[2], 0)
|
|
end
|
|
return tunpack(results, 2, results.n)
|
|
end
|
|
|
|
local function current_caller()
|
|
return oop._call_stack[#oop._call_stack]
|
|
end
|
|
|
|
local function can_access(owner, visibility, caller)
|
|
if visibility == oop.Visibility.PUBLIC then
|
|
return true
|
|
end
|
|
if not caller or not owner then
|
|
return false
|
|
end
|
|
if visibility == oop.Visibility.PRIVATE then
|
|
return caller == owner
|
|
end
|
|
if visibility == oop.Visibility.PROTECTED then
|
|
return oop.issubclass(caller, owner)
|
|
end
|
|
return false
|
|
end
|
|
|
|
local function class_name(cls)
|
|
if type(cls) ~= "table" then
|
|
return tostring(cls)
|
|
end
|
|
return rawget(cls, "__name") or "Anonymous"
|
|
end
|
|
|
|
local function visibility_error(kind, member, visibility, owner, caller)
|
|
local from = caller and class_name(caller) or "outside"
|
|
local where = owner and class_name(owner) or "unknown"
|
|
local msg = "access violation: " .. kind .. " '" .. tostring(member) .. "' is " .. visibility
|
|
.. " in " .. where .. " (from " .. from .. ")"
|
|
error(msg, 3)
|
|
end
|
|
|
|
function oop.setattr(target, key, value, visibility)
|
|
target[key] = value
|
|
if is_class_table(target) then
|
|
local vis = normalize_visibility(visibility)
|
|
target.__visibility[key] = vis
|
|
end
|
|
end
|
|
|
|
function oop.new(class, ...)
|
|
local actual_class = class
|
|
if type(class) == "function" then
|
|
actual_class = function_def_cache[class]
|
|
if not actual_class then
|
|
actual_class = oop.class(class)
|
|
function_def_cache[class] = actual_class
|
|
end
|
|
end
|
|
local obj = setmetatable({}, actual_class)
|
|
rawset(obj, "__fields", {})
|
|
local init_fn, init_owner = lookup_with_owner(actual_class, "__init")
|
|
if init_fn then
|
|
call_with_context(init_owner, init_fn, obj, ...)
|
|
end
|
|
return obj
|
|
end
|
|
|
|
function lookup_in_class(cls, key)
|
|
local cur = cls
|
|
while cur do
|
|
local val = rawget(cur, key)
|
|
if val ~= nil then
|
|
return val
|
|
end
|
|
cur = rawget(cur, "__base")
|
|
end
|
|
return nil
|
|
end
|
|
|
|
function lookup_with_owner(cls, key)
|
|
local cur = cls
|
|
while cur do
|
|
local val = rawget(cur, key)
|
|
if val ~= nil then
|
|
return val, cur
|
|
end
|
|
cur = rawget(cur, "__base")
|
|
end
|
|
return nil, nil
|
|
end
|
|
|
|
function lookup_visibility(cls, key)
|
|
local cur = cls
|
|
while cur do
|
|
local vis_table = rawget(cur, "__visibility")
|
|
if vis_table then
|
|
local vis = vis_table[key]
|
|
if vis then
|
|
return vis, cur
|
|
end
|
|
end
|
|
cur = rawget(cur, "__base")
|
|
end
|
|
return oop.Visibility.PUBLIC, nil
|
|
end
|
|
|
|
function oop.class(def, base)
|
|
local cls = {}
|
|
cls.__name = "Anonymous"
|
|
cls.__base = base
|
|
cls.__visibility = {}
|
|
cls.visibility = oop.Visibility
|
|
cls.__index = function(self, key)
|
|
local fields = rawget(self, "__fields")
|
|
if fields then
|
|
local field_val = fields[key]
|
|
if field_val ~= nil then
|
|
local vis, owner = lookup_visibility(cls, key)
|
|
if not can_access(owner, vis, current_caller()) then
|
|
visibility_error("field", key, vis, owner, current_caller())
|
|
end
|
|
return field_val
|
|
end
|
|
end
|
|
|
|
local val, owner = lookup_with_owner(cls, key)
|
|
if val == nil then
|
|
return nil
|
|
end
|
|
local vis, vis_owner = lookup_visibility(cls, key)
|
|
local access_owner = vis_owner or owner
|
|
if not can_access(access_owner, vis, current_caller()) then
|
|
visibility_error("member", key, vis, access_owner, current_caller())
|
|
end
|
|
if type(val) == "function" then
|
|
return function(...)
|
|
return call_with_context(owner, val, self, ...)
|
|
end
|
|
end
|
|
return val
|
|
end
|
|
cls.__newindex = function(self, key, value)
|
|
if key == "__fields" then
|
|
rawset(self, key, value)
|
|
return
|
|
end
|
|
local vis, owner = lookup_visibility(cls, key)
|
|
if owner and not can_access(owner, vis, current_caller()) then
|
|
visibility_error("field", key, vis, owner, current_caller())
|
|
end
|
|
local fields = rawget(self, "__fields")
|
|
if not fields then
|
|
fields = {}
|
|
rawset(self, "__fields", fields)
|
|
end
|
|
fields[key] = value
|
|
end
|
|
|
|
function cls:super(method_name, ...)
|
|
return super_call(self, method_name, ...)
|
|
end
|
|
function cls.setattr(target, key, value, visibility)
|
|
return oop.setattr(target, key, value, visibility)
|
|
end
|
|
function cls.method(name, fn, visibility)
|
|
if type(name) ~= "string" then
|
|
error("method name must be a string")
|
|
end
|
|
if type(fn) ~= "function" then
|
|
error("method must be a function")
|
|
end
|
|
cls[name] = fn
|
|
cls.__visibility[name] = normalize_visibility(visibility)
|
|
end
|
|
function cls.inherit(parent)
|
|
local actual_parent = parent
|
|
if type(parent) == "function" then
|
|
actual_parent = function_def_cache[parent]
|
|
if not actual_parent then
|
|
actual_parent = oop.class(parent)
|
|
function_def_cache[parent] = actual_parent
|
|
end
|
|
end
|
|
cls.__base = actual_parent
|
|
cls.super = setmetatable({ __class = actual_parent }, {
|
|
__index = function(_, key)
|
|
local val = actual_parent[key]
|
|
if type(val) == "function" then
|
|
return function(this, ...)
|
|
return call_with_context(actual_parent, val, this, ...)
|
|
end
|
|
end
|
|
return val
|
|
end
|
|
})
|
|
local mt = getmetatable(cls) or {}
|
|
mt.__index = actual_parent
|
|
mt.__call = mt.__call or function(c, ...)
|
|
return oop.new(c, ...)
|
|
end
|
|
setmetatable(cls, mt)
|
|
end
|
|
|
|
setmetatable(cls, {
|
|
__index = base,
|
|
__call = function(c, ...)
|
|
return oop.new(c, ...)
|
|
end
|
|
})
|
|
|
|
if type(def) == "string" then
|
|
cls.__name = def
|
|
elseif type(def) == "function" then
|
|
def(cls)
|
|
elseif def ~= nil then
|
|
error("class definition must be a function or name string")
|
|
end
|
|
if base then
|
|
cls.inherit(base)
|
|
end
|
|
|
|
return cls
|
|
end
|
|
|
|
function oop.isinstance(obj, class)
|
|
local mt = getmetatable(obj)
|
|
while mt do
|
|
if mt == class then
|
|
return true
|
|
end
|
|
mt = rawget(mt, "__base")
|
|
end
|
|
return false
|
|
end
|
|
|
|
function oop.issubclass(class, base)
|
|
local mt = class
|
|
while mt do
|
|
if mt == base then
|
|
return true
|
|
end
|
|
mt = rawget(mt, "__base")
|
|
end
|
|
return false
|
|
end
|
|
|
|
function oop.install(env)
|
|
local target = env or _G
|
|
target.class = oop.class
|
|
target.new = oop.new
|
|
target.setattr = oop.setattr
|
|
target.isinstance = oop.isinstance
|
|
target.issubclass = oop.issubclass
|
|
target.Visibility = oop.Visibility
|
|
|
|
local mt = getmetatable(target)
|
|
if not mt then
|
|
mt = {}
|
|
setmetatable(target, mt)
|
|
end
|
|
local prev_newindex = mt.__newindex
|
|
mt.__newindex = function(t, k, v)
|
|
local val = v
|
|
if type(k) == "string" and type(v) == "function" and k:match("^[A-Z]") then
|
|
val = oop.class(v)
|
|
val.__name = k
|
|
function_def_cache[v] = val
|
|
end
|
|
if prev_newindex then
|
|
return prev_newindex(t, k, val)
|
|
end
|
|
rawset(t, k, val)
|
|
end
|
|
end
|
|
|
|
oop.install()
|
|
|
|
return oop
|