未验证 提交 cf2d38b0 编写于 作者: M mapingshuo 提交者: GitHub

rename elementwise_sub_grad (#3260)

* rename elementwise_sub_grad, test=develop

* rename Grad to GRAD, test=develop

* deal with case that Y@GRAD is empty, test=develop
上级 2386c034
...@@ -76,12 +76,32 @@ void ElementwiseAddGradCompute::Run() { ...@@ -76,12 +76,32 @@ void ElementwiseAddGradCompute::Run() {
const float* x_data = param.X->data<float>(); const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>(); const float* y_data = param.Y->data<float>();
const float* out_grad_data = param.OutGrad->data<float>(); const float* out_grad_data = param.OutGrad->data<float>();
float* x_grad_data = param.XGrad->mutable_data<float>(); float* x_grad_data;
float* y_grad_data = param.YGrad->mutable_data<float>(); float* y_grad_data;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
if (param.YGrad) {
y_grad_data = param.YGrad->mutable_data<float>();
}
int axis = param.axis; int axis = param.axis;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (!param.XGrad) {
CHECK(param.YGrad);
lite::arm::math::elementwise_add_grad(
out_grad_data, y_grad_data, y_dims.production());
return;
}
if (!param.YGrad) {
CHECK(param.XGrad);
lite::arm::math::elementwise_add_grad(
out_grad_data, x_grad_data, x_dims.production());
return;
}
if (x_dims.size() < y_dims.size() && if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_grad_broadcast( lite::arm::math::elementwise_add_grad_broadcast(
...@@ -102,14 +122,28 @@ void ElementwiseSubGradCompute::Run() { ...@@ -102,14 +122,28 @@ void ElementwiseSubGradCompute::Run() {
const float* x_data = param.X->data<float>(); const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>(); const float* y_data = param.Y->data<float>();
const float* out_data = param.OutGrad->data<float>(); const float* out_data = param.OutGrad->data<float>();
float* x_grad_data = param.XGrad->mutable_data<float>(); float* x_grad_data;
float* y_grad_data = param.YGrad->mutable_data<float>(); float* y_grad_data;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
if (param.YGrad) {
y_grad_data = param.YGrad->mutable_data<float>();
}
int axis = param.axis; int axis = param.axis;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (!param.XGrad || !param.YGrad) {
CHECK(param.XGrad || param.YGrad);
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, y_dims.production());
return;
}
if (x_dims.size() < y_dims.size()) { if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div grad don't support x_dims size < y_dims size"; LOG(FATAL) << "elewise sub grad don't support x_dims size < y_dims size";
} }
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_grad_broadcast( lite::arm::math::elementwise_sub_grad_broadcast(
...@@ -148,10 +182,11 @@ REGISTER_LITE_KERNEL(elementwise_add_grad, ...@@ -148,10 +182,11 @@ REGISTER_LITE_KERNEL(elementwise_add_grad,
kNCHW, kNCHW,
paddle::lite::kernels::arm::ElementwiseAddGradCompute, paddle::lite::kernels::arm::ElementwiseAddGradCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_sub_grad, REGISTER_LITE_KERNEL(elementwise_sub_grad,
...@@ -160,10 +195,11 @@ REGISTER_LITE_KERNEL(elementwise_sub_grad, ...@@ -160,10 +195,11 @@ REGISTER_LITE_KERNEL(elementwise_sub_grad,
kNCHW, kNCHW,
paddle::lite::kernels::arm::ElementwiseSubGradCompute, paddle::lite::kernels::arm::ElementwiseSubGradCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_div_grad, REGISTER_LITE_KERNEL(elementwise_div_grad,
...@@ -172,18 +208,20 @@ REGISTER_LITE_KERNEL(elementwise_div_grad, ...@@ -172,18 +208,20 @@ REGISTER_LITE_KERNEL(elementwise_div_grad,
kNCHW, kNCHW,
paddle::lite::kernels::arm::ElementwiseDivGradCompute, paddle::lite::kernels::arm::ElementwiseDivGradCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul_grad, kARM, kFloat, kNCHW, elementwise_mul_grad_float, def) elementwise_mul_grad, kARM, kFloat, kNCHW, elementwise_mul_grad_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_max_grad, REGISTER_LITE_KERNEL(elementwise_max_grad,
...@@ -192,8 +230,9 @@ REGISTER_LITE_KERNEL(elementwise_max_grad, ...@@ -192,8 +230,9 @@ REGISTER_LITE_KERNEL(elementwise_max_grad,
kNCHW, kNCHW,
paddle::lite::kernels::arm::ElementwiseMaxGradCompute, paddle::lite::kernels::arm::ElementwiseMaxGradCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -21,8 +21,7 @@ namespace lite { ...@@ -21,8 +21,7 @@ namespace lite {
namespace operators { namespace operators {
bool ElementwiseGradOp::CheckShape() const { bool ElementwiseGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.XGrad); CHECK_OR_FALSE(param_.XGrad || param_.YGrad);
CHECK_OR_FALSE(param_.YGrad);
CHECK_OR_FALSE(param_.OutGrad); CHECK_OR_FALSE(param_.OutGrad);
return true; return true;
} }
...@@ -30,8 +29,12 @@ bool ElementwiseGradOp::CheckShape() const { ...@@ -30,8 +29,12 @@ bool ElementwiseGradOp::CheckShape() const {
bool ElementwiseGradOp::InferShape() const { bool ElementwiseGradOp::InferShape() const {
auto x_dim = param_.X->dims(); auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims(); auto y_dim = param_.Y->dims();
param_.XGrad->Resize(x_dim); if (param_.XGrad) {
param_.YGrad->Resize(y_dim); param_.XGrad->Resize(x_dim);
}
if (param_.YGrad) {
param_.YGrad->Resize(y_dim);
}
return true; return true;
} }
...@@ -39,14 +42,21 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc, ...@@ -39,14 +42,21 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) { lite::Scope* scope) {
auto Y_name = opdesc.Input("Y").front(); auto Y_name = opdesc.Input("Y").front();
auto X_name = opdesc.Input("X").front(); auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Input("Out@Grad").front(); auto Out_name = opdesc.Input("Out@GRAD").front();
auto x_grad_name = opdesc.Output("X@Grad").front(); CHECK(!opdesc.Output("X@GRAD").empty() || !opdesc.Output("Y@GRAD").empty())
auto y_grad_name = opdesc.Output("Y@Grad").front(); << "at least one of 'X@GRAD' and 'Y@GRAD' is not empty";
if (!opdesc.Output("X@GRAD").empty()) {
auto x_grad_name = opdesc.Output("X@GRAD").front();
param_.XGrad = GetMutableVar<lite::Tensor>(scope, x_grad_name);
}
if (!opdesc.Output("Y@GRAD").empty()) {
auto y_grad_name = opdesc.Output("Y@GRAD").front();
param_.YGrad = GetMutableVar<lite::Tensor>(scope, y_grad_name);
}
param_.X = GetVar<lite::Tensor>(scope, X_name); param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Y = GetVar<lite::Tensor>(scope, Y_name); param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.XGrad = GetMutableVar<lite::Tensor>(scope, x_grad_name);
param_.YGrad = GetMutableVar<lite::Tensor>(scope, y_grad_name);
param_.OutGrad = GetVar<lite::Tensor>(scope, Out_name); param_.OutGrad = GetVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis"); param_.axis = opdesc.GetAttr<int>("axis");
return true; return true;
...@@ -56,9 +66,9 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc, ...@@ -56,9 +66,9 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc,
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(elementwise_grad_sub, REGISTER_LITE_OP(elementwise_sub_grad,
paddle::lite::operators::ElementwiseGradOp); paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_add, REGISTER_LITE_OP(elementwise_add_grad,
paddle::lite::operators::ElementwiseGradOp); paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_mul, REGISTER_LITE_OP(elementwise_grad_mul,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册