未验证 提交 6d2deedf 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #10814 from guoshengCS/fix-ElementwiseOpInferVarType

Fix ElementwiseOpInferVarType in elementwise_op
...@@ -46,9 +46,11 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { ...@@ -46,9 +46,11 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
auto x_var = op_desc.Input("X")[0]; auto x_name = op_desc.Input("X")[0];
auto out_var = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
block->Var(out_var)->SetType(block->Var(x_var)->GetType()); auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册