未验证 提交 2c12abd7 编写于 作者: W Wilber 提交者: GitHub
上级 c50f5fa4
......@@ -542,16 +542,6 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}
auto index_var_name = desc.Input("Index")[0];
auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype.
if (index_var_desc->GetDataType() !=
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "gather op Index input data type must be int32";
return false;
}
#if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
......
......@@ -182,7 +182,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
if self.input_num == 3:
return 0, 5
else:
if dynamic_shape and self.index_type_int32:
if dynamic_shape:
return 1, 3
else:
return 0, 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册