package sqlite import ( "context" "database/sql" "fmt" "log" "os" _ "modernc.org/sqlite" "uvok.de/go/training_fellow/registration" ) type sqliteRepository struct { fileName string context context.Context } func NewRepository(fileName string, context context.Context) *sqliteRepository { repo := sqliteRepository{fileName: fileName, context: context} return &repo } func (repo *sqliteRepository) SaveRegistration(registration *registration.Registration) error { db, err := repo.checkInit() if err != nil { return err } defer db.Close() _, err = db.ExecContext( repo.context, `INSERT INTO registrations (reg_id, first_name, last_name, email, company, training_code, date, privacy_policy_accepted, confirmed) VALUES (?,?,?,?,?,?,?,?,?)`, registration.RegId, registration.FirstName, registration.LastName, registration.Email, registration.Company, registration.TrainingCode, registration.Date, registration.PrivacyPolicyAccepted, registration.Confirmed) return err } func rowToRegistration(scanner interface{ Scan(dest ...any) error }) (*registration.Registration, error) { var reg registration.Registration err := scanner.Scan(®.RegId, ®.FirstName, ®.LastName, ®.Email, ®.Company, ®.TrainingCode, ®.Date, ®.PrivacyPolicyAccepted) if err != nil { return nil, err } return ®, nil } func (repo *sqliteRepository) GetUnconfirmedRegistrations() ([]*registration.Registration, error) { db, err := repo.checkInit() if err != nil { return nil, err } res, err := db.QueryContext( repo.context, `SELECT reg_id, first_name, last_name, email, company, training_code, date, privacy_policy_accepted FROM registrations WHERE confirmed = 0 `) if err != nil { return nil, err } defer res.Close() regArray := make([]*registration.Registration, 0) for res.Next() { reg, err := rowToRegistration(res) if err != nil { log.Printf("Error scanning: %v", err) continue } regArray = append(regArray, reg) } return regArray, nil } func (repo *sqliteRepository) ConfirmRegistration(registrationId string) (*registration.Registration, error) { db, err := repo.checkInit() if err != nil { return nil, err } defer db.Close() transaction, err := db.BeginTx(repo.context, nil) if err != nil { return nil, err } defer transaction.Rollback() updateRes, err := transaction.Exec( "UPDATE registrations SET confirmed=1 WHERE reg_id=? AND confirmed=0", registrationId) if err != nil { return nil, err } count, err := updateRes.RowsAffected() switch { case err != nil: return nil, err case count == 0: return nil, sql.ErrNoRows case count > 1: newError := fmt.Errorf("more than one result for registration ID %v", registrationId) return nil, newError } transaction.Commit() queryRes := db.QueryRowContext( repo.context, `SELECT reg_id, first_name, last_name, email, company, training_code, date, privacy_policy_accepted FROM registrations WHERE reg_id=? `, registrationId) reg, err := rowToRegistration(queryRes) // Not included in result set reg.Confirmed = true switch { // case err == sql.ErrNoRows: // return nil, err case err != nil: return nil, err default: return reg, nil } } func (repo *sqliteRepository) checkInit() (*sql.DB, error) { db, err := sql.Open("sqlite", repo.fileName) if err != nil { return nil, err } err = db.PingContext(repo.context) if err != nil { return nil, err } return db, nil } func (repo *sqliteRepository) Migrate() error { db, err := repo.checkInit() if err != nil { return err } var schema_version int64 err = db.QueryRow("PRAGMA user_version").Scan(&schema_version) if err != nil { return err } // Initial / fresh table if schema_version == 0 { var sql []byte sql, err = os.ReadFile("./0001_Create_Initial_Tables.sql") if err != nil { return err } sqlString := string(sql) _, err = db.Exec(sqlString) } return err }