diff --git a/paddle/fluid/operators/roi_align_op_npu.cc b/paddle/fluid/operators/roi_align_op_npu.cc index c1ba046ca6af1a95165a0bf78458a1be56e29c0e..c26db2500fd6613d914bd4b9ebcb548578fb73b5 100644 --- a/paddle/fluid/operators/roi_align_op_npu.cc +++ b/paddle/fluid/operators/roi_align_op_npu.cc @@ -90,6 +90,94 @@ class ROIAlignNPUKernel : public framework::OpKernel { } }; +template +class ROIAlignNPUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sample_num = ctx.Attr("sampling_ratio"); + auto in_dims = in->dims(); + auto aligned = ctx.Attr("aligned"); + + int rois_num = rois->dims()[0]; + + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + + if (!in_grad) { + return; + } + in_grad->mutable_data(place); + + PADDLE_ENFORCE_EQ( + aligned, false, + platform::errors::InvalidArgument( + "ROIAlignGradNPU only support Aligned attribute equaled to False")); + PADDLE_ENFORCE_EQ( + ctx.HasInput("RoisNum"), true, + platform::errors::NotFound("Input(RoisNum) of ROIAlignGradOp " + "is not found while using NPU.")); + PADDLE_ENFORCE_EQ( + rois->type(), framework::proto::VarType::FP32, + platform::errors::InvalidArgument( + "ROIAlignGradNPU only support ROIs type equaled to FP32.")); + + // Cast RoisNum to fp32 tensor + auto* RoisNum = ctx.Input("RoisNum"); + Tensor ROIs_N5; + ROIs_N5.mutable_data({rois_num, 5}, place); + Tensor ROIsNum_fp; + ROIsNum_fp.mutable_data(RoisNum->dims(), place); // shape = [rois_num] + int nputype_fp32 = + static_cast(ConvertToNpuDtype(framework::proto::VarType::FP32)); + const auto& runner_cast = NpuOpRunner("Cast", {*RoisNum}, {ROIsNum_fp}, + {{"dst_type", nputype_fp32}}); + runner_cast.Run(stream); + ROIsNum_fp.Resize({rois_num, 1}); + + // Combine *ROIsNum with ROIs to get new ROIs + std::vector x_list; + x_list.push_back(ROIsNum_fp); + x_list.push_back(*rois); + const auto& runner_concat = NpuOpRunner("ConcatD", {x_list}, {ROIs_N5}, + {{"N", 2}, {"concat_dim", 1}}); + runner_concat.Run(stream); + + // By analysis, in order to match cpu grad version, + // rois[:,3:5] should substrate 1 before call ascend grad function + std::vector vec_dlt = {0, 0, 0, -1.0f, -1.0f}; + Tensor tsr_dlt; + tsr_dlt.mutable_data({5}, place); + framework::TensorFromVector(vec_dlt, ctx.device_context(), &tsr_dlt); + ctx.template device_context().Wait(); + const auto& runner_add = + NpuOpRunner("AddV2", {ROIs_N5, tsr_dlt}, {ROIs_N5}, {}); + runner_add.Run(stream); + + // Call ascend RoiAlignGrad function + int roi_end_mode = 0; + const auto& runner_roi_align_grad = + NpuOpRunner("ROIAlignGrad", {*out_grad, ROIs_N5}, {*in_grad}, + {{"xdiff_shape", framework::vectorize(in_dims)}, + {"pooled_width", pooled_width}, + {"pooled_height", pooled_height}, + {"spatial_scale", spatial_scale}, + {"sample_num", sample_num}, + {"roi_end_mode", roi_end_mode}}); + runner_roi_align_grad.Run(stream); + } +}; + } // namespace operators } // namespace paddle @@ -99,3 +187,7 @@ REGISTER_OP_NPU_KERNEL( ops::ROIAlignNPUKernel, ops::ROIAlignNPUKernel, ops::ROIAlignNPUKernel); + +REGISTER_OP_NPU_KERNEL(roi_align_grad, ops::ROIAlignNPUGradKernel, + ops::ROIAlignNPUGradKernel, + ops::ROIAlignNPUGradKernel);