summaryrefslogtreecommitdiff
path: root/persistence/sqlite/sqliteRepository.go
blob: 9e0e3f5bfa9544e2959b9eb07b1abafb7142c455 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package sqlite

import (
	"context"
	"database/sql"
	"log"
	"os"

	_ "modernc.org/sqlite"

	"uvok.de/go/training_fellow/registration"
)

type sqliteRepository struct {
	fileName string
	context  context.Context
}

func NewRepository(fileName string, context context.Context) *sqliteRepository {
	repo := sqliteRepository{fileName: fileName, context: context}
	return &repo
}

func (repo *sqliteRepository) SaveRegistration(registration *registration.Registration) error {
	db, err := repo.checkInit()
	if err != nil {
		return err
	}
	defer db.Close()

	_, err = db.ExecContext(
		repo.context,
		`INSERT INTO registrations
		  (reg_id, first_name, last_name, email, company, training_code, date, privacy_policy_accepted, confirmed)
		  VALUES (?,?,?,?,?,?,?,?,?)`,
		registration.RegId,
		registration.FirstName,
		registration.LastName,
		registration.Email,
		registration.Company,
		registration.TrainingCode,
		registration.Date,
		registration.PrivacyPolicyAccepted,
		registration.Confirmed)

	return err
}


func (repo *sqliteRepository) GetUnconfirmedRegistrations() ([]*registration.Registration, error) {
	db, err := repo.checkInit()
	if err != nil {
		return nil, err
	}
	res, err := db.QueryContext(
		repo.context,
		`SELECT reg_id, first_name, last_name, email, company, training_code, date, privacy_policy_accepted
		FROM registrations
		WHERE confirmed = 0
		`)

	if err != nil {
		return nil, err
	}
	defer res.Close()

	regArray := make([]*registration.Registration, 0)
	for res.Next() {
		reg := registration.Registration{}
		err = res.Scan(&reg.RegId, &reg.FirstName, &reg.LastName, &reg.Email, &reg.Company, &reg.TrainingCode, &reg.Date, &reg.PrivacyPolicyAccepted)
		if err != nil {
			log.Printf("Error scanning: %v", err)
		}
		regArray = append(regArray, &reg)
	}
	return regArray, nil
}

func (repo *sqliteRepository) ConfirmRegistration(registrationId string) (*registration.Registration, error) {
	repo.checkInit()
	panic("not implemented") // TODO: Implement
}

func (repo *sqliteRepository) checkInit() (*sql.DB, error) {
	db, err := sql.Open("sqlite", repo.fileName)
	if err != nil {
		return nil, err
	}
	err = db.PingContext(repo.context)

	if err != nil {
		return nil, err
	}
	return db, nil
}

func (repo *sqliteRepository) Migrate() error {
	db, err := repo.checkInit()
	if err != nil {
		return err
	}
	var schema_version int64
	err = db.QueryRow("PRAGMA user_version").Scan(&schema_version)

	if err != nil {
		return err
	}
	// Initial / fresh table
	if schema_version == 0 {
		var sql []byte
		sql, err = os.ReadFile("./0001_Create_Initial_Tables.sql")
		if err != nil {
			return err
		}
		sqlString := string(sql)
		_, err = db.Exec(sqlString)
	}

	return err
}