提交 29819ba7 编写于 作者: Y Yu Yang

Fix unittest

上级 acc54c7b
...@@ -35,14 +35,17 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -35,14 +35,17 @@ class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDescBind &op_desc, void operator()(const OpDescBind &op_desc,
BlockDescBind *block) const override { BlockDescBind *block) const override {
auto default_var_type = VarDesc::LOD_TENSOR; auto &inputs = op_desc.Input("X");
for (auto &in_var_name : op_desc.Input("X")) { auto default_var_type = VarDesc::SELECTED_ROWS;
auto in_var_type = block->Var(in_var_name)->GetType();
if (in_var_type != default_var_type) { bool any_input_is_lod_tensor = std::any_of(
default_var_type = in_var_type; inputs.begin(), inputs.end(), [block](const std::string &name) {
break; return block->Var(name)->GetType() == VarDesc::LOD_TENSOR;
} });
if (any_input_is_lod_tensor) {
default_var_type = VarDesc::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); block->Var(out_var_name)->SetType(default_var_type);
} }
...@@ -65,20 +68,18 @@ TEST(InferVarType, sum_op) { ...@@ -65,20 +68,18 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.Block(0)->NewVar("test_a")->SetType(VarDesc_VarType_LOD_TENSOR); prog.Block(0)->NewVar("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test_b")->SetType(VarDesc_VarType_LOD_TENSOR); prog.Block(0)->NewVar("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test_c")->SetType(VarDesc_VarType_LOD_TENSOR); prog.Block(0)->NewVar("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test_out"); prog.Block(0)->NewVar("test_out");
op->InferVarType(prog.Block(0)); op->InferVarType(prog.Block(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_b")->SetType(VarDesc_VarType_SELECTED_ROWS); prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
op->InferVarType(prog.Block(0)); op->InferVarType(prog.Block(0));
ASSERT_EQ(VarDesc_VarType_SELECTED_ROWS, ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_out")->GetType());
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
...@@ -88,9 +89,9 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -88,9 +89,9 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.Block(0)->NewVar("test2_a")->SetType(VarDesc_VarType_LOD_TENSOR); prog.Block(0)->NewVar("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test2_b")->SetType(VarDesc_VarType_SELECTED_ROWS); prog.Block(0)->NewVar("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test2_c")->SetType(VarDesc_VarType_LOD_TENSOR); prog.Block(0)->NewVar("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test2_out"); prog.Block(0)->NewVar("test2_out");
op->InferVarType(prog.Block(0)); op->InferVarType(prog.Block(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册