-
Notifications
You must be signed in to change notification settings - Fork 507
/
Copy pathtorch_xla_test.cpp
112 lines (92 loc) · 3.73 KB
/
torch_xla_test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include "test/cpp/torch_xla_test.h"
#include <ATen/ATen.h>
#include "absl/memory/memory.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/xla_backend_impl.h"
#include "torch_xla/csrc/xla_graph_executor.h"
namespace torch_xla {
namespace cpp_test {
static bool xla_backend_inited = InitXlaBackend();
void XlaTest::SetUp() {
at::manual_seed(42);
XLAGraphExecutor::Get()->SetRngSeed(bridge::GetCurrentDevice(), 42);
start_msnap_ = absl::make_unique<MetricsSnapshot>();
}
void XlaTest::TearDown() {
static bool dump_metrics =
torch_xla::runtime::sys_util::GetEnvBool("XLA_TEST_DUMP_METRICS", false);
if (dump_metrics) {
MakeEndSnapshot();
std::string diffs = start_msnap_->DumpDifferences(*end_msnap_,
/*ignore_se=*/nullptr);
if (!diffs.empty()) {
TF_LOG(INFO)
<< ::testing::UnitTest::GetInstance()->current_test_info()->name()
<< " Metrics Differences:\n"
<< diffs;
}
}
}
static void ExpectCounterNotChanged_(
const std::vector<MetricsSnapshot::ChangedCounter>& changed) {
for (auto& change_counter : changed) {
TF_LOG(INFO) << "Counter '" << change_counter.name
<< "' changed: " << change_counter.before << " -> "
<< change_counter.after;
}
EXPECT_TRUE(changed.empty());
}
void XlaTest::ExpectCounterNotChanged(
const std::string& counter_regex,
const std::unordered_set<std::string>* ignore_set) {
MakeEndSnapshot();
auto changed =
start_msnap_->CounterChanged(counter_regex, *end_msnap_, ignore_set);
ExpectCounterNotChanged_(changed);
// Some operators could've been renamed to `opName_symint`, yet the tests are
// using the old names. We modify `ExpectCounterNotChanged` to also check
// `opName_symint` counters. When we finish migrating the ops to symints, we
// would remove this logic and fix all the tests
auto changed_symint = start_msnap_->CounterChanged(counter_regex + "_symint",
*end_msnap_, ignore_set);
ExpectCounterNotChanged_(changed_symint);
}
void XlaTest::ExpectCounterChanged(
const std::string& counter_regex,
const std::unordered_set<std::string>* ignore_set) {
MakeEndSnapshot();
auto changed =
start_msnap_->CounterChanged(counter_regex, *end_msnap_, ignore_set);
// Some operators could've been renamed to `opName_symint`, yet the tests are
// using the old names. We modify `ExpectCounterChanged` to also check
// `opName_symint` counters. When we finish migrating the ops to symints, we
// would remove this logic and fix all the tests
auto changed_symint = start_msnap_->CounterChanged(counter_regex + "_symint",
*end_msnap_, ignore_set);
EXPECT_TRUE(!changed.empty() || !changed_symint.empty());
// We expect *either* changed or changed_symint to contain changed counters
// but not *both*. Likewise, if both are empty, the assertion above should
// fail
EXPECT_TRUE(changed.empty() != changed_symint.empty());
}
void XlaTest::ResetCounters() {
start_msnap_ = std::move(end_msnap_);
end_msnap_ = nullptr;
}
void XlaTest::MakeEndSnapshot() {
if (end_msnap_ == nullptr) {
end_msnap_ = absl::make_unique<MetricsSnapshot>();
}
}
void XlaTest::CommonSetup() {
XlaHelpers::set_mat_mul_precision(xla::PrecisionConfig::HIGHEST);
}
void TorchXlaTest::SetUpTestCase() { CommonSetup(); }
void AtenXlaTensorTestBase::SetUpTestCase() { CommonSetup(); }
} // namespace cpp_test
} // namespace torch_xla