未验证 提交 8573ca54 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix bn_grad and hard_sigmoid_grad error (#44919)

上级 713c4d0d
...@@ -370,7 +370,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> { ...@@ -370,7 +370,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<Tensor>("Out"); auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float slope = ctx.Attr<float>("slope"); float slope = ctx.Attr<float>("slope");
float offset = ctx.Attr<float>("offset"); float offset = ctx.Attr<float>("offset");
...@@ -381,7 +381,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> { ...@@ -381,7 +381,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
1.0f /*sliced_dim useless*/, 1.0f /*sliced_dim useless*/,
slope, slope,
offset); offset);
MLUCnnlTensorDesc out_desc(*out); MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc dout_desc(*dout); MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx); MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnl::ActiveGrad(ctx, MLUCnnl::ActiveGrad(ctx,
...@@ -392,8 +392,8 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> { ...@@ -392,8 +392,8 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
nullptr, nullptr,
dout_desc.get(), dout_desc.get(),
GetBasePtr(dout), GetBasePtr(dout),
out_desc.get(), x_desc.get(),
GetBasePtr(out), GetBasePtr(x),
dx_desc.get(), dx_desc.get(),
GetBasePtr(dx)); GetBasePtr(dx));
} }
......
...@@ -273,7 +273,7 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> { ...@@ -273,7 +273,7 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
const auto *running_mean = ctx.Input<Tensor>("Mean"); const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance"); const auto *running_variance = ctx.Input<Tensor>("Variance");
MLUCnnl::FusedBatchNormGrad(ctx, MLUCnnl::FusedBatchNormGrad(ctx,
true /*is_training*/, false /*is_training*/,
transformed_desc.get(), transformed_desc.get(),
GetBasePtr(&transformed_d_y), GetBasePtr(&transformed_d_y),
transformed_desc.get(), transformed_desc.get(),
......
...@@ -271,26 +271,18 @@ class Conv2DTransposeGradMLUKernel : public framework::OpKernel<T> { ...@@ -271,26 +271,18 @@ class Conv2DTransposeGradMLUKernel : public framework::OpKernel<T> {
data_layout_mlu, data_layout_mlu,
ToCnnlDataType(input_grad_tensor.dtype())); ToCnnlDataType(input_grad_tensor.dtype()));
cnnlDataType_t tensor_dtype = ToCnnlDataType<T>(); MLUCnnl::ConvolutionForward(ctx,
cnnlDataType_t dt_onchip = ToCnnlDataType<T>(); conv_desc.get(),
MLUCnnl::Conv2D(ctx, nullptr /*alpha*/,
conv_desc.get(), nullptr /*beta*/,
tensor_dtype, nullptr /*bias_desc*/,
dt_onchip, nullptr /*bias_ptr*/,
nullptr /* input_position */, output_grad_desc.get(),
nullptr /* input_scale */, GetBasePtr(&output_grad_tensor),
nullptr /* input_offset */, trans_filter_desc.get(),
nullptr /* filter_position */, GetBasePtr(&trans_filter),
nullptr /* filter_scale */, input_grad_desc.get(),
nullptr /* filter_offset */, GetBasePtr(&input_grad_tensor));
output_grad_desc.get(),
GetBasePtr(&output_grad_tensor),
trans_filter_desc.get(),
GetBasePtr(&trans_filter),
nullptr /* bias_desc*/,
nullptr /* bias */,
input_grad_desc.get(),
GetBasePtr(&input_grad_tensor));
if (!channel_last) { if (!channel_last) {
// transpose output from NHWC to NCHW // transpose output from NHWC to NCHW
const std::vector<int> perm_to_nchw = {0, 3, 1, 2}; const std::vector<int> perm_to_nchw = {0, 3, 1, 2};
......
...@@ -1604,7 +1604,11 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -1604,7 +1604,11 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { static constexpr ActBwdOpFwdDeps FwdDeps() {
#ifdef PADDLE_WITH_MLU
return ActBwdOpFwdDeps::kDepX;
#else
return ActBwdOpFwdDeps::kDepOut; return ActBwdOpFwdDeps::kDepOut;
#endif
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册