From 8573ca5406132b472e3df071da7e10564199fc29 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Mon, 8 Aug 2022 11:37:27 +0800 Subject: [PATCH] [MLU] fix bn_grad and hard_sigmoid_grad error (#44919) --- paddle/fluid/operators/activation_op_mlu.cc | 8 ++--- paddle/fluid/operators/batch_norm_op_mlu.cc | 2 +- .../fluid/operators/conv_transpose_op_mlu.cc | 32 +++++++------------ paddle/phi/kernels/funcs/activation_functor.h | 4 +++ 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index 72e0e9ceac..6cfe4738d7 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -370,7 +370,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* out = ctx.Input("Out"); + auto* x = ctx.Input("X"); auto* dx = ctx.Output(framework::GradVarName("X")); float slope = ctx.Attr("slope"); float offset = ctx.Attr("offset"); @@ -381,7 +381,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel { 1.0f /*sliced_dim useless*/, slope, offset); - MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc x_desc(*x); MLUCnnlTensorDesc dout_desc(*dout); MLUCnnlTensorDesc dx_desc(*dx); MLUCnnl::ActiveGrad(ctx, @@ -392,8 +392,8 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel { nullptr, dout_desc.get(), GetBasePtr(dout), - out_desc.get(), - GetBasePtr(out), + x_desc.get(), + GetBasePtr(x), dx_desc.get(), GetBasePtr(dx)); } diff --git a/paddle/fluid/operators/batch_norm_op_mlu.cc b/paddle/fluid/operators/batch_norm_op_mlu.cc index 199a9b95ec..1aa445bda3 100644 --- a/paddle/fluid/operators/batch_norm_op_mlu.cc +++ b/paddle/fluid/operators/batch_norm_op_mlu.cc @@ -273,7 +273,7 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel { const auto *running_mean = ctx.Input("Mean"); const auto *running_variance = ctx.Input("Variance"); MLUCnnl::FusedBatchNormGrad(ctx, - true /*is_training*/, + false /*is_training*/, transformed_desc.get(), GetBasePtr(&transformed_d_y), transformed_desc.get(), diff --git a/paddle/fluid/operators/conv_transpose_op_mlu.cc b/paddle/fluid/operators/conv_transpose_op_mlu.cc index 322328b1c2..f757898886 100644 --- a/paddle/fluid/operators/conv_transpose_op_mlu.cc +++ b/paddle/fluid/operators/conv_transpose_op_mlu.cc @@ -271,26 +271,18 @@ class Conv2DTransposeGradMLUKernel : public framework::OpKernel { data_layout_mlu, ToCnnlDataType(input_grad_tensor.dtype())); - cnnlDataType_t tensor_dtype = ToCnnlDataType(); - cnnlDataType_t dt_onchip = ToCnnlDataType(); - MLUCnnl::Conv2D(ctx, - conv_desc.get(), - tensor_dtype, - dt_onchip, - nullptr /* input_position */, - nullptr /* input_scale */, - nullptr /* input_offset */, - nullptr /* filter_position */, - nullptr /* filter_scale */, - nullptr /* filter_offset */, - output_grad_desc.get(), - GetBasePtr(&output_grad_tensor), - trans_filter_desc.get(), - GetBasePtr(&trans_filter), - nullptr /* bias_desc*/, - nullptr /* bias */, - input_grad_desc.get(), - GetBasePtr(&input_grad_tensor)); + MLUCnnl::ConvolutionForward(ctx, + conv_desc.get(), + nullptr /*alpha*/, + nullptr /*beta*/, + nullptr /*bias_desc*/, + nullptr /*bias_ptr*/, + output_grad_desc.get(), + GetBasePtr(&output_grad_tensor), + trans_filter_desc.get(), + GetBasePtr(&trans_filter), + input_grad_desc.get(), + GetBasePtr(&input_grad_tensor)); if (!channel_last) { // transpose output from NHWC to NCHW const std::vector perm_to_nchw = {0, 3, 1, 2}; diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 542c59bec1..318f2e8b6b 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1604,7 +1604,11 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { } static constexpr ActBwdOpFwdDeps FwdDeps() { +#ifdef PADDLE_WITH_MLU + return ActBwdOpFwdDeps::kDepX; +#else return ActBwdOpFwdDeps::kDepOut; +#endif } }; -- GitLab