@@ -80,10 +80,10 @@ namespace {
80
80
81
81
template <typename T>
82
82
struct PreCalc {
83
- int pos1;
84
- int pos2;
85
- int pos3;
86
- int pos4;
83
+ int64_t pos1;
84
+ int64_t pos2;
85
+ int64_t pos3;
86
+ int64_t pos4;
87
87
T w1;
88
88
T w2;
89
89
T w3;
@@ -94,27 +94,27 @@ template <typename T, typename ACC_T>
94
94
inline void roi_align_single_framework_forward (
95
95
const T* input,
96
96
const ACC_T count,
97
- int channels,
98
- int height,
99
- int width,
100
- int pooled_height,
101
- int pooled_width,
102
- int roi_bin_grid_h,
103
- int roi_bin_grid_w,
97
+ int64_t channels,
98
+ int64_t height,
99
+ int64_t width,
100
+ int64_t pooled_height,
101
+ int64_t pooled_width,
102
+ int64_t roi_bin_grid_h,
103
+ int64_t roi_bin_grid_w,
104
104
const std::vector<PreCalc<ACC_T>>& pre_calc,
105
105
T* output);
106
106
107
107
template <typename T, typename ACC_T>
108
108
inline void roi_align_single_framework_channels_last_forward (
109
109
const T* input,
110
110
const ACC_T count,
111
- int channels,
112
- int height,
113
- int width,
114
- int pooled_height,
115
- int pooled_width,
116
- int roi_bin_grid_h,
117
- int roi_bin_grid_w,
111
+ int64_t channels,
112
+ int64_t height,
113
+ int64_t width,
114
+ int64_t pooled_height,
115
+ int64_t pooled_width,
116
+ int64_t roi_bin_grid_h,
117
+ int64_t roi_bin_grid_w,
118
118
const std::vector<PreCalc<ACC_T>>& pre_calc,
119
119
T* output);
120
120
@@ -124,13 +124,13 @@ inline void roi_align_single_framework_channels_last_forward<
124
124
float >(
125
125
const at::BFloat16* input,
126
126
const float count,
127
- int channels,
128
- int height,
129
- int width,
130
- int pooled_height,
131
- int pooled_width,
132
- int roi_bin_grid_h,
133
- int roi_bin_grid_w,
127
+ int64_t channels,
128
+ int64_t height,
129
+ int64_t width,
130
+ int64_t pooled_height,
131
+ int64_t pooled_width,
132
+ int64_t roi_bin_grid_h,
133
+ int64_t roi_bin_grid_w,
134
134
const std::vector<PreCalc<float >>& pre_calc,
135
135
at::BFloat16* output);
136
136
@@ -141,57 +141,57 @@ template <typename T, typename ACC_T>
141
141
inline void roi_align_single_framework_backward (
142
142
const T* grad_output,
143
143
const ACC_T count,
144
- int channels,
145
- int height,
146
- int width,
147
- int pooled_height,
148
- int pooled_width,
149
- int roi_bin_grid_h,
150
- int roi_bin_grid_w,
144
+ int64_t channels,
145
+ int64_t height,
146
+ int64_t width,
147
+ int64_t pooled_height,
148
+ int64_t pooled_width,
149
+ int64_t roi_bin_grid_h,
150
+ int64_t roi_bin_grid_w,
151
151
const std::vector<PreCalc<ACC_T>>& pre_calc,
152
152
T* grad_input);
153
153
154
154
template <typename T, typename ACC_T>
155
155
inline void roi_align_single_framework_channels_last_backward (
156
156
const T* grad_output,
157
157
const ACC_T count,
158
- int channels,
159
- int height,
160
- int width,
161
- int pooled_height,
162
- int pooled_width,
163
- int roi_bin_grid_h,
164
- int roi_bin_grid_w,
158
+ int64_t channels,
159
+ int64_t height,
160
+ int64_t width,
161
+ int64_t pooled_height,
162
+ int64_t pooled_width,
163
+ int64_t roi_bin_grid_h,
164
+ int64_t roi_bin_grid_w,
165
165
const std::vector<PreCalc<ACC_T>>& pre_calc,
166
166
T* grad_input);
167
167
168
168
template <typename T, typename ACC_T>
169
169
void roi_align_forward_kernel_body (
170
- int n_rois,
170
+ int64_t n_rois,
171
171
const T* input,
172
172
const ACC_T& spatial_scale,
173
- int channels,
174
- int height,
175
- int width,
176
- int pooled_height,
177
- int pooled_width,
178
- int sampling_ratio,
173
+ int64_t channels,
174
+ int64_t height,
175
+ int64_t width,
176
+ int64_t pooled_height,
177
+ int64_t pooled_width,
178
+ int64_t sampling_ratio,
179
179
bool aligned,
180
180
const ACC_T* rois,
181
181
T* output,
182
182
bool is_channels_last);
183
183
184
184
template <typename T, typename ACC_T>
185
185
void roi_align_backward_kernel_body (
186
- int n_rois,
186
+ int64_t n_rois,
187
187
const T* grad_output,
188
188
const ACC_T& spatial_scale,
189
- int channels,
190
- int height,
191
- int width,
192
- int pooled_height,
193
- int pooled_width,
194
- int sampling_ratio,
189
+ int64_t channels,
190
+ int64_t height,
191
+ int64_t width,
192
+ int64_t pooled_height,
193
+ int64_t pooled_width,
194
+ int64_t sampling_ratio,
195
195
bool aligned,
196
196
T* grad_input,
197
197
const ACC_T* rois,
0 commit comments