diff --git a/main.lua b/main.lua index 192eb6f..074fc34 100644 --- a/main.lua +++ b/main.lua @@ -1,97 +1,48 @@ oop = require("oop/oop") -function Animal(cls) - function cls.__init(this, name) - this.name = name - end +function Vault(cls) + cls.setattr(cls, "public_note", nil, "public") + cls.setattr(cls, "protected_note", nil, "protected") + cls.setattr(cls, "private_note", nil, "private") - function cls.speak(this) - print(this.name .. " makes a noise.") - end + cls.method("public_method", function(this) + print("public_method:", this.public_note) + end, "public") - function cls.info(this) - print("Animal: " .. this.name) - end -end + cls.method("protected_method", function(this) + print("protected_method:", this.protected_note) + end, "protected") -function Dog(cls) - cls.inherit(Animal) - - function cls.speak(this) - cls.super.speak(this) - print(this.name .. " barks!") - end -end - -function Cat(cls) - cls.inherit(Animal) - - function cls.speak(this) - print(this.name .. " meows.") - end -end - -function Bird(cls) - cls.inherit(Animal) - - function cls.speak(this) - print(this.name .. " chirps.") - end -end - -function CLI(cls) - cls.setattr(cls, "classes", { - Dog = Dog, - Cat = Cat, - Bird = Bird - }) + cls.method("private_method", function(this) + print("private_method:", this.private_note) + end, "private") function cls.__init(this) - this.classes = cls.classes - end - - function cls.prompt(this, label) - io.write(label) - return io.read() - end - - function cls.handleActions(this, obj) - print("Type an action: speak, info (or 'back' to choose another animal).") - while true do - local action = this.prompt("action> ") - if not action or action == "back" then - return - end - local method = obj[action] - if type(method) == "function" then - method() - else - print("Unknown action.") - end - end - end - - function cls.handleAnimal(this, choice) - local cls_ref = this.classes[choice] - if not cls_ref then - print("Unknown animal.") - return - end - local obj = oop.new(cls_ref, choice) - this.handleActions(obj) - end - - function cls.run(this) - print("Type a class name: Dog, Cat, Bird (or 'quit' to exit).") - while true do - local choice = this.prompt("animal> ") - if not choice or choice == "quit" then - break - end - this.handleAnimal(choice) - end + this.public_note = "hello" + this.protected_note = "shielded" + this.private_note = "secret" end end -local cli = oop.new(CLI) -cli.run() +function SubVault(cls) + cls.inherit(Vault) + + function cls.show_all(this) + print("inside SubVault:") + this.public_method() + this.protected_method() + this.private_method() + end +end + +local v = oop.new(Vault) +print("outside Vault:") +print("public_note:", v.public_note) +v.public_method() +print("protected_note:", v.protected_note) +print("private_note:", v.private_note) +v.protected_method() +v.private_method() + +local s = oop.new(SubVault) +s.show_all() diff --git a/oop/oop.lua b/oop/oop.lua index 0fd42d7..f65faf4 100644 --- a/oop/oop.lua +++ b/oop/oop.lua @@ -1,5 +1,14 @@ local oop = {} +local tpack = table.pack or function(...) + return { n = select("#", ...), ... } +end +local tunpack = table.unpack or unpack + +local lookup_in_class +local lookup_with_owner +local lookup_visibility + local function super_call(self, method_name, ...) local class = getmetatable(self) if not class then @@ -17,9 +26,59 @@ local function super_call(self, method_name, ...) end local function_def_cache = setmetatable({}, { __mode = "k" }) +oop._call_stack = {} -function oop.setattr(target, key, value) +local function normalize_visibility(visibility) + if visibility == nil then + return "public" + end + if visibility == "public" or visibility == "protected" or visibility == "private" then + return visibility + end + error("visibility must be 'public', 'protected', or 'private'") +end + +local function is_class_table(target) + return type(target) == "table" and rawget(target, "__visibility") ~= nil +end + +local function call_with_context(owner, fn, self, ...) + local stack = oop._call_stack + stack[#stack + 1] = owner + local results = tpack(pcall(fn, self, ...)) + stack[#stack] = nil + if not results[1] then + error(results[2], 0) + end + return tunpack(results, 2, results.n) +end + +local function current_caller() + return oop._call_stack[#oop._call_stack] +end + +local function can_access(owner, visibility, caller) + if visibility == "public" then + return true + end + if not caller or not owner then + return false + end + if visibility == "private" then + return caller == owner + end + if visibility == "protected" then + return oop.issubclass(caller, owner) + end + return false +end + +function oop.setattr(target, key, value, visibility) target[key] = value + if is_class_table(target) then + local vis = normalize_visibility(visibility) + target.__visibility[key] = vis + end end function oop.new(class, ...) @@ -32,13 +91,15 @@ function oop.new(class, ...) end end local obj = setmetatable({}, actual_class) - if actual_class.__init then - actual_class.__init(obj, ...) + rawset(obj, "__fields", {}) + local init_fn, init_owner = lookup_with_owner(actual_class, "__init") + if init_fn then + call_with_context(init_owner, init_fn, obj, ...) end return obj end -local function lookup_in_class(cls, key) +function lookup_in_class(cls, key) local cur = cls while cur do local val = rawget(cur, key) @@ -50,25 +111,99 @@ local function lookup_in_class(cls, key) return nil end +function lookup_with_owner(cls, key) + local cur = cls + while cur do + local val = rawget(cur, key) + if val ~= nil then + return val, cur + end + cur = rawget(cur, "__base") + end + return nil, nil +end + +function lookup_visibility(cls, key) + local cur = cls + while cur do + local vis_table = rawget(cur, "__visibility") + if vis_table then + local vis = vis_table[key] + if vis then + return vis, cur + end + end + cur = rawget(cur, "__base") + end + return "public", nil +end + function oop.class(def, base) local cls = {} cls.__name = "Anonymous" cls.__base = base + cls.__visibility = {} cls.__index = function(self, key) - local val = lookup_in_class(cls, key) + local fields = rawget(self, "__fields") + if fields then + local field_val = fields[key] + if field_val ~= nil then + local vis, owner = lookup_visibility(cls, key) + if not can_access(owner, vis, current_caller()) then + error("attempt to access " .. vis .. " member '" .. tostring(key) .. "'") + end + return field_val + end + end + + local val, owner = lookup_with_owner(cls, key) + if val == nil then + return nil + end + local vis, vis_owner = lookup_visibility(cls, key) + local access_owner = vis_owner or owner + if not can_access(access_owner, vis, current_caller()) then + error("attempt to access " .. vis .. " member '" .. tostring(key) .. "'") + end if type(val) == "function" then return function(...) - return val(self, ...) + return call_with_context(owner, val, self, ...) end end return val end + cls.__newindex = function(self, key, value) + if key == "__fields" then + rawset(self, key, value) + return + end + local vis, owner = lookup_visibility(cls, key) + if owner and not can_access(owner, vis, current_caller()) then + error("attempt to set " .. vis .. " member '" .. tostring(key) .. "'") + end + local fields = rawget(self, "__fields") + if not fields then + fields = {} + rawset(self, "__fields", fields) + end + fields[key] = value + end function cls:super(method_name, ...) return super_call(self, method_name, ...) end - function cls.setattr(target, key, value) - return oop.setattr(target, key, value) + function cls.setattr(target, key, value, visibility) + return oop.setattr(target, key, value, visibility) + end + function cls.method(name, fn, visibility) + if type(name) ~= "string" then + error("method name must be a string") + end + if type(fn) ~= "function" then + error("method must be a function") + end + cls[name] = fn + cls.__visibility[name] = normalize_visibility(visibility) end function cls.inherit(parent) local actual_parent = parent @@ -80,7 +215,17 @@ function oop.class(def, base) end end cls.__base = actual_parent - cls.super = actual_parent + cls.super = setmetatable({ __class = actual_parent }, { + __index = function(_, key) + local val = actual_parent[key] + if type(val) == "function" then + return function(this, ...) + return call_with_context(actual_parent, val, this, ...) + end + end + return val + end + }) local mt = getmetatable(cls) or {} mt.__index = actual_parent mt.__call = mt.__call or function(c, ...)