-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathapinode_client.py
262 lines (203 loc) · 12.7 KB
/
apinode_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import warnings
from .base_client import DSSBaseClient
class APINodeClient(DSSBaseClient):
"""Entry point for the DSS API Node client
This is an API client for the user-facing API of DSS API Node server (user facing API)
"""
def __init__(self, uri, service_id, api_key=None, bearer_token=None, no_check_certificate=False, **kwargs):
"""
Instantiate a new DSS API client on the given base URI with the given API key.
:param str uri: Base URI of the DSS API node server (http://host:port/ or https://host:port/)
:param str service_id: Identifier of the service to query
:param str api_key: Optional, API key for the service. Only required if the service has its authorization setup to API keys
:param str bearer_token: Optional, The bearer token. Only required if the service has its authorization setup to OAuth2/JWT
"""
if "insecure_tls" in kwargs:
# Backward compatibility before removing insecure_tls option
warnings.warn("insecure_tls field is now deprecated. It has been replaced by no_check_certificate.", DeprecationWarning)
no_check_certificate = kwargs.get("insecure_tls") or no_check_certificate
DSSBaseClient.__init__(self, "%s/%s" % (uri, "public/api/v1/%s" % service_id), api_key=api_key, bearer_token=bearer_token, no_check_certificate=no_check_certificate)
@staticmethod
def _set_dispatch(obj, forced_generation, dispatch_key):
if forced_generation is not None:
obj["dispatch"] = {"forcedGeneration": forced_generation}
elif dispatch_key is not None:
obj["dispatch"] = {"dispatchKey": dispatch_key}
def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None,
with_explanations=None, explanation_method=None, n_explanations=None, n_explanations_mc_steps=None):
"""
Predicts a single record on a DSS API node endpoint (standard or custom prediction)
:param str endpoint_id: Identifier of the endpoint to query
:param features: Python dictionary of features of the record
:param forced_generation: See documentation about multi-version prediction
:param dispatch_key: See documentation about multi-version prediction
:param context: Optional, Python dictionary of additional context information. The context information is logged, but not directly used.
:param with_explanations: Optional, whether individual explanations should be computed for each record. The prediction endpoint must be compatible. If None, will use the value configured in the endpoint.
:param explanation_method: Optional, method to compute explanations. Valid values are 'SHAPLEY' or 'ICE'. If None, will use the value configured in the endpoint.
:param n_explanations: Optional, number of explanations to output per prediction. If None, will use the value configured in the endpoint.
:param n_explanations_mc_steps: Optional, precision parameter for SHAPLEY method, higher means more precise but slower (between 25 and 1000).
If None, will use the value configured in the endpoint.
:return: a Python dict of the API answer. The answer contains a "result" key (itself a dict)
"""
obj = {
"features": features,
"explanations": {
"enabled": with_explanations,
"method": explanation_method,
"nExplanations": n_explanations,
"nMonteCarloSteps": n_explanations_mc_steps
}
}
self._set_dispatch(obj, forced_generation, dispatch_key)
if context is not None:
obj["context"] = context
return self._perform_json("POST", "%s/predict" % endpoint_id, body = obj)
def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None, with_explanations=None,
explanation_method=None, n_explanations=None, n_explanations_mc_steps=None):
"""
Predicts a batch of records on a DSS API node endpoint (standard or custom prediction)
:param str endpoint_id: Identifier of the endpoint to query
:param records: Python list of records. Each record must be a Python dict. Each record must contain a "features" dict (see predict_record) and optionally a "context" dict.
:param forced_generation: See documentation about multi-version prediction
:param dispatch_key: See documentation about multi-version prediction
:param with_explanations: Optional, whether individual explanations should be computed for each record. The prediction endpoint must be compatible. If None, will use the value configured in the endpoint.
:param explanation_method: Optional, method to compute explanations. Valid values are 'SHAPLEY' or 'ICE'. If None, will use the value configured in the endpoint.
:param n_explanations: Optional, number of explanations to output per prediction. If None, will use the value configured in the endpoint.
:param n_explanations_mc_steps: Optional, precision parameter for SHAPLEY method, higher means more precise but slower (between 25 and 1000).
If None, will use the value configured in the endpoint.
:return: a Python dict of the API answer. The answer contains a "results" key (which is an array of result objects)
"""
for record in records:
if not "features" in record:
raise ValueError("Each record must contain a 'features' dict")
obj = {
"items": records,
"explanations": {
"enabled": with_explanations,
"method": explanation_method,
"nExplanations": n_explanations,
"nMonteCarloSteps": n_explanations_mc_steps
}
}
self._set_dispatch(obj, forced_generation, dispatch_key)
return self._perform_json("POST", "%s/predict-multi" % endpoint_id, body = obj)
def forecast(self, endpoint_id, records, forced_generation=None, dispatch_key=None):
"""
Forecast using a time series forecasting model on a DSS API node endpoint
:param str endpoint_id: Identifier of the endpoint to query
:param array records: List of time series data records to be used as an input for the
time series forecasting model. Each record should be a dict where
keys are feature names, and values feature values.
Example:
.. code-block:: python
records = [
{'date': '2015-01-04T00:00:00.000Z',
'timeseries_id': 'A', 'target': 10.0},
{'date': '2015-01-04T00:00:00.000Z',
'timeseries_id': 'B', 'target': 4.5},
{'date': '2015-01-05T00:00:00.000Z',
'timeseries_id': 'A', 'target': 2.0},
...
{'date': '2015-03-20T00:00:00.000Z',
'timeseries_id': 'B', 'target': 1.3}
]
:param forced_generation: See documentation about multi-version prediction
:param dispatch_key: See documentation about multi-version prediction
:return: a Python dict of the API answer. The answer contains a "results" key
(which is an array of result objects, corresponding to the forecast records)
Example:
.. code-block:: python
{'results': [
{'forecast': 12.57, 'ignored': False,
'quantiles': [0.0001, 0.5, 0.9999],
'quantilesValues': [3.0, 16.0, 16.0],
'time': '2015-03-21T00:00:00.000000Z',
'timeseriesIdentifier': {'timeseries_id': 'A'}},
{'forecast': 15.57, 'ignored': False,
'quantiles': [0.0001, 0.5, 0.9999],
'quantilesValues': [3.0, 18.0, 19.0],
'time': '2015-03-21T00:00:00.000000Z',
'timeseriesIdentifier': {'timeseries_id': 'B'}},
...],
...}
"""
obj = {"items": records}
self._set_dispatch(obj, forced_generation, dispatch_key)
return self._perform_json("POST", "{}/forecast".format(endpoint_id), body=obj)
def predict_effect(self, endpoint_id, features, forced_generation=None, dispatch_key=None):
"""
Predicts the treatment effect of a single record on a DSS API node endpoint (standard causal prediction)
:param str endpoint_id: Identifier of the endpoint to query
:param features: Python dictionary of features of the record
:param forced_generation: See documentation about multi-version prediction
:param dispatch_key: See documentation about multi-version prediction
:return: a Python dict of the API answer. The answer contains a "result" key (itself a dict)
"""
obj = {
"features": features,
}
self._set_dispatch(obj, forced_generation, dispatch_key)
return self._perform_json("POST", "%s/predict-effect" % endpoint_id, body=obj)
def predict_effects(self, endpoint_id, records, forced_generation=None, dispatch_key=None):
"""
Predicts the treatment effects on a batch of records on a DSS API node endpoint (standard causal prediction)
:param str endpoint_id: Identifier of the endpoint to query
:param records: Python list of records. Each record must be a Python dict. Each record must contain a "features" dict (see predict_record) and optionally a "context" dict.
:param dispatch_key: See documentation about multi-version prediction
:return: a Python dict of the API answer. The answer contains a "results" key (which is an array of result objects)
"""
for record in records:
if not "features" in record:
raise ValueError("Each record must contain a 'features' dict")
obj = {
"items": records,
}
self._set_dispatch(obj, forced_generation, dispatch_key)
return self._perform_json("POST", "%s/predict-effect-multi" % endpoint_id, body = obj)
def sql_query(self, endpoint_id, parameters):
"""
Queries a "SQL query" endpoint on a DSS API node
:param str endpoint_id: Identifier of the endpoint to query
:param parameters: Python dictionary of the named parameters for the SQL query endpoint
:return: a Python dict of the API answer. The answer is the a dict with a columns field and a rows field (list of rows as list of strings)
"""
return self._perform_json("POST", "%s/query" % endpoint_id, body = parameters)
def lookup_record(self, endpoint_id, record, context=None):
"""
Lookup a single record on a DSS API node endpoint of "dataset lookup" type
:param str endpoint_id: Identifier of the endpoint to query
:param record: Python dictionary of features of the record
:param context: Optional, Python dictionary of additional context information. The context information is logged, but not directly used.
:return: a Python dict of the API answer. The answer contains a "data" key (itself a dict)
"""
obj = {
"data" :record
}
if context is not None:
obj["context"] = context
return self._perform_json("POST", "%s/lookup" % endpoint_id, body = obj).get("results", [])[0]
def lookup_records(self, endpoint_id, records):
"""
Lookups a batch of records on a DSS API node endpoint of "dataset lookup" type
:param str endpoint_id: Identifier of the endpoint to query
:param records: Python list of records. Each record must be a Python dict, containing at least one entry called "data": a dict containing the input columns
:return: a Python dict of the API answer. The answer contains a "results" key, which is an array of result objects. Each result contains a "data" dict which is the output
"""
for record in records:
if not "data" in record:
raise ValueError("Each record must contain a 'data' dict")
obj = {
"items" : records
}
return self._perform_json("POST", "%s/lookup-multi" % endpoint_id, body = obj)
def run_function(self, endpoint_id, **kwargs):
"""
Calls a "Run function" endpoint on a DSS API node
:param str endpoint_id: Identifier of the endpoint to query
:param kwargs: Arguments of the function
:return: The function result
"""
obj = {}
for (k,v) in kwargs.items():
obj[k] = v
return self._perform_json("POST", "%s/run" % endpoint_id, body = obj)