未验证 提交 8d3457f6 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[prim] fix cast prim api dtype mapping error between phi and fluid (#51134)

上级 d3352b99
......@@ -79,8 +79,7 @@ class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
std::make_shared<prim::DescTensor>(this->SingleInputGrad("X")));
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dtype = static_cast<paddle::experimental::DataType>(
this->Attr<int>("in_dtype"));
auto dtype = phi::TransToPhiDataType((this->Attr<int>("in_dtype")));
prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
}
......
......@@ -160,8 +160,8 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("in_dtype", static_cast<int>(x.dtype()));
op->SetAttr("out_dtype", static_cast<int>(dtype));
op->SetAttr("in_dtype", paddle::framework::TransToProtoVarType(x.dtype()));
op->SetAttr("out_dtype", paddle::framework::TransToProtoVarType(dtype));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册