lua_oop/oop/oop.lua

334 lines
9.1 KiB
Lua

local oop = {}
oop.Visibility = {
PUBLIC = "public",
PROTECTED = "protected",
PRIVATE = "private"
}
local tpack = table.pack or function(...)
return { n = select("#", ...), ... }
end
local tunpack = table.unpack or unpack
local lookup_in_class
local lookup_with_owner
local lookup_visibility
local function super_call(self, method_name, ...)
local class = getmetatable(self)
if not class then
error("super() called on non-instance")
end
local base = rawget(class, "__base")
if not base then
error("super() called but no base class")
end
local fn = base[method_name]
if not fn then
error("base class has no method '" .. tostring(method_name) .. "'")
end
return fn(self, ...)
end
local function_def_cache = setmetatable({}, { __mode = "k" })
oop._call_stack = {}
local function normalize_visibility(visibility)
if visibility == nil then
return oop.Visibility.PUBLIC
end
if visibility == oop.Visibility.PUBLIC or visibility == oop.Visibility.PROTECTED or visibility == oop.Visibility.PRIVATE then
return visibility
end
error("visibility must be Visibility.PUBLIC, Visibility.PROTECTED, or Visibility.PRIVATE")
end
local function is_class_table(target)
return type(target) == "table" and rawget(target, "__visibility") ~= nil
end
local function call_with_context(owner, fn, self, ...)
local stack = oop._call_stack
stack[#stack + 1] = owner
local results = tpack(pcall(fn, self, ...))
stack[#stack] = nil
if not results[1] then
error(results[2], 0)
end
return tunpack(results, 2, results.n)
end
local function current_caller()
return oop._call_stack[#oop._call_stack]
end
local function can_access(owner, visibility, caller)
if visibility == oop.Visibility.PUBLIC then
return true
end
if not caller or not owner then
return false
end
if visibility == oop.Visibility.PRIVATE then
return caller == owner
end
if visibility == oop.Visibility.PROTECTED then
return oop.issubclass(caller, owner)
end
return false
end
local function class_name(cls)
if type(cls) ~= "table" then
return tostring(cls)
end
return rawget(cls, "__name") or "Anonymous"
end
local function visibility_error(kind, member, visibility, owner, caller)
local from = caller and class_name(caller) or "outside"
local where = owner and class_name(owner) or "unknown"
local msg = "access violation: " .. kind .. " '" .. tostring(member) .. "' is " .. visibility
.. " in " .. where .. " (from " .. from .. ")"
error(msg, 3)
end
function oop.setattr(target, key, value, visibility)
target[key] = value
if is_class_table(target) then
local vis = normalize_visibility(visibility)
target.__visibility[key] = vis
end
end
function oop.new(class, ...)
local actual_class = class
if type(class) == "function" then
actual_class = function_def_cache[class]
if not actual_class then
actual_class = oop.class(class)
function_def_cache[class] = actual_class
end
end
local obj = setmetatable({}, actual_class)
rawset(obj, "__fields", {})
local init_fn, init_owner = lookup_with_owner(actual_class, "__init")
if init_fn then
call_with_context(init_owner, init_fn, obj, ...)
end
return obj
end
function lookup_in_class(cls, key)
local cur = cls
while cur do
local val = rawget(cur, key)
if val ~= nil then
return val
end
cur = rawget(cur, "__base")
end
return nil
end
function lookup_with_owner(cls, key)
local cur = cls
while cur do
local val = rawget(cur, key)
if val ~= nil then
return val, cur
end
cur = rawget(cur, "__base")
end
return nil, nil
end
function lookup_visibility(cls, key)
local cur = cls
while cur do
local vis_table = rawget(cur, "__visibility")
if vis_table then
local vis = vis_table[key]
if vis then
return vis, cur
end
end
cur = rawget(cur, "__base")
end
return oop.Visibility.PUBLIC, nil
end
function oop.class(def, base)
local cls = {}
cls.__name = "Anonymous"
cls.__base = base
cls.__visibility = {}
cls.visibility = oop.Visibility
cls.__index = function(self, key)
local fields = rawget(self, "__fields")
if fields then
local field_val = fields[key]
if field_val ~= nil then
local vis, owner = lookup_visibility(cls, key)
if not can_access(owner, vis, current_caller()) then
visibility_error("field", key, vis, owner, current_caller())
end
return field_val
end
end
local val, owner = lookup_with_owner(cls, key)
if val == nil then
return nil
end
local vis, vis_owner = lookup_visibility(cls, key)
local access_owner = vis_owner or owner
if not can_access(access_owner, vis, current_caller()) then
visibility_error("member", key, vis, access_owner, current_caller())
end
if type(val) == "function" then
return function(...)
return call_with_context(owner, val, self, ...)
end
end
return val
end
cls.__newindex = function(self, key, value)
if key == "__fields" then
rawset(self, key, value)
return
end
local vis, owner = lookup_visibility(cls, key)
if owner and not can_access(owner, vis, current_caller()) then
visibility_error("field", key, vis, owner, current_caller())
end
local fields = rawget(self, "__fields")
if not fields then
fields = {}
rawset(self, "__fields", fields)
end
fields[key] = value
end
function cls:super(method_name, ...)
return super_call(self, method_name, ...)
end
function cls.setattr(target, key, value, visibility)
return oop.setattr(target, key, value, visibility)
end
function cls.method(name, fn, visibility)
if type(name) ~= "string" then
error("method name must be a string")
end
if type(fn) ~= "function" then
error("method must be a function")
end
cls[name] = fn
cls.__visibility[name] = normalize_visibility(visibility)
end
function cls.inherit(parent)
local actual_parent = parent
if type(parent) == "function" then
actual_parent = function_def_cache[parent]
if not actual_parent then
actual_parent = oop.class(parent)
function_def_cache[parent] = actual_parent
end
end
cls.__base = actual_parent
cls.super = setmetatable({ __class = actual_parent }, {
__index = function(_, key)
local val = actual_parent[key]
if type(val) == "function" then
return function(this, ...)
return call_with_context(actual_parent, val, this, ...)
end
end
return val
end
})
local mt = getmetatable(cls) or {}
mt.__index = actual_parent
mt.__call = mt.__call or function(c, ...)
return oop.new(c, ...)
end
setmetatable(cls, mt)
end
setmetatable(cls, {
__index = base,
__call = function(c, ...)
return oop.new(c, ...)
end
})
if type(def) == "string" then
cls.__name = def
elseif type(def) == "function" then
def(cls)
elseif def ~= nil then
error("class definition must be a function or name string")
end
if base then
cls.inherit(base)
end
return cls
end
function oop.isinstance(obj, class)
local mt = getmetatable(obj)
while mt do
if mt == class then
return true
end
mt = rawget(mt, "__base")
end
return false
end
function oop.issubclass(class, base)
local mt = class
while mt do
if mt == base then
return true
end
mt = rawget(mt, "__base")
end
return false
end
function oop.install(env)
local target = env or _G
target.class = oop.class
target.new = oop.new
target.setattr = oop.setattr
target.isinstance = oop.isinstance
target.issubclass = oop.issubclass
target.Visibility = oop.Visibility
local mt = getmetatable(target)
if not mt then
mt = {}
setmetatable(target, mt)
end
local prev_newindex = mt.__newindex
mt.__newindex = function(t, k, v)
local val = v
if type(k) == "string" and type(v) == "function" and k:match("^[A-Z]") then
val = oop.class(v)
val.__name = k
function_def_cache[v] = val
end
if prev_newindex then
return prev_newindex(t, k, val)
end
rawset(t, k, val)
end
end
oop.install()
return oop