未验证 提交 8c5c03c2 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify the signature of cast_grad for keeping consistent with yaml config. (#53498)

* modify concat_grad add sum comp rule

* modify cast
上级 aa887717
...@@ -72,15 +72,18 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -72,15 +72,18 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override { void Apply() override {
paddle::Tensor out_grad = paddle::Tensor( paddle::Tensor x = this->GetSingleForwardInput("X");
std::make_shared<prim::DescTensor>(this->SingleOutputGrad("Out"))); paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor x_grad = paddle::Tensor(
std::make_shared<prim::DescTensor>(this->SingleInputGrad("X"))); // get outputs
auto dx_ptr = this->GetOutputPtr(&x_grad); paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
std::string dx_name = this->GetOutputName(x_grad); paddle::Tensor *x_grad = this->GetOutputPtr(&x_grad_t);
auto dtype = phi::TransToPhiDataType((this->Attr<int>("in_dtype"))); std::string x_grad_name = this->GetOutputName(x_grad_t);
prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr);
this->RecoverOutputName(x_grad, dx_name); VLOG(6) << "Runing cast_grad composite func";
prim::cast_grad<prim::DescTensor>(x, out_grad, x_grad);
this->RecoverOutputName(x_grad_t, x_grad_name);
} }
}; };
......
...@@ -110,9 +110,9 @@ void softmax_grad(const Tensor& out, ...@@ -110,9 +110,9 @@ void softmax_grad(const Tensor& out,
} }
template <typename T> template <typename T>
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
auto res = cast<T>(out_grad, dtype); auto res = cast<T>(out_grad, x.dtype());
set_output<T>(res, x_grad); set_output<T>(res, x_grad);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册