未验证 提交 8573ca54 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix bn_grad and hard_sigmoid_grad error (#44919)

上级 713c4d0d
......@@ -370,7 +370,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<Tensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float slope = ctx.Attr<float>("slope");
float offset = ctx.Attr<float>("offset");
......@@ -381,7 +381,7 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
1.0f /*sliced_dim useless*/,
slope,
offset);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnl::ActiveGrad(ctx,
......@@ -392,8 +392,8 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
nullptr,
dout_desc.get(),
GetBasePtr(dout),
out_desc.get(),
GetBasePtr(out),
x_desc.get(),
GetBasePtr(x),
dx_desc.get(),
GetBasePtr(dx));
}
......
......@@ -273,7 +273,7 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
MLUCnnl::FusedBatchNormGrad(ctx,
true /*is_training*/,
false /*is_training*/,
transformed_desc.get(),
GetBasePtr(&transformed_d_y),
transformed_desc.get(),
......
......@@ -271,24 +271,16 @@ class Conv2DTransposeGradMLUKernel : public framework::OpKernel<T> {
data_layout_mlu,
ToCnnlDataType(input_grad_tensor.dtype()));
cnnlDataType_t tensor_dtype = ToCnnlDataType<T>();
cnnlDataType_t dt_onchip = ToCnnlDataType<T>();
MLUCnnl::Conv2D(ctx,
MLUCnnl::ConvolutionForward(ctx,
conv_desc.get(),
tensor_dtype,
dt_onchip,
nullptr /* input_position */,
nullptr /* input_scale */,
nullptr /* input_offset */,
nullptr /* filter_position */,
nullptr /* filter_scale */,
nullptr /* filter_offset */,
nullptr /*alpha*/,
nullptr /*beta*/,
nullptr /*bias_desc*/,
nullptr /*bias_ptr*/,
output_grad_desc.get(),
GetBasePtr(&output_grad_tensor),
trans_filter_desc.get(),
GetBasePtr(&trans_filter),
nullptr /* bias_desc*/,
nullptr /* bias */,
input_grad_desc.get(),
GetBasePtr(&input_grad_tensor));
if (!channel_last) {
......
......@@ -1604,7 +1604,11 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
#ifdef PADDLE_WITH_MLU
return ActBwdOpFwdDeps::kDepX;
#else
return ActBwdOpFwdDeps::kDepOut;
#endif
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册