未验证 提交 45b93325 编写于 作者: F feng_shuai 提交者: GitHub

fix:gather op (#46779)

* fix:gather op

* add ut
上级 140f3b24
...@@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass."; "the pass.";
return false; return false;
} }
auto index_var_name = desc.Input("Index")[0];
auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype.
if (index_var_desc->GetDataType() !=
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "gather op Index input data type must be int32";
return false;
}
#if !IS_TRT_VERSION_GE(7000) #if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
......
...@@ -42,6 +42,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -42,6 +42,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
def generate_input2(index): def generate_input2(index):
return np.array(index).astype(np.int32) return np.array(index).astype(np.int32)
def generate_input4(index):
return np.array(index).astype(np.int64)
def generate_input3(axis): def generate_input3(axis):
return np.array([axis]).astype(np.int32) return np.array([axis]).astype(np.int32)
...@@ -57,9 +60,11 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -57,9 +60,11 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
"Index": ["index_data"], "Index": ["index_data"],
"Axis": ["axis_data"] "Axis": ["axis_data"]
}]: }]:
for index_type_int32 in [True, False]:
self.shape = shape self.shape = shape
self.axis = axis self.axis = axis
self.input_num = len(input) self.input_num = len(input)
self.index_type_int32 = index_type_int32
dics = [{"overwrite": overwrite, "axis": axis}] dics = [{"overwrite": overwrite, "axis": axis}]
ops_config = [{ ops_config = [{
"op_type": "gather", "op_type": "gather",
...@@ -80,7 +85,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -80,7 +85,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
generate_input1, shape)), generate_input1, shape)),
"index_data": "index_data":
TensorConfig(data_gen=partial( TensorConfig(data_gen=partial(
generate_input2, index)), generate_input2
if index_type_int32 ==
True else generate_input4, index)),
} if len(input) == 2 else { } if len(input) == 2 else {
"input_data": "input_data":
TensorConfig(data_gen=partial( TensorConfig(data_gen=partial(
...@@ -162,7 +169,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -162,7 +169,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
if self.input_num == 3: if self.input_num == 3:
return 0, 5 return 0, 5
else: else:
if dynamic_shape: if dynamic_shape and self.index_type_int32 == True:
return 1, 3 return 1, 3
else: else:
return 0, 4 return 0, 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册