-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
gather.ts
67 lines (61 loc) · 2.41 KB
/
gather.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from '../engine';
import {GatherV2, GatherV2Attrs, GatherV2Inputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {op} from './operation';
/**
* Gather slices from tensor `x`'s axis `axis` according to `indices`.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* const indices = tf.tensor1d([1, 3, 3], 'int32');
*
* x.gather(indices).print();
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const indices = tf.tensor1d([1, 1, 0], 'int32');
*
* x.gather(indices).print();
* ```
* @param x The input tensor whose slices are to be gathered.
* @param indices The indices of the values to extract.
* @param axis The axis over which to select values. Defaults to 0.
* @param batchDims Optional. The number of batch dimensions. It must be less
* than or equal to rank(indices). Defaults to 0.
* The output tensor will have shape of
* `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]`
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function gather_<T extends Tensor>(
x: T|TensorLike, indices: Tensor|TensorLike, axis = 0, batchDims = 0): T {
const $x = convertToTensor(x, 'x', 'gather');
const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
const inputs: GatherV2Inputs = {x: $x, indices: $indices};
const attrs: GatherV2Attrs = {axis, batchDims};
return ENGINE.runKernel(
GatherV2, inputs as unknown as NamedTensorMap,
attrs as unknown as NamedAttrMap);
}
export const gather = /* @__PURE__ */ op({gather_});