Compare commits

...

10 Commits

4 changed files with 601 additions and 66 deletions

104
main.lua
View File

@ -1,59 +1,61 @@
oop = require("oop/oop")
local oop = require("oop/oop")
local Types = require("oop/types")
function BankAccount(cls)
cls.setattr(cls, "owner", nil, cls.visibility.PUBLIC)
cls.setattr(cls, "balance", nil, cls.visibility.PUBLIC)
cls.setattr(cls, "pin", nil, cls.visibility.PROTECTED)
cls.setattr(cls, "token", nil, cls.visibility.PRIVATE)
cls.method("deposit", function(this, amount)
this.balance = this.balance + amount
print("deposit:", amount, "balance:", this.balance)
end, cls.visibility.PUBLIC)
cls.method("withdraw", function(this, amount, pin)
if pin ~= this.pin then
print("bad pin")
return
end
this.balance = this.balance - amount
print("withdraw:", amount, "balance:", this.balance)
end, cls.visibility.PUBLIC)
cls.method("audit", function(this)
print("audit:", this.owner, "balance:", this.balance)
end, cls.visibility.PROTECTED)
cls.method("rotate_token", function(this)
this.token = "t-" .. tostring(os.time())
print("token rotated")
end, cls.visibility.PRIVATE)
function cls.__init(this, owner, pin)
this.owner = owner
this.balance = 0
this.pin = pin
this.token = "t-" .. tostring(os.time())
local function attempt(label, fn)
local ok, err = pcall(fn)
if ok then
print(label .. ": ok")
else
print(label .. ": " .. err)
end
end
function PremiumAccount(cls)
cls.inherit(BankAccount)
local i = Types.Int(10)
local f = Types.Float(2.5)
local s = Types.Str("hi")
local b = Types.Bool(true)
local t = Types.Tuple({ 1, 2, 3 })
local l = Types.List({ "a", "b" })
local q = Types.Queue({ 1, 2 })
local st = Types.Stack({ 9, 8 })
local d = Types.Dict({ one = 1, two = 2 })
local se = Types.Set({ "x", "y" })
function cls.monthly_bonus(this)
this.deposit(25)
this.audit()
end
end
print("Int:", i)
print("Float:", f)
print("Str:", s, "#", #s)
print("Bool:", b)
print("Tuple:", t, "#", #t, "get2", t.get(2))
print("List:", l, "#", #l, "get1", l.get(1))
print("Queue:", q, "#", #q, "peek", q.peek())
print("Stack:", st, "#", #st, "peek", st.peek())
print("Dict:", d, "#", #d, "get(two)", d.get("two"))
print("Set:", se, "#", #se, "has(x)", se.has("x"))
local acct = oop.new(BankAccount, "Kim", 1234)
acct.deposit(100)
acct.withdraw(30, 1234)
print("Int add:", i + Types.Int(5))
print("Float mul:", f * Types.Float(2.0))
print("Str concat:", s .. Types.Str("!"))
print("Bool eq:", b == Types.Bool(true))
print("Int call:", i())
print("Float call:", f())
print("Str call:", s())
print("Bool call:", b())
print("Tuple call:", t())
print("Int copy:", i.copy())
print("Tuple copy:", t.copy())
local vip = oop.new(PremiumAccount, "Sam", 9999)
vip.monthly_bonus()
local ok, err = pcall(function()
vip.rotate_token()
attempt("Int mutate", function()
i.value = 99
end)
attempt("Float mutate", function()
f.value = 1.25
end)
attempt("Str mutate", function()
s.value = "nope"
end)
attempt("Bool mutate", function()
b.value = false
end)
attempt("Tuple mutate", function()
t.items[1] = 99
end)
print("rotate_token from outside:", ok, err)

View File

@ -34,6 +34,38 @@ end
local function_def_cache = setmetatable({}, { __mode = "k" })
oop._call_stack = {}
local default_dunders = {
"__add",
"__sub",
"__mul",
"__div",
"__idiv",
"__mod",
"__pow",
"__unm",
"__band",
"__bor",
"__bxor",
"__bnot",
"__shl",
"__shr",
"__concat",
"__len",
"__eq",
"__lt",
"__le",
"__pairs",
"__ipairs",
"__call",
"__tostring"
}
local function deny_op(name)
return function()
error("operation not allowed: " .. name, 2)
end
end
local function normalize_visibility(visibility)
if visibility == nil then
return oop.Visibility.PUBLIC
@ -263,6 +295,12 @@ function oop.class(def, base)
end
})
for _, name in ipairs(default_dunders) do
if rawget(cls, name) == nil then
cls[name] = deny_op(name)
end
end
if type(def) == "string" then
cls.__name = def
elseif type(def) == "function" then

510
oop/types.lua Normal file
View File

@ -0,0 +1,510 @@
local oop = require("oop/oop")
local Types = {}
local function freeze_table(t)
return setmetatable({}, {
__index = t,
__newindex = function()
error("Tuple is immutable", 2)
end,
__len = function()
return #t
end,
__pairs = function()
return pairs(t)
end,
__ipairs = function()
return ipairs(t)
end
})
end
local function PrimitiveDef(cls)
function cls._freeze(this)
rawset(this, "__immutable", true)
end
function cls._thaw(this)
rawset(this, "__immutable", false)
end
function cls.copy(this)
local cls_ref = getmetatable(this)
local items = rawget(this, "__items")
if items then
return cls_ref(items)
end
return cls_ref(this.value)
end
end
Types.Primitive = oop.class(PrimitiveDef)
Types.Primitive.__name = "Primitive"
local function make_primitive(def, name)
local cls = oop.class(def)
cls.__name = name
cls.inherit(Types.Primitive)
local base_newindex = cls.__newindex
cls.__newindex = function(self, key, value)
if rawget(self, "__immutable") and key ~= "__fields" then
local mt = getmetatable(self)
local cname = mt and mt.__name or name or "Primitive"
error(cname .. " is immutable", 2)
end
return base_newindex(self, key, value)
end
function cls.__call(this)
local items = rawget(this, "__items")
if items then
return this.items
end
return this.value
end
return cls
end
local function IntDef(cls)
function cls.__init(this, value)
rawset(this, "__immutable", false)
this.value = math.floor(tonumber(value) or 0)
rawset(this, "__immutable", true)
end
function cls.__add(a, b)
return Types.Int(a.value + b.value)
end
function cls.__sub(a, b)
return Types.Int(a.value - b.value)
end
function cls.__mul(a, b)
if type(a) == "number" then
return Types.Int(a * b.value)
end
if type(b) == "number" then
return Types.Int(a.value * b)
end
return Types.Int(a.value * b.value)
end
function cls.__eq(a, b)
return a.value == b.value
end
function cls.__tostring(this)
return tostring(this.value)
end
function cls.to_number(this)
return this.value
end
end
local function FloatDef(cls)
function cls.__init(this, value)
rawset(this, "__immutable", false)
this.value = tonumber(value) or 0.0
rawset(this, "__immutable", true)
end
function cls.__add(a, b)
return Types.Float(a.value + b.value)
end
function cls.__sub(a, b)
return Types.Float(a.value - b.value)
end
function cls.__mul(a, b)
if type(a) == "number" then
return Types.Float(a * b.value)
end
if type(b) == "number" then
return Types.Float(a.value * b)
end
return Types.Float(a.value * b.value)
end
function cls.__eq(a, b)
return a.value == b.value
end
function cls.__tostring(this)
return tostring(this.value)
end
function cls.to_number(this)
return this.value
end
end
local function ListDef(cls)
function cls.__init(this, items)
this.head = nil
this.tail = nil
this.size = 0
if type(items) == "table" then
for _, v in ipairs(items) do
this.append(v)
end
end
end
function cls.__len(this)
return this.size
end
function cls.__tostring(this)
local parts = {}
local node = this.head
while node do
parts[#parts + 1] = tostring(node.value)
node = node.next
end
return "[" .. table.concat(parts, ", ") .. "]"
end
function cls.append(this, value)
local node = { value = value, next = nil }
if this.tail then
this.tail.next = node
else
this.head = node
end
this.tail = node
this.size = this.size + 1
end
function cls.get(this, index)
if type(index) ~= "number" or index < 1 or index > this.size then
return nil
end
local node = this.head
local i = 1
while node do
if i == index then
return node.value
end
node = node.next
i = i + 1
end
return nil
end
function cls.set(this, index, value)
if type(index) ~= "number" or index < 1 or index > this.size then
return false
end
local node = this.head
local i = 1
while node do
if i == index then
node.value = value
return true
end
node = node.next
i = i + 1
end
return false
end
function cls.iter(this)
local node = this.head
return function()
if not node then
return nil
end
local value = node.value
node = node.next
return value
end
end
end
local function TupleDef(cls)
function cls.__init(this, items)
rawset(this, "__immutable", false)
if type(items) == "table" then
local raw_items = { table.unpack(items) }
rawset(this, "__items", raw_items)
this.items = freeze_table(raw_items)
else
local raw_items = {}
rawset(this, "__items", raw_items)
this.items = freeze_table(raw_items)
end
rawset(this, "__immutable", true)
end
function cls.__len(this)
return #this.items
end
function cls.__tostring(this)
local parts = {}
for i, v in ipairs(this.items) do
parts[i] = tostring(v)
end
return "(" .. table.concat(parts, ", ") .. ")"
end
function cls.get(this, index)
return this.items[index]
end
function cls.iter(this)
return ipairs(this.items)
end
end
local function QueueDef(cls)
function cls.__init(this, items)
if type(items) == "table" then
this.items = { table.unpack(items) }
else
this.items = {}
end
this.head = 1
end
function cls.__len(this)
return #this.items - this.head + 1
end
function cls.__tostring(this)
local parts = {}
for i = this.head, #this.items do
parts[#parts + 1] = tostring(this.items[i])
end
return "Queue([" .. table.concat(parts, ", ") .. "])"
end
function cls.enqueue(this, value)
this.items[#this.items + 1] = value
end
function cls.dequeue(this)
if this.head > #this.items then
return nil
end
local value = this.items[this.head]
this.items[this.head] = nil
this.head = this.head + 1
if this.head > 32 and this.head > (#this.items / 2) then
local new_items = {}
for i = this.head, #this.items do
new_items[#new_items + 1] = this.items[i]
end
this.items = new_items
this.head = 1
end
return value
end
function cls.peek(this)
return this.items[this.head]
end
end
local function StackDef(cls)
function cls.__init(this, items)
if type(items) == "table" then
this.items = { table.unpack(items) }
else
this.items = {}
end
end
function cls.__len(this)
return #this.items
end
function cls.__tostring(this)
local parts = {}
for i, v in ipairs(this.items) do
parts[i] = tostring(v)
end
return "Stack([" .. table.concat(parts, ", ") .. "])"
end
function cls.push(this, value)
this.items[#this.items + 1] = value
end
function cls.pop(this)
local n = #this.items
if n == 0 then
return nil
end
local value = this.items[n]
this.items[n] = nil
return value
end
function cls.peek(this)
return this.items[#this.items]
end
end
local function DictDef(cls)
function cls.__init(this, items)
if type(items) == "table" then
this.items = items
else
this.items = {}
end
end
function cls.__len(this)
local count = 0
for _ in pairs(this.items) do
count = count + 1
end
return count
end
function cls.__tostring(this)
local parts = {}
for k, v in pairs(this.items) do
parts[#parts + 1] = tostring(k) .. "=" .. tostring(v)
end
return "Dict({" .. table.concat(parts, ", ") .. "})"
end
function cls.get(this, key, default)
local val = this.items[key]
if val == nil then
return default
end
return val
end
function cls.set(this, key, value)
this.items[key] = value
end
function cls.keys(this)
local keys = {}
for k in pairs(this.items) do
keys[#keys + 1] = k
end
return keys
end
end
local function SetDef(cls)
function cls.__init(this, items)
this.items = {}
if type(items) == "table" then
for _, v in ipairs(items) do
this.items[v] = true
end
end
end
function cls.__len(this)
local count = 0
for _ in pairs(this.items) do
count = count + 1
end
return count
end
function cls.__tostring(this)
local parts = {}
for k in pairs(this.items) do
parts[#parts + 1] = tostring(k)
end
return "Set({" .. table.concat(parts, ", ") .. "})"
end
function cls.add(this, value)
this.items[value] = true
end
function cls.has(this, value)
return this.items[value] == true
end
function cls.remove(this, value)
this.items[value] = nil
end
end
local function BoolDef(cls)
function cls.__init(this, value)
rawset(this, "__immutable", false)
this.value = not not value
rawset(this, "__immutable", true)
end
function cls.__tostring(this)
return tostring(this.value)
end
function cls.__eq(a, b)
return a.value == b.value
end
end
local function StrDef(cls)
function cls.__init(this, value)
rawset(this, "__immutable", false)
this.value = tostring(value or "")
rawset(this, "__immutable", true)
end
function cls.__len(this)
return #this.value
end
function cls.__concat(a, b)
return Types.Str(tostring(a) .. tostring(b))
end
function cls.__tostring(this)
return this.value
end
end
Types.Int = make_primitive(IntDef, "Int")
Types.Float = make_primitive(FloatDef, "Float")
Types.List = oop.class(ListDef)
Types.List.__name = "List"
Types.Tuple = make_primitive(TupleDef, "Tuple")
Types.Tupple = Types.Tuple
Types.Queue = oop.class(QueueDef)
Types.Queue.__name = "Queue"
Types.Que = Types.Queue
Types.Stack = oop.class(StackDef)
Types.Stack.__name = "Stack"
Types.Dict = oop.class(DictDef)
Types.Dict.__name = "Dict"
Types.Set = oop.class(SetDef)
Types.Set.__name = "Set"
Types.Bool = make_primitive(BoolDef, "Bool")
Types.Str = make_primitive(StrDef, "Str")
Types.String = Types.Str
return Types

View File

@ -1,15 +0,0 @@
local function deepcopy(t)
local new = {}
for k, v in pairs(t) do
if type(v) == "table" then
new[k] = deepcopy(v)
else
new[k] = v
end
end
return new
end
return {
deepcopy = deepcopy
}