未验证 提交 8c5c03c2 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify the signature of cast_grad for keeping consistent with yaml config. (#53498)

* modify concat_grad add sum comp rule

* modify cast
上级 aa887717
......@@ -72,15 +72,18 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
paddle::Tensor out_grad = paddle::Tensor(
std::make_shared<prim::DescTensor>(this->SingleOutputGrad("Out")));
paddle::Tensor x_grad = paddle::Tensor(
std::make_shared<prim::DescTensor>(this->SingleInputGrad("X")));
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dtype = phi::TransToPhiDataType((this->Attr<int>("in_dtype")));
prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
// get outputs
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
paddle::Tensor *x_grad = this->GetOutputPtr(&x_grad_t);
std::string x_grad_name = this->GetOutputName(x_grad_t);
VLOG(6) << "Runing cast_grad composite func";
prim::cast_grad<prim::DescTensor>(x, out_grad, x_grad);
this->RecoverOutputName(x_grad_t, x_grad_name);
}
};
......
......@@ -110,9 +110,9 @@ void softmax_grad(const Tensor& out,
}
template <typename T>
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto res = cast<T>(out_grad, dtype);
auto res = cast<T>(out_grad, x.dtype());
set_output<T>(res, x_grad);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册