-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathverifier.go
142 lines (120 loc) · 3.69 KB
/
verifier.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package ios
import (
"context"
"errors"
"strings"
devicecheck "github.com/rinchsan/device-check-go"
"github.com/code-payments/code-server/pkg/device"
"github.com/code-payments/code-server/pkg/grpc/client"
"github.com/code-payments/code-server/pkg/metrics"
)
const (
metricsStructName = "device.ios.verifier"
)
type AppleEnv uint8
const (
AppleEnvDevelopment AppleEnv = iota
AppleEnvProduction
)
// Current two-bit configuration:
// - bit0: Free account flag
// - bit1: Unused
//
// todo: May need a small refactor if bit1 is used
type iOSDeviceVerifier struct {
client *devicecheck.Client
minVersion *client.Version
}
// NewIOSDeviceVerifier returns a new device.Verifier for iOS devices
func NewIOSDeviceVerifier(
env AppleEnv,
keyIssuer string,
keyId string,
privateKeyFile string,
minVersion *client.Version,
) (device.Verifier, error) {
var dcEnv devicecheck.Environment
switch env {
case AppleEnvDevelopment:
dcEnv = devicecheck.Development
case AppleEnvProduction:
dcEnv = devicecheck.Production
default:
return nil, errors.New("invalid environment")
}
client := devicecheck.New(
devicecheck.NewCredentialFile(privateKeyFile),
devicecheck.NewConfig(keyIssuer, keyId, dcEnv),
)
return &iOSDeviceVerifier{
client: client,
minVersion: minVersion,
}, nil
}
// IsValid implements device.Verifier.IsValid
func (v *iOSDeviceVerifier) IsValid(ctx context.Context, token string) (bool, string, error) {
tracer := metrics.TraceMethodCall(ctx, metricsStructName, "IsValid")
defer tracer.End()
isValid, reason, err := func() (bool, string, error) {
userAgent, err := client.GetUserAgent(ctx)
if err != nil {
return false, "user agent not set", nil
}
if userAgent.DeviceType != client.DeviceTypeIOS {
return false, "user agent is not ios", nil
}
if userAgent.Version.Before(v.minVersion) {
return false, "minimum client version not met", nil
}
err = v.client.ValidateDeviceToken(token)
if err == nil {
return true, "", nil
}
// Need to parse for the "bad device token" type of errors. Otherwise, we
// cannot distinguish between a validity issue or other API/network error.
//
// https://developer.apple.com/documentation/devicecheck/accessing_and_modifying_per-device_data#2910408
errorString := strings.ToLower(err.Error())
if strings.Contains(errorString, "bad device token") {
return false, "invalid device token", nil
} else if strings.Contains(errorString, "missing or incorrectly formatted device token payload") {
return false, "invalid device token", nil
}
return false, "", err
}()
if err != nil {
tracer.OnError(err)
}
return isValid, reason, err
}
// HasCreatedFreeAccount implements device.Verifier.HasCreatedFreeAccount
func (v *iOSDeviceVerifier) HasCreatedFreeAccount(ctx context.Context, token string) (bool, error) {
tracer := metrics.TraceMethodCall(ctx, metricsStructName, "HasCreatedFreeAccount")
defer tracer.End()
hasCreatedFreeAccount, err := func() (bool, error) {
var res devicecheck.QueryTwoBitsResult
err := v.client.QueryTwoBits(token, &res)
if err == nil {
return res.Bit0, nil
}
errorString := strings.ToLower(err.Error())
if strings.Contains(errorString, "bit state not found") {
return false, nil
}
return false, err
}()
if err != nil {
tracer.OnError(err)
}
return hasCreatedFreeAccount, err
}
// MarkCreatedFreeAccount implements device.Verifier.MarkCreatedFreeAccount
func (v *iOSDeviceVerifier) MarkCreatedFreeAccount(ctx context.Context, token string) error {
tracer := metrics.TraceMethodCall(ctx, metricsStructName, "MarkCreatedFreeAccount")
defer tracer.End()
err := v.client.UpdateTwoBits(token, true, false)
if err != nil {
tracer.OnError(err)
}
return err
}