diff --git a/paddle/framework/var_type_inference_test.cc b/paddle/framework/var_type_inference_test.cc index e3f4893f1a68d255c0b557dcaa334ef05ca4a4a9..97b8c647485c4520dea895f220077ec6a6eedd70 100644 --- a/paddle/framework/var_type_inference_test.cc +++ b/paddle/framework/var_type_inference_test.cc @@ -35,14 +35,17 @@ class SumOpVarTypeInference : public VarTypeInference { public: void operator()(const OpDescBind &op_desc, BlockDescBind *block) const override { - auto default_var_type = VarDesc::LOD_TENSOR; - for (auto &in_var_name : op_desc.Input("X")) { - auto in_var_type = block->Var(in_var_name)->GetType(); - if (in_var_type != default_var_type) { - default_var_type = in_var_type; - break; - } + auto &inputs = op_desc.Input("X"); + auto default_var_type = VarDesc::SELECTED_ROWS; + + bool any_input_is_lod_tensor = std::any_of( + inputs.begin(), inputs.end(), [block](const std::string &name) { + return block->Var(name)->GetType() == VarDesc::LOD_TENSOR; + }); + if (any_input_is_lod_tensor) { + default_var_type = VarDesc::LOD_TENSOR; } + auto out_var_name = op_desc.Output("Out").front(); block->Var(out_var_name)->SetType(default_var_type); } @@ -65,20 +68,18 @@ TEST(InferVarType, sum_op) { op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetOutput("Out", {"test_out"}); - prog.Block(0)->NewVar("test_a")->SetType(VarDesc_VarType_LOD_TENSOR); - prog.Block(0)->NewVar("test_b")->SetType(VarDesc_VarType_LOD_TENSOR); - prog.Block(0)->NewVar("test_c")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test_a")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test_b")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test_c")->SetType(VarDesc::SELECTED_ROWS); prog.Block(0)->NewVar("test_out"); op->InferVarType(prog.Block(0)); - ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, - prog.Block(0)->Var("test_out")->GetType()); + ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType()); - prog.Block(0)->Var("test_b")->SetType(VarDesc_VarType_SELECTED_ROWS); + prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR); op->InferVarType(prog.Block(0)); - ASSERT_EQ(VarDesc_VarType_SELECTED_ROWS, - prog.Block(0)->Var("test_out")->GetType()); + ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType()); } TEST(InferVarType, sum_op_without_infer_var_type) { @@ -88,9 +89,9 @@ TEST(InferVarType, sum_op_without_infer_var_type) { op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetOutput("Out", {"test2_out"}); - prog.Block(0)->NewVar("test2_a")->SetType(VarDesc_VarType_LOD_TENSOR); - prog.Block(0)->NewVar("test2_b")->SetType(VarDesc_VarType_SELECTED_ROWS); - prog.Block(0)->NewVar("test2_c")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test2_a")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test2_b")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test2_c")->SetType(VarDesc::SELECTED_ROWS); prog.Block(0)->NewVar("test2_out"); op->InferVarType(prog.Block(0));