未验证 提交 1ef1cace 编写于 作者: 光明和真理's avatar 光明和真理 提交者: GitHub

[MLU] add mlu kernel for add_reduce_max_grad (#45651)

Co-authored-by: Nliupeiyu <liupeiyu@cambricon.com>
上级 8e9c719d
......@@ -92,6 +92,112 @@ class ReduceMaxMLUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class ReduceMaxGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Input<Tensor>("Out");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto reduce_dims = context.Attr<std::vector<int>>("dim");
bool reduce_all = context.Attr<bool>("reduce_all");
int in_dtype = context.Attr<int>("in_dtype");
PADDLE_ENFORCE_EQ(
in_dtype == -1,
true,
platform::errors::InvalidArgument(
"MLU only support in_dtype == -1 in reduce_max_grad op."));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
x_grad->mutable_data<T>(context.GetPlace());
auto place = context.GetPlace();
// broadcast
auto x_dims_vec = phi::vectorize(x->dims());
if (reduce_all) {
reduce_dims.clear();
for (size_t d = 0; d < x_dims_vec.size(); ++d) {
reduce_dims.push_back(static_cast<int>(d));
}
}
Tensor tmp_out, tmp_out_grad;
auto tmp_out_dims_vec = x_dims_vec;
for (auto d : reduce_dims) {
if (d < 0) {
d += x_dims_vec.size();
}
tmp_out_dims_vec[d] = 1;
}
tmp_out.ShareDataWith(*out);
tmp_out.Resize(phi::make_ddim(tmp_out_dims_vec));
tmp_out_grad.ShareDataWith(*out_grad);
tmp_out_grad.Resize(phi::make_ddim(tmp_out_dims_vec));
Tensor transformed_out(x->type());
transformed_out.Resize(phi::make_ddim(x_dims_vec));
transformed_out.mutable_data<T>(place);
MLUCnnlTensorDesc tmp_out_desc(tmp_out);
MLUCnnlTensorDesc transformed_out_desc(transformed_out);
MLUCnnl::BroadcastTo(context,
tmp_out_desc.get(),
GetBasePtr(&tmp_out),
transformed_out_desc.get(),
GetBasePtr(&transformed_out));
Tensor transformed_out_grad(x->type());
transformed_out_grad.Resize(phi::make_ddim(x_dims_vec));
transformed_out_grad.mutable_data<T>(place);
MLUCnnlTensorDesc tmp_out_grad_desc(tmp_out_grad);
MLUCnnlTensorDesc transformed_out_grad_desc(transformed_out_grad);
MLUCnnl::BroadcastTo(context,
tmp_out_grad_desc.get(),
GetBasePtr(&tmp_out_grad),
transformed_out_grad_desc.get(),
GetBasePtr(&transformed_out_grad));
// compare
Tensor equal_cond;
equal_cond.mutable_data<bool>(x_grad->dims(), place);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc equal_cond_desc(equal_cond);
MLUCnnl::Logic(context,
CNNL_LOGIC_OP_EQ,
x_desc.get(),
GetBasePtr(x),
transformed_out_desc.get(),
GetBasePtr(&transformed_out),
equal_cond_desc.get(),
GetBasePtr(&equal_cond));
// select
Tensor t_zero;
t_zero.mutable_data<T>(x_grad->dims(), place);
FillMLUTensorWithHostValue<T>(context, static_cast<T>(0), &t_zero);
t_zero.Resize(x_grad->dims());
MLUCnnlTensorDesc t_zero_desc(t_zero);
MLUCnnlTensorDesc x_grad_desc(*x_grad);
MLUCnnl::Select(context,
equal_cond_desc.get(),
GetBasePtr(&equal_cond),
transformed_out_grad_desc.get(),
GetBasePtr(&transformed_out_grad),
t_zero_desc.get(),
GetBasePtr(&t_zero),
x_grad_desc.get(),
GetBasePtr(x_grad));
}
};
} // namespace operators
} // namespace paddle
......@@ -102,3 +208,7 @@ REGISTER_OP_MLU_KERNEL(reduce_max,
ops::ReduceMaxMLUKernel<float>,
ops::ReduceMaxMLUKernel<plat::float16>,
ops::ReduceMaxMLUKernel<int>);
REGISTER_OP_MLU_KERNEL(reduce_max_grad,
ops::ReduceMaxGradMLUKernel<float>,
ops::ReduceMaxGradMLUKernel<plat::float16>,
ops::ReduceMaxGradMLUKernel<int>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册