diff --git a/lite/kernels/arm/elementwise_grad_compute.cc b/lite/kernels/arm/elementwise_grad_compute.cc index 53971e26706f15032dae244e7bd0493a49376cd6..93bc5853459005137ef4f948f3a5892d76441b7c 100644 --- a/lite/kernels/arm/elementwise_grad_compute.cc +++ b/lite/kernels/arm/elementwise_grad_compute.cc @@ -76,12 +76,32 @@ void ElementwiseAddGradCompute::Run() { const float* x_data = param.X->data(); const float* y_data = param.Y->data(); const float* out_grad_data = param.OutGrad->data(); - float* x_grad_data = param.XGrad->mutable_data(); - float* y_grad_data = param.YGrad->mutable_data(); + float* x_grad_data; + float* y_grad_data; + if (param.XGrad) { + x_grad_data = param.XGrad->mutable_data(); + } + if (param.YGrad) { + y_grad_data = param.YGrad->mutable_data(); + } int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; + if (!param.XGrad) { + CHECK(param.YGrad); + lite::arm::math::elementwise_add_grad( + out_grad_data, y_grad_data, y_dims.production()); + return; + } + + if (!param.YGrad) { + CHECK(param.XGrad); + lite::arm::math::elementwise_add_grad( + out_grad_data, x_grad_data, x_dims.production()); + return; + } + if (x_dims.size() < y_dims.size() && is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { lite::arm::math::elementwise_add_grad_broadcast( @@ -102,14 +122,28 @@ void ElementwiseSubGradCompute::Run() { const float* x_data = param.X->data(); const float* y_data = param.Y->data(); const float* out_data = param.OutGrad->data(); - float* x_grad_data = param.XGrad->mutable_data(); - float* y_grad_data = param.YGrad->mutable_data(); + float* x_grad_data; + float* y_grad_data; + if (param.XGrad) { + x_grad_data = param.XGrad->mutable_data(); + } + if (param.YGrad) { + y_grad_data = param.YGrad->mutable_data(); + } int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; + + if (!param.XGrad || !param.YGrad) { + CHECK(param.XGrad || param.YGrad); + lite::arm::math::elementwise_sub_grad( + out_data, x_grad_data, y_grad_data, y_dims.production()); + return; + } + if (x_dims.size() < y_dims.size()) { - LOG(FATAL) << "elewise div grad don't support x_dims size < y_dims size"; + LOG(FATAL) << "elewise sub grad don't support x_dims size < y_dims size"; } if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { lite::arm::math::elementwise_sub_grad_broadcast( @@ -148,10 +182,11 @@ REGISTER_LITE_KERNEL(elementwise_add_grad, kNCHW, paddle::lite::kernels::arm::ElementwiseAddGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); REGISTER_LITE_KERNEL(elementwise_sub_grad, @@ -160,10 +195,11 @@ REGISTER_LITE_KERNEL(elementwise_sub_grad, kNCHW, paddle::lite::kernels::arm::ElementwiseSubGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); REGISTER_LITE_KERNEL(elementwise_div_grad, @@ -172,18 +208,20 @@ REGISTER_LITE_KERNEL(elementwise_div_grad, kNCHW, paddle::lite::kernels::arm::ElementwiseDivGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); REGISTER_LITE_KERNEL( elementwise_mul_grad, kARM, kFloat, kNCHW, elementwise_mul_grad_float, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); REGISTER_LITE_KERNEL(elementwise_max_grad, @@ -192,8 +230,9 @@ REGISTER_LITE_KERNEL(elementwise_max_grad, kNCHW, paddle::lite::kernels::arm::ElementwiseMaxGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/operators/elementwise_grad_ops.cc b/lite/operators/elementwise_grad_ops.cc index 8d9e1040976a98d890dc8c841cb4f70d81453d61..9d964bf9e36889f2bc72b2656d23bf4022cc121c 100644 --- a/lite/operators/elementwise_grad_ops.cc +++ b/lite/operators/elementwise_grad_ops.cc @@ -21,8 +21,7 @@ namespace lite { namespace operators { bool ElementwiseGradOp::CheckShape() const { - CHECK_OR_FALSE(param_.XGrad); - CHECK_OR_FALSE(param_.YGrad); + CHECK_OR_FALSE(param_.XGrad || param_.YGrad); CHECK_OR_FALSE(param_.OutGrad); return true; } @@ -30,8 +29,12 @@ bool ElementwiseGradOp::CheckShape() const { bool ElementwiseGradOp::InferShape() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); - param_.XGrad->Resize(x_dim); - param_.YGrad->Resize(y_dim); + if (param_.XGrad) { + param_.XGrad->Resize(x_dim); + } + if (param_.YGrad) { + param_.YGrad->Resize(y_dim); + } return true; } @@ -39,14 +42,21 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto Y_name = opdesc.Input("Y").front(); auto X_name = opdesc.Input("X").front(); - auto Out_name = opdesc.Input("Out@Grad").front(); - auto x_grad_name = opdesc.Output("X@Grad").front(); - auto y_grad_name = opdesc.Output("Y@Grad").front(); + auto Out_name = opdesc.Input("Out@GRAD").front(); + CHECK(!opdesc.Output("X@GRAD").empty() || !opdesc.Output("Y@GRAD").empty()) + << "at least one of 'X@GRAD' and 'Y@GRAD' is not empty"; + + if (!opdesc.Output("X@GRAD").empty()) { + auto x_grad_name = opdesc.Output("X@GRAD").front(); + param_.XGrad = GetMutableVar(scope, x_grad_name); + } + if (!opdesc.Output("Y@GRAD").empty()) { + auto y_grad_name = opdesc.Output("Y@GRAD").front(); + param_.YGrad = GetMutableVar(scope, y_grad_name); + } param_.X = GetVar(scope, X_name); param_.Y = GetVar(scope, Y_name); - param_.XGrad = GetMutableVar(scope, x_grad_name); - param_.YGrad = GetMutableVar(scope, y_grad_name); param_.OutGrad = GetVar(scope, Out_name); param_.axis = opdesc.GetAttr("axis"); return true; @@ -56,9 +66,9 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc, } // namespace lite } // namespace paddle -REGISTER_LITE_OP(elementwise_grad_sub, +REGISTER_LITE_OP(elementwise_sub_grad, paddle::lite::operators::ElementwiseGradOp); -REGISTER_LITE_OP(elementwise_grad_add, +REGISTER_LITE_OP(elementwise_add_grad, paddle::lite::operators::ElementwiseGradOp); REGISTER_LITE_OP(elementwise_grad_mul,