local oop = require("oop/oop") local Types = {} local function freeze_table(t) return setmetatable({}, { __index = t, __newindex = function() error("Tuple is immutable", 2) end, __len = function() return #t end, __pairs = function() return pairs(t) end, __ipairs = function() return ipairs(t) end }) end local function PrimitiveDef(cls) function cls._freeze(this) rawset(this, "__immutable", true) end function cls._thaw(this) rawset(this, "__immutable", false) end function cls.copy(this) local cls_ref = getmetatable(this) local items = rawget(this, "__items") if items then return cls_ref(items) end return cls_ref(this.value) end end Types.Primitive = oop.class(PrimitiveDef) Types.Primitive.__name = "Primitive" local function make_primitive(def, name) local cls = oop.class(def) cls.__name = name cls.inherit(Types.Primitive) local base_newindex = cls.__newindex cls.__newindex = function(self, key, value) if rawget(self, "__immutable") and key ~= "__fields" then local mt = getmetatable(self) local cname = mt and mt.__name or name or "Primitive" error(cname .. " is immutable", 2) end return base_newindex(self, key, value) end function cls.__call(this) local items = rawget(this, "__items") if items then return this.items end return this.value end return cls end local function IntDef(cls) function cls.__init(this, value) rawset(this, "__immutable", false) this.value = math.floor(tonumber(value) or 0) rawset(this, "__immutable", true) end function cls.__add(a, b) return Types.Int(a.value + b.value) end function cls.__sub(a, b) return Types.Int(a.value - b.value) end function cls.__mul(a, b) if type(a) == "number" then return Types.Int(a * b.value) end if type(b) == "number" then return Types.Int(a.value * b) end return Types.Int(a.value * b.value) end function cls.__eq(a, b) return a.value == b.value end function cls.__tostring(this) return tostring(this.value) end function cls.to_number(this) return this.value end end local function FloatDef(cls) function cls.__init(this, value) rawset(this, "__immutable", false) this.value = tonumber(value) or 0.0 rawset(this, "__immutable", true) end function cls.__add(a, b) return Types.Float(a.value + b.value) end function cls.__sub(a, b) return Types.Float(a.value - b.value) end function cls.__mul(a, b) if type(a) == "number" then return Types.Float(a * b.value) end if type(b) == "number" then return Types.Float(a.value * b) end return Types.Float(a.value * b.value) end function cls.__eq(a, b) return a.value == b.value end function cls.__tostring(this) return tostring(this.value) end function cls.to_number(this) return this.value end end local function ListDef(cls) function cls.__init(this, items) this.head = nil this.tail = nil this.size = 0 if type(items) == "table" then for _, v in ipairs(items) do this.append(v) end end end function cls.__len(this) return this.size end function cls.__tostring(this) local parts = {} local node = this.head while node do parts[#parts + 1] = tostring(node.value) node = node.next end return "[" .. table.concat(parts, ", ") .. "]" end function cls.append(this, value) local node = { value = value, next = nil } if this.tail then this.tail.next = node else this.head = node end this.tail = node this.size = this.size + 1 end function cls.get(this, index) if type(index) ~= "number" or index < 1 or index > this.size then return nil end local node = this.head local i = 1 while node do if i == index then return node.value end node = node.next i = i + 1 end return nil end function cls.set(this, index, value) if type(index) ~= "number" or index < 1 or index > this.size then return false end local node = this.head local i = 1 while node do if i == index then node.value = value return true end node = node.next i = i + 1 end return false end function cls.iter(this) local node = this.head return function() if not node then return nil end local value = node.value node = node.next return value end end end local function TupleDef(cls) function cls.__init(this, items) rawset(this, "__immutable", false) if type(items) == "table" then local raw_items = { table.unpack(items) } rawset(this, "__items", raw_items) this.items = freeze_table(raw_items) else local raw_items = {} rawset(this, "__items", raw_items) this.items = freeze_table(raw_items) end rawset(this, "__immutable", true) end function cls.__len(this) return #this.items end function cls.__tostring(this) local parts = {} for i, v in ipairs(this.items) do parts[i] = tostring(v) end return "(" .. table.concat(parts, ", ") .. ")" end function cls.get(this, index) return this.items[index] end function cls.iter(this) return ipairs(this.items) end end local function QueueDef(cls) function cls.__init(this, items) if type(items) == "table" then this.items = { table.unpack(items) } else this.items = {} end this.head = 1 end function cls.__len(this) return #this.items - this.head + 1 end function cls.__tostring(this) local parts = {} for i = this.head, #this.items do parts[#parts + 1] = tostring(this.items[i]) end return "Queue([" .. table.concat(parts, ", ") .. "])" end function cls.enqueue(this, value) this.items[#this.items + 1] = value end function cls.dequeue(this) if this.head > #this.items then return nil end local value = this.items[this.head] this.items[this.head] = nil this.head = this.head + 1 if this.head > 32 and this.head > (#this.items / 2) then local new_items = {} for i = this.head, #this.items do new_items[#new_items + 1] = this.items[i] end this.items = new_items this.head = 1 end return value end function cls.peek(this) return this.items[this.head] end end local function StackDef(cls) function cls.__init(this, items) if type(items) == "table" then this.items = { table.unpack(items) } else this.items = {} end end function cls.__len(this) return #this.items end function cls.__tostring(this) local parts = {} for i, v in ipairs(this.items) do parts[i] = tostring(v) end return "Stack([" .. table.concat(parts, ", ") .. "])" end function cls.push(this, value) this.items[#this.items + 1] = value end function cls.pop(this) local n = #this.items if n == 0 then return nil end local value = this.items[n] this.items[n] = nil return value end function cls.peek(this) return this.items[#this.items] end end local function DictDef(cls) function cls.__init(this, items) if type(items) == "table" then this.items = items else this.items = {} end end function cls.__len(this) local count = 0 for _ in pairs(this.items) do count = count + 1 end return count end function cls.__tostring(this) local parts = {} for k, v in pairs(this.items) do parts[#parts + 1] = tostring(k) .. "=" .. tostring(v) end return "Dict({" .. table.concat(parts, ", ") .. "})" end function cls.get(this, key, default) local val = this.items[key] if val == nil then return default end return val end function cls.set(this, key, value) this.items[key] = value end function cls.keys(this) local keys = {} for k in pairs(this.items) do keys[#keys + 1] = k end return keys end end local function SetDef(cls) function cls.__init(this, items) this.items = {} if type(items) == "table" then for _, v in ipairs(items) do this.items[v] = true end end end function cls.__len(this) local count = 0 for _ in pairs(this.items) do count = count + 1 end return count end function cls.__tostring(this) local parts = {} for k in pairs(this.items) do parts[#parts + 1] = tostring(k) end return "Set({" .. table.concat(parts, ", ") .. "})" end function cls.add(this, value) this.items[value] = true end function cls.has(this, value) return this.items[value] == true end function cls.remove(this, value) this.items[value] = nil end end local function BoolDef(cls) function cls.__init(this, value) rawset(this, "__immutable", false) this.value = not not value rawset(this, "__immutable", true) end function cls.__tostring(this) return tostring(this.value) end function cls.__eq(a, b) return a.value == b.value end end local function StrDef(cls) function cls.__init(this, value) rawset(this, "__immutable", false) this.value = tostring(value or "") rawset(this, "__immutable", true) end function cls.__len(this) return #this.value end function cls.__concat(a, b) return Types.Str(tostring(a) .. tostring(b)) end function cls.__tostring(this) return this.value end end Types.Int = make_primitive(IntDef, "Int") Types.Float = make_primitive(FloatDef, "Float") Types.List = oop.class(ListDef) Types.List.__name = "List" Types.Tuple = make_primitive(TupleDef, "Tuple") Types.Tupple = Types.Tuple Types.Queue = oop.class(QueueDef) Types.Queue.__name = "Queue" Types.Que = Types.Queue Types.Stack = oop.class(StackDef) Types.Stack.__name = "Stack" Types.Dict = oop.class(DictDef) Types.Dict.__name = "Dict" Types.Set = oop.class(SetDef) Types.Set.__name = "Set" Types.Bool = make_primitive(BoolDef, "Bool") Types.Str = make_primitive(StrDef, "Str") Types.String = Types.Str return Types