未验证 提交 cf2d38b0 编写于 作者: M mapingshuo 提交者: GitHub

rename elementwise_sub_grad (#3260)

* rename elementwise_sub_grad, test=develop

* rename Grad to GRAD, test=develop

* deal with case that Y@GRAD is empty, test=develop
上级 2386c034
......@@ -76,12 +76,32 @@ void ElementwiseAddGradCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_grad_data = param.OutGrad->data<float>();
float* x_grad_data = param.XGrad->mutable_data<float>();
float* y_grad_data = param.YGrad->mutable_data<float>();
float* x_grad_data;
float* y_grad_data;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
if (param.YGrad) {
y_grad_data = param.YGrad->mutable_data<float>();
}
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (!param.XGrad) {
CHECK(param.YGrad);
lite::arm::math::elementwise_add_grad(
out_grad_data, y_grad_data, y_dims.production());
return;
}
if (!param.YGrad) {
CHECK(param.XGrad);
lite::arm::math::elementwise_add_grad(
out_grad_data, x_grad_data, x_dims.production());
return;
}
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_grad_broadcast(
......@@ -102,14 +122,28 @@ void ElementwiseSubGradCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_data = param.OutGrad->data<float>();
float* x_grad_data = param.XGrad->mutable_data<float>();
float* y_grad_data = param.YGrad->mutable_data<float>();
float* x_grad_data;
float* y_grad_data;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
if (param.YGrad) {
y_grad_data = param.YGrad->mutable_data<float>();
}
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (!param.XGrad || !param.YGrad) {
CHECK(param.XGrad || param.YGrad);
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, y_dims.production());
return;
}
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div grad don't support x_dims size < y_dims size";
LOG(FATAL) << "elewise sub grad don't support x_dims size < y_dims size";
}
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_grad_broadcast(
......@@ -148,10 +182,11 @@ REGISTER_LITE_KERNEL(elementwise_add_grad,
kNCHW,
paddle::lite::kernels::arm::ElementwiseAddGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_sub_grad,
......@@ -160,10 +195,11 @@ REGISTER_LITE_KERNEL(elementwise_sub_grad,
kNCHW,
paddle::lite::kernels::arm::ElementwiseSubGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_div_grad,
......@@ -172,18 +208,20 @@ REGISTER_LITE_KERNEL(elementwise_div_grad,
kNCHW,
paddle::lite::kernels::arm::ElementwiseDivGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
elementwise_mul_grad, kARM, kFloat, kNCHW, elementwise_mul_grad_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_max_grad,
......@@ -192,8 +230,9 @@ REGISTER_LITE_KERNEL(elementwise_max_grad,
kNCHW,
paddle::lite::kernels::arm::ElementwiseMaxGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -21,8 +21,7 @@ namespace lite {
namespace operators {
bool ElementwiseGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.XGrad);
CHECK_OR_FALSE(param_.YGrad);
CHECK_OR_FALSE(param_.XGrad || param_.YGrad);
CHECK_OR_FALSE(param_.OutGrad);
return true;
}
......@@ -30,8 +29,12 @@ bool ElementwiseGradOp::CheckShape() const {
bool ElementwiseGradOp::InferShape() const {
auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims();
param_.XGrad->Resize(x_dim);
param_.YGrad->Resize(y_dim);
if (param_.XGrad) {
param_.XGrad->Resize(x_dim);
}
if (param_.YGrad) {
param_.YGrad->Resize(y_dim);
}
return true;
}
......@@ -39,14 +42,21 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto Y_name = opdesc.Input("Y").front();
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Input("Out@Grad").front();
auto x_grad_name = opdesc.Output("X@Grad").front();
auto y_grad_name = opdesc.Output("Y@Grad").front();
auto Out_name = opdesc.Input("Out@GRAD").front();
CHECK(!opdesc.Output("X@GRAD").empty() || !opdesc.Output("Y@GRAD").empty())
<< "at least one of 'X@GRAD' and 'Y@GRAD' is not empty";
if (!opdesc.Output("X@GRAD").empty()) {
auto x_grad_name = opdesc.Output("X@GRAD").front();
param_.XGrad = GetMutableVar<lite::Tensor>(scope, x_grad_name);
}
if (!opdesc.Output("Y@GRAD").empty()) {
auto y_grad_name = opdesc.Output("Y@GRAD").front();
param_.YGrad = GetMutableVar<lite::Tensor>(scope, y_grad_name);
}
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.XGrad = GetMutableVar<lite::Tensor>(scope, x_grad_name);
param_.YGrad = GetMutableVar<lite::Tensor>(scope, y_grad_name);
param_.OutGrad = GetVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis");
return true;
......@@ -56,9 +66,9 @@ bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc,
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(elementwise_grad_sub,
REGISTER_LITE_OP(elementwise_sub_grad,
paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_add,
REGISTER_LITE_OP(elementwise_add_grad,
paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_mul,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册