Skip to content

Commit 6beb3d4

Browse files
authored
modify ROIAlign (#1540) (#1589)
1 parent f5ce619 commit 6beb3d4

File tree

2 files changed

+157
-156
lines changed

2 files changed

+157
-156
lines changed

csrc/cpu/aten/ROIAlign.h

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ namespace {
8080

8181
template <typename T>
8282
struct PreCalc {
83-
int pos1;
84-
int pos2;
85-
int pos3;
86-
int pos4;
83+
int64_t pos1;
84+
int64_t pos2;
85+
int64_t pos3;
86+
int64_t pos4;
8787
T w1;
8888
T w2;
8989
T w3;
@@ -94,27 +94,27 @@ template <typename T, typename ACC_T>
9494
inline void roi_align_single_framework_forward(
9595
const T* input,
9696
const ACC_T count,
97-
int channels,
98-
int height,
99-
int width,
100-
int pooled_height,
101-
int pooled_width,
102-
int roi_bin_grid_h,
103-
int roi_bin_grid_w,
97+
int64_t channels,
98+
int64_t height,
99+
int64_t width,
100+
int64_t pooled_height,
101+
int64_t pooled_width,
102+
int64_t roi_bin_grid_h,
103+
int64_t roi_bin_grid_w,
104104
const std::vector<PreCalc<ACC_T>>& pre_calc,
105105
T* output);
106106

107107
template <typename T, typename ACC_T>
108108
inline void roi_align_single_framework_channels_last_forward(
109109
const T* input,
110110
const ACC_T count,
111-
int channels,
112-
int height,
113-
int width,
114-
int pooled_height,
115-
int pooled_width,
116-
int roi_bin_grid_h,
117-
int roi_bin_grid_w,
111+
int64_t channels,
112+
int64_t height,
113+
int64_t width,
114+
int64_t pooled_height,
115+
int64_t pooled_width,
116+
int64_t roi_bin_grid_h,
117+
int64_t roi_bin_grid_w,
118118
const std::vector<PreCalc<ACC_T>>& pre_calc,
119119
T* output);
120120

@@ -124,13 +124,13 @@ inline void roi_align_single_framework_channels_last_forward<
124124
float>(
125125
const at::BFloat16* input,
126126
const float count,
127-
int channels,
128-
int height,
129-
int width,
130-
int pooled_height,
131-
int pooled_width,
132-
int roi_bin_grid_h,
133-
int roi_bin_grid_w,
127+
int64_t channels,
128+
int64_t height,
129+
int64_t width,
130+
int64_t pooled_height,
131+
int64_t pooled_width,
132+
int64_t roi_bin_grid_h,
133+
int64_t roi_bin_grid_w,
134134
const std::vector<PreCalc<float>>& pre_calc,
135135
at::BFloat16* output);
136136

@@ -141,57 +141,57 @@ template <typename T, typename ACC_T>
141141
inline void roi_align_single_framework_backward(
142142
const T* grad_output,
143143
const ACC_T count,
144-
int channels,
145-
int height,
146-
int width,
147-
int pooled_height,
148-
int pooled_width,
149-
int roi_bin_grid_h,
150-
int roi_bin_grid_w,
144+
int64_t channels,
145+
int64_t height,
146+
int64_t width,
147+
int64_t pooled_height,
148+
int64_t pooled_width,
149+
int64_t roi_bin_grid_h,
150+
int64_t roi_bin_grid_w,
151151
const std::vector<PreCalc<ACC_T>>& pre_calc,
152152
T* grad_input);
153153

154154
template <typename T, typename ACC_T>
155155
inline void roi_align_single_framework_channels_last_backward(
156156
const T* grad_output,
157157
const ACC_T count,
158-
int channels,
159-
int height,
160-
int width,
161-
int pooled_height,
162-
int pooled_width,
163-
int roi_bin_grid_h,
164-
int roi_bin_grid_w,
158+
int64_t channels,
159+
int64_t height,
160+
int64_t width,
161+
int64_t pooled_height,
162+
int64_t pooled_width,
163+
int64_t roi_bin_grid_h,
164+
int64_t roi_bin_grid_w,
165165
const std::vector<PreCalc<ACC_T>>& pre_calc,
166166
T* grad_input);
167167

168168
template <typename T, typename ACC_T>
169169
void roi_align_forward_kernel_body(
170-
int n_rois,
170+
int64_t n_rois,
171171
const T* input,
172172
const ACC_T& spatial_scale,
173-
int channels,
174-
int height,
175-
int width,
176-
int pooled_height,
177-
int pooled_width,
178-
int sampling_ratio,
173+
int64_t channels,
174+
int64_t height,
175+
int64_t width,
176+
int64_t pooled_height,
177+
int64_t pooled_width,
178+
int64_t sampling_ratio,
179179
bool aligned,
180180
const ACC_T* rois,
181181
T* output,
182182
bool is_channels_last);
183183

184184
template <typename T, typename ACC_T>
185185
void roi_align_backward_kernel_body(
186-
int n_rois,
186+
int64_t n_rois,
187187
const T* grad_output,
188188
const ACC_T& spatial_scale,
189-
int channels,
190-
int height,
191-
int width,
192-
int pooled_height,
193-
int pooled_width,
194-
int sampling_ratio,
189+
int64_t channels,
190+
int64_t height,
191+
int64_t width,
192+
int64_t pooled_height,
193+
int64_t pooled_width,
194+
int64_t sampling_ratio,
195195
bool aligned,
196196
T* grad_input,
197197
const ACC_T* rois,

0 commit comments

Comments
 (0)