提交 7d7e2995 编写于 作者: W whs 提交者: qingqing01

Fix bp of roi perspective transform op. (#17216)

上级 7bd1d03e
...@@ -466,6 +466,10 @@ class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> { ...@@ -466,6 +466,10 @@ class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
const int* out2in_idx_data = out2in_idx->data<int>(); const int* out2in_idx_data = out2in_idx->data<int>();
const T* out2in_w_data = out2in_w->data<T>(); const T* out2in_w_data = out2in_w->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册