Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove chained ops usage from core #3706

Merged
merged 6 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint fix
  • Loading branch information
tafsiri committed Jul 31, 2020
commit 1cdd86b3a000bcf8f3b16bff46d932a3e9c63384
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/backend_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export function castTensor<T extends Tensor>(
const real = backend.real(x);
const result = cast(real, dtype);
real.dispose();
return result as T;
return result;
}
if (dtype === 'int32') {
return backend.int(x);
Expand Down
3 changes: 3 additions & 0 deletions tfjs-core/src/ops/batchnorm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,15 @@ function as1DOr4D(x: Tensor): Tensor4D|Tensor1D {
return null;
}
if (x.rank === 0) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return reshape(x, [x.size]) as Tensor1D;
} else if (x.rank === 1) {
return x as Tensor1D;
} else if (x.rank === 2) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return reshape(x, [1, 1, x.shape[0], x.shape[1]]) as Tensor4D;
} else if (x.rank === 3) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]) as Tensor4D;
}
return x as Tensor4D;
Expand Down
5 changes: 3 additions & 2 deletions tfjs-core/src/ops/eye.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import {op} from './operation';
import {reshape} from './reshape';
import {tile} from './tile';


/**
* Create an identity matrix.
*
Expand Down Expand Up @@ -54,17 +53,19 @@ function eye_(
for (let i = 0; i < n; ++i) {
buff.set(1, i, i);
}
const out = reshape(buff.toTensor(), [numRows, numColumns]) as Tensor2D;
const out: Tensor2D = reshape(buff.toTensor(), [numRows, numColumns]);
if (batchShape == null) {
return out;
} else {
if (batchShape.length === 1) {
return tile(expandDims(out, 0), [batchShape[0], 1, 1]) as Tensor2D;
} else if (batchShape.length === 2) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return tile(
expandDims(expandDims(out, 0), 0),
[batchShape[0], batchShape[1], 1, 1]) as Tensor2D;
} else if (batchShape.length === 3) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
batchShape[0], batchShape[1], batchShape[2], 1, 1
]) as Tensor2D;
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/ops/multinomial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function multinomial_(
const res = ENGINE.runKernelFunc(
backend => backend.multinomial(logits2D, normalized, numSamples, seed),
{logits2D});
// tslint:disable-next-line:no-unnecessary-type-assertion
return origRank === 1 ? reshape(res, [res.size]) as Tensor1D : res;
}

Expand Down