提交 29819ba7 编写于 作者: Y Yu Yang

Fix unittest

上级 acc54c7b
......@@ -35,14 +35,17 @@ class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDescBind &op_desc,
BlockDescBind *block) const override {
auto default_var_type = VarDesc::LOD_TENSOR;
for (auto &in_var_name : op_desc.Input("X")) {
auto in_var_type = block->Var(in_var_name)->GetType();
if (in_var_type != default_var_type) {
default_var_type = in_var_type;
break;
}
auto &inputs = op_desc.Input("X");
auto default_var_type = VarDesc::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == VarDesc::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = VarDesc::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
}
......@@ -65,20 +68,18 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.Block(0)->NewVar("test_a")->SetType(VarDesc_VarType_LOD_TENSOR);
prog.Block(0)->NewVar("test_b")->SetType(VarDesc_VarType_LOD_TENSOR);
prog.Block(0)->NewVar("test_c")->SetType(VarDesc_VarType_LOD_TENSOR);
prog.Block(0)->NewVar("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test_out");
op->InferVarType(prog.Block(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR,
prog.Block(0)->Var("test_out")->GetType());
ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_b")->SetType(VarDesc_VarType_SELECTED_ROWS);
prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
op->InferVarType(prog.Block(0));
ASSERT_EQ(VarDesc_VarType_SELECTED_ROWS,
prog.Block(0)->Var("test_out")->GetType());
ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType());
}
TEST(InferVarType, sum_op_without_infer_var_type) {
......@@ -88,9 +89,9 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
prog.Block(0)->NewVar("test2_a")->SetType(VarDesc_VarType_LOD_TENSOR);
prog.Block(0)->NewVar("test2_b")->SetType(VarDesc_VarType_SELECTED_ROWS);
prog.Block(0)->NewVar("test2_c")->SetType(VarDesc_VarType_LOD_TENSOR);
prog.Block(0)->NewVar("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->NewVar("test2_out");
op->InferVarType(prog.Block(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册