diff --git a/src/db/Migration.java b/src/db/Migration.java deleted file mode 100644 index fd2ee7f..0000000 --- a/src/db/Migration.java +++ /dev/null @@ -1,59 +0,0 @@ -package src.db; - -import java.sql.Connection; -import java.sql.Statement; - -public class Migration { - public static void run(Connection conn) { - try (Statement stmt = conn.createStatement()) { - stmt.execute(""" - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - email TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL - ) - """); - - stmt.execute(""" - CREATE TABLE IF NOT EXISTS accounts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - owner_id INTEGER NOT NULL, - type TEXT CHECK(type IN ('CHECKING','SAVINGS','CREDIT')) NOT NULL, - account_number TEXT NOT NULL UNIQUE, - bank_code TEXT NOT NULL, - balance REAL NOT NULL DEFAULT 0, - FOREIGN KEY(owner_id) REFERENCES users(id) - ) - """); - - stmt.execute(""" - CREATE TABLE IF NOT EXISTS giro_accounts ( - id INTEGER PRIMARY KEY, - overdraft_limit REAL DEFAULT 0, - FOREIGN KEY(id) REFERENCES accounts(id) ON DELETE CASCADE - ) - """); - - stmt.execute(""" - CREATE TABLE IF NOT EXISTS spar_accounts ( - id INTEGER PRIMARY KEY, - interest_rate REAL DEFAULT 0, - FOREIGN KEY(id) REFERENCES accounts(id) ON DELETE CASCADE - ) - """); - - stmt.execute(""" - CREATE TABLE IF NOT EXISTS kredit_accounts ( - id INTEGER PRIMARY KEY, - credit_limit REAL DEFAULT 0, - repayment_plan TEXT, - FOREIGN KEY(id) REFERENCES accounts(id) ON DELETE CASCADE - ) - """); - - } catch (Exception e) { - e.printStackTrace(); - } - } -} diff --git a/src/models/Model.java b/src/models/Model.java deleted file mode 100644 index e99ecab..0000000 --- a/src/models/Model.java +++ /dev/null @@ -1,16 +0,0 @@ -package src.models; - -import java.sql.Connection; -import src.db.Database; - -public abstract class Model { - protected Connection conn; - - public Model() { - try { - this.conn = Database.getConnection(); - } catch (java.sql.SQLException e) { - throw new RuntimeException("Failed to get DB connection", e); - } - } -} diff --git a/src/db/Database.java b/src/models/squirrel/Database.java similarity index 76% rename from src/db/Database.java rename to src/models/squirrel/Database.java index 92495ce..a7cd178 100644 --- a/src/db/Database.java +++ b/src/models/squirrel/Database.java @@ -1,19 +1,22 @@ -package src.db; +package src.models.squirrel; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; +import java.util.function.Consumer; public class Database { private static Connection conn; + public static Consumer migrate; public static Connection getConnection() throws SQLException { if (conn == null || conn.isClosed()) { String url = "jdbc:sqlite:instance/test.db"; conn = DriverManager.getConnection(url); - Migration.run(conn); + if (migrate != null) { + migrate.accept(conn); + } } - return conn; } diff --git a/src/models/squirrel/Model.java b/src/models/squirrel/Model.java new file mode 100644 index 0000000..68970a0 --- /dev/null +++ b/src/models/squirrel/Model.java @@ -0,0 +1,259 @@ +package src.models.squirrel; + +import java.sql.*; +import java.util.*; + +public abstract class Model { + protected static Connection conn; + + protected String tableName; + protected Set columns = new HashSet<>(); + protected Map attributes = new LinkedHashMap<>(); + + protected boolean rowMode = false; + + public Model() { + if (conn == null) { + try { + conn = Database.getConnection(); + } catch (SQLException e) { + throw new RuntimeException("Failed to get DB connection for model '" + getTableName() + "'", e); + } + } + } + + // ------------------------------- + // Query (class/global mode only) + // ------------------------------- + public List all() { + ensureClassMode("all()"); + String sql = "SELECT * FROM " + getTableName(); + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + ResultSet rs = stmt.executeQuery(); + List results = new ArrayList<>(); + while (rs.next()) { + @SuppressWarnings("unchecked") + T instance = (T) this.getClass().getDeclaredConstructor().newInstance(); + instance.seedFromRow(rs); + instance.rowMode = true; + results.add(instance); + } + rs.close(); + return results; + } catch (Exception e) { + throw new RuntimeException("Failed to execute all() for model '" + getTableName() + "'", e); + } + } + + public T find(int id) { + ensureClassMode("find()"); + String sql = "SELECT * FROM " + getTableName() + " WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + stmt.setInt(1, id); + ResultSet rs = stmt.executeQuery(); + if (rs.next()) { + @SuppressWarnings("unchecked") + T instance = (T) this.getClass().getDeclaredConstructor().newInstance(); + instance.seedFromRow(rs); + instance.rowMode = true; + rs.close(); + return instance; + } + } catch (Exception e) { + throw new RuntimeException("Failed to execute find() for model '" + getTableName() + "'", e); + } + return null; + } + + public List where(Map filters) { + ensureClassMode("where()"); + StringBuilder where = new StringBuilder("1=1"); + List values = new ArrayList<>(); + for (var entry : filters.entrySet()) { + where.append(" AND ").append(entry.getKey()).append(" = ?"); + values.add(entry.getValue()); + } + + String sql = "SELECT * FROM " + getTableName() + " WHERE " + where; + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + for (int i = 0; i < values.size(); i++) { + stmt.setObject(i + 1, values.get(i)); + } + ResultSet rs = stmt.executeQuery(); + List results = new ArrayList<>(); + while (rs.next()) { + @SuppressWarnings("unchecked") + T instance = (T) this.getClass().getDeclaredConstructor().newInstance(); + instance.seedFromRow(rs); + instance.rowMode = true; + results.add(instance); + } + rs.close(); + return results; + } catch (Exception e) { + throw new RuntimeException("Failed to execute where() for model '" + getTableName() + "'", e); + } + } + + // ------------------------------- + // Row-only actions + // ------------------------------- + public boolean create() { + ensureRowMode("create()"); + validateColumns(attributes.keySet()); + StringBuilder cols = new StringBuilder(); + StringBuilder vals = new StringBuilder(); + List params = new ArrayList<>(); + for (String col : attributes.keySet()) { + if (col.equals("id")) continue; + if (cols.length() > 0) { + cols.append(", "); + vals.append(", "); + } + cols.append(col); + vals.append("?"); + params.add(attributes.get(col)); + } + String sql = "INSERT INTO " + getTableName() + " (" + cols + ") VALUES (" + vals + ")"; + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { + for (int i = 0; i < params.size(); i++) { + stmt.setObject(i + 1, params.get(i)); + } + int affected = stmt.executeUpdate(); + if (affected == 0) { + throw new RuntimeException("Failed to insert row for model '" + getTableName() + "'"); + } + ResultSet keys = stmt.getGeneratedKeys(); + if (keys.next()) { + set("id", keys.getInt(1)); + } + keys.close(); + return true; + } catch (SQLException e) { + throw new RuntimeException("Failed to execute create() for model '" + getTableName() + "'", e); + } + } + + public boolean save() { + ensureRowMode("save()"); + if (!attributes.containsKey("id")) { + throw new IllegalStateException("No id set for row instance"); + } + int id = (int) attributes.get("id"); + + StringBuilder set = new StringBuilder(); + List values = new ArrayList<>(); + + for (var entry : attributes.entrySet()) { + if (entry.getKey().equals("id")) continue; + if (set.length() > 0) set.append(", "); + set.append(entry.getKey()).append("=?"); + values.add(entry.getValue()); + } + + String sql = "UPDATE " + getTableName() + " SET " + set + " WHERE id=?"; + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + for (int i = 0; i < values.size(); i++) { + stmt.setObject(i + 1, values.get(i)); + } + stmt.setInt(values.size() + 1, id); + return stmt.executeUpdate() > 0; + } catch (SQLException e) { + throw new RuntimeException("Failed to execute save() for model '" + getTableName() + "'", e); + } + } + + public boolean delete() { + ensureRowMode("delete()"); + if (!attributes.containsKey("id")) { + throw new IllegalStateException("No id set for row instance"); + } + int id = (int) attributes.get("id"); + String sql = "DELETE FROM " + getTableName() + " WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + stmt.setInt(1, id); + return stmt.executeUpdate() > 0; + } catch (SQLException e) { + throw new RuntimeException("Failed to execute delete() for model '" + getTableName() + "'", e); + } + } + + public Object get(String column) { + ensureRowMode("get()"); + return attributes.get(column); + } + + public void set(String column, Object value) { + ensureRowMode("set()"); + if (!columns.contains(column)) { + throw new IllegalArgumentException("Unknown column '" + column + "' for model '" + getTableName() + "'"); + } + attributes.put(column, value); + } + + // ------------------------------- + // Helpers + // ------------------------------- + protected void seedFromRow(ResultSet rs) { + try { + attributes = new LinkedHashMap<>(); + ResultSetMetaData meta = rs.getMetaData(); + for (int i = 1; i <= meta.getColumnCount(); i++) { + attributes.put(meta.getColumnName(i), rs.getObject(i)); + } + } catch (SQLException e) { + throw new RuntimeException("Failed to seed from row for model '" + getTableName() + "'", e); + } + } + + protected void validateColumns(Collection keys) { + for (String key : keys) { + if (!columns.contains(key)) { + throw new IllegalArgumentException( + "Unknown column '" + key + "' for model '" + getTableName() + "'" + ); + } + } + } + + protected String getTableName() { + if (this.tableName != null && !this.tableName.isEmpty()) { + return this.tableName; + } + String className = this.getClass().getSimpleName(); + return className.toLowerCase() + "s"; + } + + private void ensureRowMode(String action) { + if (!rowMode) { + throw new IllegalStateException(action + " requires a row-bound instance"); + } + } + + private void ensureClassMode(String action) { + if (rowMode) { + throw new IllegalStateException(action + " requires a class/global instance"); + } + } + + // ------------------------------- + // toString + // ------------------------------- + @Override + public String toString() { + if (!rowMode) { + return "[Model: " + getClass().getSimpleName() + " | table=" + getTableName() + "]"; + } + StringBuilder sb = new StringBuilder(getClass().getSimpleName() + " { "); + if (attributes.containsKey("id")) { + sb.append("id: ").append(attributes.get("id")).append(", "); + } + for (var entry : attributes.entrySet()) { + if (entry.getKey().equals("id")) continue; + sb.append(entry.getKey()).append(": ").append(entry.getValue()).append(", "); + } + if (!attributes.isEmpty()) sb.setLength(sb.length() - 2); + sb.append(" }"); + return sb.toString(); + } +} diff --git a/src/models/squirrel/ModelManager.java b/src/models/squirrel/ModelManager.java new file mode 100644 index 0000000..78c8e65 --- /dev/null +++ b/src/models/squirrel/ModelManager.java @@ -0,0 +1,31 @@ +package src.models.squirrel; + +import java.util.*; + +import src.models.UserModel; + +public class ModelManager { + private static final Map, Model> models = new HashMap<>(); + + public static void initializeModels() { + List> modelClasses = Arrays.asList( + UserModel.class + ); + for (Class clazz : modelClasses) { + try { + Model instance = clazz.getDeclaredConstructor().newInstance(); + models.put(clazz, instance); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + public static T get(Class clazz) { + Model instance = models.get(clazz); + if (instance == null) { + throw new IllegalStateException("Model not initialized: " + clazz.getSimpleName()); + } + return clazz.cast(instance); + } +}