diff --git a/oop/types.lua b/oop/types.lua index 6fdd12a..ef36efd 100644 --- a/oop/types.lua +++ b/oop/types.lua @@ -2,9 +2,31 @@ local oop = require("oop") local Types = {} +local function freeze_table(t) + return setmetatable({}, { + __index = t, + __newindex = function() + error("Tuple is immutable", 2) + end, + __len = function() + return #t + end, + __pairs = function() + return pairs(t) + end, + __ipairs = function() + return ipairs(t) + end + }) +end + local function IntDef(cls) + local base_newindex = cls.__newindex + function cls.__init(this, value) + rawset(this, "__frozen", false) this.value = math.floor(tonumber(value) or 0) + rawset(this, "__frozen", true) end function cls.__add(a, b) @@ -36,11 +58,22 @@ local function IntDef(cls) function cls.to_number(this) return this.value end + + function cls.__newindex(this, key, value) + if rawget(this, "__frozen") and key ~= "__fields" then + error("Int is immutable", 2) + end + return base_newindex(this, key, value) + end end local function FloatDef(cls) + local base_newindex = cls.__newindex + function cls.__init(this, value) + rawset(this, "__frozen", false) this.value = tonumber(value) or 0.0 + rawset(this, "__frozen", true) end function cls.__add(a, b) @@ -72,6 +105,13 @@ local function FloatDef(cls) function cls.to_number(this) return this.value end + + function cls.__newindex(this, key, value) + if rawget(this, "__frozen") and key ~= "__fields" then + error("Float is immutable", 2) + end + return base_newindex(this, key, value) + end end local function ListDef(cls) @@ -161,10 +201,11 @@ local function TupleDef(cls) local base_newindex = cls.__newindex function cls.__init(this, items) + rawset(this, "__frozen", false) if type(items) == "table" then - this.items = { table.unpack(items) } + this.items = freeze_table({ table.unpack(items) }) else - this.items = {} + this.items = freeze_table({}) end rawset(this, "__frozen", true) end @@ -372,8 +413,12 @@ local function SetDef(cls) end local function BoolDef(cls) + local base_newindex = cls.__newindex + function cls.__init(this, value) + rawset(this, "__frozen", false) this.value = not not value + rawset(this, "__frozen", true) end function cls.__tostring(this) @@ -383,11 +428,22 @@ local function BoolDef(cls) function cls.__eq(a, b) return a.value == b.value end + + function cls.__newindex(this, key, value) + if rawget(this, "__frozen") and key ~= "__fields" then + error("Bool is immutable", 2) + end + return base_newindex(this, key, value) + end end local function StrDef(cls) + local base_newindex = cls.__newindex + function cls.__init(this, value) + rawset(this, "__frozen", false) this.value = tostring(value or "") + rawset(this, "__frozen", true) end function cls.__len(this) @@ -401,6 +457,13 @@ local function StrDef(cls) function cls.__tostring(this) return this.value end + + function cls.__newindex(this, key, value) + if rawget(this, "__frozen") and key ~= "__fields" then + error("Str is immutable", 2) + end + return base_newindex(this, key, value) + end end Types.Int = oop.class(IntDef)