-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstore.go
101 lines (81 loc) · 2.33 KB
/
store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package postgres
import (
"context"
"database/sql"
"time"
"github.com/jmoiron/sqlx"
"github.com/code-payments/code-server/pkg/database/query"
"github.com/code-payments/code-server/pkg/code/data/currency"
pg "github.com/code-payments/code-server/pkg/database/postgres"
)
type store struct {
db *sqlx.DB
}
func New(db *sql.DB) currency.Store {
return &store{
db: sqlx.NewDb(db, "pgx"),
}
}
func (s store) Put(ctx context.Context, obj *currency.MultiRateRecord) error {
return pg.ExecuteInTx(ctx, s.db, sql.LevelDefault, func(tx *sqlx.Tx) error {
// Loop through all rates and save individual records (within a transaction)
for symbol, item := range obj.Rates {
err := toModel(¤cy.ExchangeRateRecord{
Time: obj.Time,
Rate: item,
Symbol: symbol,
}).txSave(ctx, tx)
if err != nil {
return pg.CheckUniqueViolation(err, currency.ErrExists)
}
}
return nil
})
}
func (s store) Get(ctx context.Context, symbol string, t time.Time) (*currency.ExchangeRateRecord, error) {
obj, err := dbGetBySymbolAndTime(ctx, s.db, symbol, t, query.Descending)
if err != nil {
return nil, err
}
return fromModel(obj), nil
}
func (s store) GetAll(ctx context.Context, t time.Time) (*currency.MultiRateRecord, error) {
list, err := dbGetAllByTime(ctx, s.db, t, query.Descending)
if err != nil {
return nil, err
}
res := ¤cy.MultiRateRecord{
Time: list[0].ForTimestamp,
Rates: map[string]float64{},
}
for _, item := range list {
res.Rates[item.CurrencyCode] = item.CurrencyRate
}
return res, nil
}
func (s store) GetRange(ctx context.Context, symbol string, interval query.Interval, start time.Time, end time.Time, ordering query.Ordering) ([]*currency.ExchangeRateRecord, error) {
if interval > query.IntervalMonth {
return nil, currency.ErrInvalidInterval
}
if start.IsZero() || end.IsZero() {
return nil, currency.ErrInvalidRange
}
var actualStart, actualEnd time.Time
if start.Unix() > end.Unix() {
actualStart = end
actualEnd = start
} else {
actualStart = start
actualEnd = end
}
// TODO: check that the range is reasonable
list, err := dbGetAllForRange(ctx, s.db, symbol, interval, actualStart, actualEnd, ordering)
if err != nil {
return nil, err
}
res := []*currency.ExchangeRateRecord{}
for _, item := range list {
res = append(res, fromModel(item))
}
return res, nil
}