From 8c5c03c2d6d09ce9e5a99d804a5be02e50524b58 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 5 May 2023 11:18:32 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90prim=E3=80=91modify=20the=20signature?= =?UTF-8?q?=20of=20cast=5Fgrad=20for=20keeping=20consistent=20with=20yaml?= =?UTF-8?q?=20config.=20(#53498)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * modify concat_grad add sum comp rule * modify cast --- paddle/fluid/operators/cast_op.cc | 21 +++++++++++-------- .../composite_backward_api.h | 4 ++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 52ad19c385e..ec7c820a144 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -72,15 +72,18 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; void Apply() override { - paddle::Tensor out_grad = paddle::Tensor( - std::make_shared(this->SingleOutputGrad("Out"))); - paddle::Tensor x_grad = paddle::Tensor( - std::make_shared(this->SingleInputGrad("X"))); - auto dx_ptr = this->GetOutputPtr(&x_grad); - std::string dx_name = this->GetOutputName(x_grad); - auto dtype = phi::TransToPhiDataType((this->Attr("in_dtype"))); - prim::cast_grad(out_grad, dtype, dx_ptr); - this->RecoverOutputName(x_grad, dx_name); + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + + // get outputs + paddle::Tensor x_grad_t = this->GetSingleInputGrad("X"); + paddle::Tensor *x_grad = this->GetOutputPtr(&x_grad_t); + std::string x_grad_name = this->GetOutputName(x_grad_t); + + VLOG(6) << "Runing cast_grad composite func"; + prim::cast_grad(x, out_grad, x_grad); + + this->RecoverOutputName(x_grad_t, x_grad_name); } }; diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 8ad54079ca0..979be562db9 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -110,9 +110,9 @@ void softmax_grad(const Tensor& out, } template -void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { +void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - auto res = cast(out_grad, dtype); + auto res = cast(out_grad, x.dtype()); set_output(res, x_grad); } } -- GitLab