未验证 提交 236ed94d 编写于 作者: Z zhulei 提交者: GitHub

Add roi_align grad (#36724)

上级 87fbbd36
...@@ -90,6 +90,94 @@ class ROIAlignNPUKernel : public framework::OpKernel<T> { ...@@ -90,6 +90,94 @@ class ROIAlignNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class ROIAlignNPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sample_num = ctx.Attr<int>("sampling_ratio");
auto in_dims = in->dims();
auto aligned = ctx.Attr<bool>("aligned");
int rois_num = rois->dims()[0];
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
if (!in_grad) {
return;
}
in_grad->mutable_data<T>(place);
PADDLE_ENFORCE_EQ(
aligned, false,
platform::errors::InvalidArgument(
"ROIAlignGradNPU only support Aligned attribute equaled to False"));
PADDLE_ENFORCE_EQ(
ctx.HasInput("RoisNum"), true,
platform::errors::NotFound("Input(RoisNum) of ROIAlignGradOp "
"is not found while using NPU."));
PADDLE_ENFORCE_EQ(
rois->type(), framework::proto::VarType::FP32,
platform::errors::InvalidArgument(
"ROIAlignGradNPU only support ROIs type equaled to FP32."));
// Cast RoisNum to fp32 tensor
auto* RoisNum = ctx.Input<framework::Tensor>("RoisNum");
Tensor ROIs_N5;
ROIs_N5.mutable_data<float>({rois_num, 5}, place);
Tensor ROIsNum_fp;
ROIsNum_fp.mutable_data<T>(RoisNum->dims(), place); // shape = [rois_num]
int nputype_fp32 =
static_cast<int>(ConvertToNpuDtype(framework::proto::VarType::FP32));
const auto& runner_cast = NpuOpRunner("Cast", {*RoisNum}, {ROIsNum_fp},
{{"dst_type", nputype_fp32}});
runner_cast.Run(stream);
ROIsNum_fp.Resize({rois_num, 1});
// Combine *ROIsNum with ROIs to get new ROIs
std::vector<paddle::framework::Tensor> x_list;
x_list.push_back(ROIsNum_fp);
x_list.push_back(*rois);
const auto& runner_concat = NpuOpRunner("ConcatD", {x_list}, {ROIs_N5},
{{"N", 2}, {"concat_dim", 1}});
runner_concat.Run(stream);
// By analysis, in order to match cpu grad version,
// rois[:,3:5] should substrate 1 before call ascend grad function
std::vector<float> vec_dlt = {0, 0, 0, -1.0f, -1.0f};
Tensor tsr_dlt;
tsr_dlt.mutable_data<float>({5}, place);
framework::TensorFromVector<float>(vec_dlt, ctx.device_context(), &tsr_dlt);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
const auto& runner_add =
NpuOpRunner("AddV2", {ROIs_N5, tsr_dlt}, {ROIs_N5}, {});
runner_add.Run(stream);
// Call ascend RoiAlignGrad function
int roi_end_mode = 0;
const auto& runner_roi_align_grad =
NpuOpRunner("ROIAlignGrad", {*out_grad, ROIs_N5}, {*in_grad},
{{"xdiff_shape", framework::vectorize<int>(in_dims)},
{"pooled_width", pooled_width},
{"pooled_height", pooled_height},
{"spatial_scale", spatial_scale},
{"sample_num", sample_num},
{"roi_end_mode", roi_end_mode}});
runner_roi_align_grad.Run(stream);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -99,3 +187,7 @@ REGISTER_OP_NPU_KERNEL( ...@@ -99,3 +187,7 @@ REGISTER_OP_NPU_KERNEL(
ops::ROIAlignNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ROIAlignNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ROIAlignNPUKernel<paddle::platform::NPUDeviceContext, double>, ops::ROIAlignNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::ROIAlignNPUKernel<paddle::platform::NPUDeviceContext, int>); ops::ROIAlignNPUKernel<paddle::platform::NPUDeviceContext, int>);
REGISTER_OP_NPU_KERNEL(roi_align_grad, ops::ROIAlignNPUGradKernel<float>,
ops::ROIAlignNPUGradKernel<double>,
ops::ROIAlignNPUGradKernel<int>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册