未验证 提交 256bf6ff 编写于 作者: U USTCKAY 提交者: GitHub

fix roi_align_op_npu to pass the unittest (#45310)

上级 ea1f4702
...@@ -157,8 +157,10 @@ class ROIAlignNPUGradKernel : public framework::OpKernel<T> { ...@@ -157,8 +157,10 @@ class ROIAlignNPUGradKernel : public framework::OpKernel<T> {
"ConcatD", {x_list}, {ROIs_N5}, {{"N", 2}, {"concat_dim", 1}}); "ConcatD", {x_list}, {ROIs_N5}, {{"N", 2}, {"concat_dim", 1}});
runner_concat.Run(stream); runner_concat.Run(stream);
// By analysis, in order to match cpu grad version, // If CANN version code is less than 504, by analysis, in order to match
// rois[:,3:5] should substrate 1 before call ascend grad function // cpu grad version, rois[:,3:5] should substrate 1 before call ascend grad
// function
#if (CANN_VERSION_CODE < 504000)
std::vector<float> vec_dlt = {0, 0, 0, -1.0f, -1.0f}; std::vector<float> vec_dlt = {0, 0, 0, -1.0f, -1.0f};
Tensor tsr_dlt; Tensor tsr_dlt;
tsr_dlt.mutable_data<float>({5}, place); tsr_dlt.mutable_data<float>({5}, place);
...@@ -167,6 +169,7 @@ class ROIAlignNPUGradKernel : public framework::OpKernel<T> { ...@@ -167,6 +169,7 @@ class ROIAlignNPUGradKernel : public framework::OpKernel<T> {
const auto& runner_add = const auto& runner_add =
NpuOpRunner("AddV2", {ROIs_N5, tsr_dlt}, {ROIs_N5}, {}); NpuOpRunner("AddV2", {ROIs_N5, tsr_dlt}, {ROIs_N5}, {});
runner_add.Run(stream); runner_add.Run(stream);
#endif
// Call ascend RoiAlignGrad function // Call ascend RoiAlignGrad function
int roi_end_mode = 0; int roi_end_mode = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册