diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 52ad19c385e97dd528aef251590a69c46337dd3a..ec7c820a1440e7d07a73428bd0d7da216bc2705f 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -72,15 +72,18 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; void Apply() override { - paddle::Tensor out_grad = paddle::Tensor( - std::make_shared(this->SingleOutputGrad("Out"))); - paddle::Tensor x_grad = paddle::Tensor( - std::make_shared(this->SingleInputGrad("X"))); - auto dx_ptr = this->GetOutputPtr(&x_grad); - std::string dx_name = this->GetOutputName(x_grad); - auto dtype = phi::TransToPhiDataType((this->Attr("in_dtype"))); - prim::cast_grad(out_grad, dtype, dx_ptr); - this->RecoverOutputName(x_grad, dx_name); + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + + // get outputs + paddle::Tensor x_grad_t = this->GetSingleInputGrad("X"); + paddle::Tensor *x_grad = this->GetOutputPtr(&x_grad_t); + std::string x_grad_name = this->GetOutputName(x_grad_t); + + VLOG(6) << "Runing cast_grad composite func"; + prim::cast_grad(x, out_grad, x_grad); + + this->RecoverOutputName(x_grad_t, x_grad_name); } }; diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 8ad54079ca087b5feb839f6efca7246682ec3f40..979be562db93a519d93275bfde43d08a8b7c6bc1 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -110,9 +110,9 @@ void softmax_grad(const Tensor& out, } template -void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { +void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - auto res = cast(out_grad, dtype); + auto res = cast(out_grad, x.dtype()); set_output(res, x_grad); } }