local oop = require("oop") local Types = {} local function freeze_table(t) return setmetatable({}, { __index = t, __newindex = function() error("Tuple is immutable", 2) end, __len = function() return #t end, __pairs = function() return pairs(t) end, __ipairs = function() return ipairs(t) end }) end local function IntDef(cls) local base_newindex = cls.__newindex function cls.__init(this, value) rawset(this, "__frozen", false) this.value = math.floor(tonumber(value) or 0) rawset(this, "__frozen", true) end function cls.__add(a, b) return Types.Int(a.value + b.value) end function cls.__sub(a, b) return Types.Int(a.value - b.value) end function cls.__mul(a, b) if type(a) == "number" then return Types.Int(a * b.value) end if type(b) == "number" then return Types.Int(a.value * b) end return Types.Int(a.value * b.value) end function cls.__eq(a, b) return a.value == b.value end function cls.__tostring(this) return tostring(this.value) end function cls.to_number(this) return this.value end function cls.__newindex(this, key, value) if rawget(this, "__frozen") and key ~= "__fields" then error("Int is immutable", 2) end return base_newindex(this, key, value) end end local function FloatDef(cls) local base_newindex = cls.__newindex function cls.__init(this, value) rawset(this, "__frozen", false) this.value = tonumber(value) or 0.0 rawset(this, "__frozen", true) end function cls.__add(a, b) return Types.Float(a.value + b.value) end function cls.__sub(a, b) return Types.Float(a.value - b.value) end function cls.__mul(a, b) if type(a) == "number" then return Types.Float(a * b.value) end if type(b) == "number" then return Types.Float(a.value * b) end return Types.Float(a.value * b.value) end function cls.__eq(a, b) return a.value == b.value end function cls.__tostring(this) return tostring(this.value) end function cls.to_number(this) return this.value end function cls.__newindex(this, key, value) if rawget(this, "__frozen") and key ~= "__fields" then error("Float is immutable", 2) end return base_newindex(this, key, value) end end local function ListDef(cls) function cls.__init(this, items) this.head = nil this.tail = nil this.size = 0 if type(items) == "table" then for _, v in ipairs(items) do this.append(v) end end end function cls.__len(this) return this.size end function cls.__tostring(this) local parts = {} local node = this.head while node do parts[#parts + 1] = tostring(node.value) node = node.next end return "[" .. table.concat(parts, ", ") .. "]" end function cls.append(this, value) local node = { value = value, next = nil } if this.tail then this.tail.next = node else this.head = node end this.tail = node this.size = this.size + 1 end function cls.get(this, index) if type(index) ~= "number" or index < 1 or index > this.size then return nil end local node = this.head local i = 1 while node do if i == index then return node.value end node = node.next i = i + 1 end return nil end function cls.set(this, index, value) if type(index) ~= "number" or index < 1 or index > this.size then return false end local node = this.head local i = 1 while node do if i == index then node.value = value return true end node = node.next i = i + 1 end return false end function cls.iter(this) local node = this.head return function() if not node then return nil end local value = node.value node = node.next return value end end end local function TupleDef(cls) local base_newindex = cls.__newindex function cls.__init(this, items) rawset(this, "__frozen", false) if type(items) == "table" then this.items = freeze_table({ table.unpack(items) }) else this.items = freeze_table({}) end rawset(this, "__frozen", true) end function cls.__len(this) return #this.items end function cls.__tostring(this) local parts = {} for i, v in ipairs(this.items) do parts[i] = tostring(v) end return "(" .. table.concat(parts, ", ") .. ")" end function cls.get(this, index) return this.items[index] end function cls.iter(this) return ipairs(this.items) end function cls.__newindex(this, key, value) if rawget(this, "__frozen") and key ~= "__fields" then error("Tuple is immutable", 2) end return base_newindex(this, key, value) end end local function QueueDef(cls) function cls.__init(this, items) if type(items) == "table" then this.items = { table.unpack(items) } else this.items = {} end this.head = 1 end function cls.__len(this) return #this.items - this.head + 1 end function cls.__tostring(this) local parts = {} for i = this.head, #this.items do parts[#parts + 1] = tostring(this.items[i]) end return "Queue([" .. table.concat(parts, ", ") .. "])" end function cls.enqueue(this, value) this.items[#this.items + 1] = value end function cls.dequeue(this) if this.head > #this.items then return nil end local value = this.items[this.head] this.items[this.head] = nil this.head = this.head + 1 if this.head > 32 and this.head > (#this.items / 2) then local new_items = {} for i = this.head, #this.items do new_items[#new_items + 1] = this.items[i] end this.items = new_items this.head = 1 end return value end function cls.peek(this) return this.items[this.head] end end local function StackDef(cls) function cls.__init(this, items) if type(items) == "table" then this.items = { table.unpack(items) } else this.items = {} end end function cls.__len(this) return #this.items end function cls.__tostring(this) local parts = {} for i, v in ipairs(this.items) do parts[i] = tostring(v) end return "Stack([" .. table.concat(parts, ", ") .. "])" end function cls.push(this, value) this.items[#this.items + 1] = value end function cls.pop(this) local n = #this.items if n == 0 then return nil end local value = this.items[n] this.items[n] = nil return value end function cls.peek(this) return this.items[#this.items] end end local function DictDef(cls) function cls.__init(this, items) if type(items) == "table" then this.items = items else this.items = {} end end function cls.__len(this) local count = 0 for _ in pairs(this.items) do count = count + 1 end return count end function cls.__tostring(this) local parts = {} for k, v in pairs(this.items) do parts[#parts + 1] = tostring(k) .. "=" .. tostring(v) end return "Dict({" .. table.concat(parts, ", ") .. "})" end function cls.get(this, key, default) local val = this.items[key] if val == nil then return default end return val end function cls.set(this, key, value) this.items[key] = value end function cls.keys(this) local keys = {} for k in pairs(this.items) do keys[#keys + 1] = k end return keys end end local function SetDef(cls) function cls.__init(this, items) this.items = {} if type(items) == "table" then for _, v in ipairs(items) do this.items[v] = true end end end function cls.__len(this) local count = 0 for _ in pairs(this.items) do count = count + 1 end return count end function cls.__tostring(this) local parts = {} for k in pairs(this.items) do parts[#parts + 1] = tostring(k) end return "Set({" .. table.concat(parts, ", ") .. "})" end function cls.add(this, value) this.items[value] = true end function cls.has(this, value) return this.items[value] == true end function cls.remove(this, value) this.items[value] = nil end end local function BoolDef(cls) local base_newindex = cls.__newindex function cls.__init(this, value) rawset(this, "__frozen", false) this.value = not not value rawset(this, "__frozen", true) end function cls.__tostring(this) return tostring(this.value) end function cls.__eq(a, b) return a.value == b.value end function cls.__newindex(this, key, value) if rawget(this, "__frozen") and key ~= "__fields" then error("Bool is immutable", 2) end return base_newindex(this, key, value) end end local function StrDef(cls) local base_newindex = cls.__newindex function cls.__init(this, value) rawset(this, "__frozen", false) this.value = tostring(value or "") rawset(this, "__frozen", true) end function cls.__len(this) return #this.value end function cls.__concat(a, b) return Types.Str(tostring(a) .. tostring(b)) end function cls.__tostring(this) return this.value end function cls.__newindex(this, key, value) if rawget(this, "__frozen") and key ~= "__fields" then error("Str is immutable", 2) end return base_newindex(this, key, value) end end Types.Int = oop.class(IntDef) Types.Int.__name = "Int" Types.Float = oop.class(FloatDef) Types.Float.__name = "Float" Types.List = oop.class(ListDef) Types.List.__name = "List" Types.Tuple = oop.class(TupleDef) Types.Tuple.__name = "Tuple" Types.Tupple = Types.Tuple Types.Queue = oop.class(QueueDef) Types.Queue.__name = "Queue" Types.Que = Types.Queue Types.Stack = oop.class(StackDef) Types.Stack.__name = "Stack" Types.Dict = oop.class(DictDef) Types.Dict.__name = "Dict" Types.Set = oop.class(SetDef) Types.Set.__name = "Set" Types.Bool = oop.class(BoolDef) Types.Bool.__name = "Bool" Types.Str = oop.class(StrDef) Types.Str.__name = "Str" Types.String = Types.Str return Types