未验证 提交 8d3457f6 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[prim] fix cast prim api dtype mapping error between phi and fluid (#51134)

上级 d3352b99
...@@ -79,8 +79,7 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -79,8 +79,7 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
std::make_shared<prim::DescTensor>(this->SingleInputGrad("X"))); std::make_shared<prim::DescTensor>(this->SingleInputGrad("X")));
auto dx_ptr = this->GetOutputPtr(&x_grad); auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad); std::string dx_name = this->GetOutputName(x_grad);
auto dtype = static_cast<paddle::experimental::DataType>( auto dtype = phi::TransToPhiDataType((this->Attr<int>("in_dtype")));
this->Attr<int>("in_dtype"));
prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr); prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr);
this->RecoverOutputName(x_grad, dx_name); this->RecoverOutputName(x_grad, dx_name);
} }
......
...@@ -160,8 +160,8 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) { ...@@ -160,8 +160,8 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("in_dtype", static_cast<int>(x.dtype())); op->SetAttr("in_dtype", paddle::framework::TransToProtoVarType(x.dtype()));
op->SetAttr("out_dtype", static_cast<int>(dtype)); op->SetAttr("out_dtype", paddle::framework::TransToProtoVarType(dtype));
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册