未验证 提交 45b93325 编写于 作者: F feng_shuai 提交者: GitHub

fix:gather op (#46779)

* fix:gather op

* add ut
上级 140f3b24
......@@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}
auto index_var_name = desc.Input("Index")[0];
auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype.
if (index_var_desc->GetDataType() !=
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "gather op Index input data type must be int32";
return false;
}
#if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
......
......@@ -42,6 +42,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
def generate_input2(index):
return np.array(index).astype(np.int32)
def generate_input4(index):
return np.array(index).astype(np.int64)
def generate_input3(axis):
return np.array([axis]).astype(np.int32)
......@@ -57,9 +60,11 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
"Index": ["index_data"],
"Axis": ["axis_data"]
}]:
for index_type_int32 in [True, False]:
self.shape = shape
self.axis = axis
self.input_num = len(input)
self.index_type_int32 = index_type_int32
dics = [{"overwrite": overwrite, "axis": axis}]
ops_config = [{
"op_type": "gather",
......@@ -80,7 +85,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
generate_input1, shape)),
"index_data":
TensorConfig(data_gen=partial(
generate_input2, index)),
generate_input2
if index_type_int32 ==
True else generate_input4, index)),
} if len(input) == 2 else {
"input_data":
TensorConfig(data_gen=partial(
......@@ -162,7 +169,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
if self.input_num == 3:
return 0, 5
else:
if dynamic_shape:
if dynamic_shape and self.index_type_int32 == True:
return 1, 3
else:
return 0, 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册