未验证 提交 45b93325 编写于 作者: F feng_shuai 提交者: GitHub

fix:gather op (#46779)

* fix:gather op

* add ut
上级 140f3b24
...@@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -534,6 +534,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass."; "the pass.";
return false; return false;
} }
auto index_var_name = desc.Input("Index")[0];
auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype.
if (index_var_desc->GetDataType() !=
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "gather op Index input data type must be int32";
return false;
}
#if !IS_TRT_VERSION_GE(7000) #if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
......
...@@ -42,6 +42,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -42,6 +42,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
def generate_input2(index): def generate_input2(index):
return np.array(index).astype(np.int32) return np.array(index).astype(np.int32)
def generate_input4(index):
return np.array(index).astype(np.int64)
def generate_input3(axis): def generate_input3(axis):
return np.array([axis]).astype(np.int32) return np.array([axis]).astype(np.int32)
...@@ -57,44 +60,48 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -57,44 +60,48 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
"Index": ["index_data"], "Index": ["index_data"],
"Axis": ["axis_data"] "Axis": ["axis_data"]
}]: }]:
self.shape = shape for index_type_int32 in [True, False]:
self.axis = axis self.shape = shape
self.input_num = len(input) self.axis = axis
dics = [{"overwrite": overwrite, "axis": axis}] self.input_num = len(input)
ops_config = [{ self.index_type_int32 = index_type_int32
"op_type": "gather", dics = [{"overwrite": overwrite, "axis": axis}]
"op_inputs": input, ops_config = [{
"op_outputs": { "op_type": "gather",
"Out": ["output_data"] "op_inputs": input,
}, "op_outputs": {
"op_attrs": dics[0] "Out": ["output_data"]
}] },
ops = self.generate_op_config(ops_config) "op_attrs": dics[0]
}]
program_config = ProgramConfig( ops = self.generate_op_config(ops_config)
ops=ops,
weights={}, program_config = ProgramConfig(
inputs={ ops=ops,
"input_data": weights={},
TensorConfig(data_gen=partial( inputs={
generate_input1, shape)), "input_data":
"index_data": TensorConfig(data_gen=partial(
TensorConfig(data_gen=partial( generate_input1, shape)),
generate_input2, index)), "index_data":
} if len(input) == 2 else { TensorConfig(data_gen=partial(
"input_data": generate_input2
TensorConfig(data_gen=partial( if index_type_int32 ==
generate_input1, shape)), True else generate_input4, index)),
"index_data": } if len(input) == 2 else {
TensorConfig(data_gen=partial( "input_data":
generate_input2, index)), TensorConfig(data_gen=partial(
"axis_data": generate_input1, shape)),
TensorConfig(data_gen=partial( "index_data":
generate_input3, axis)), TensorConfig(data_gen=partial(
}, generate_input2, index)),
outputs=["output_data"]) "axis_data":
TensorConfig(data_gen=partial(
yield program_config generate_input3, axis)),
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs( def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float): self, program_config) -> (paddle_infer.Config, List[int], float):
...@@ -162,7 +169,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -162,7 +169,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
if self.input_num == 3: if self.input_num == 3:
return 0, 5 return 0, 5
else: else:
if dynamic_shape: if dynamic_shape and self.index_type_int32 == True:
return 1, 3 return 1, 3
else: else:
return 0, 4 return 0, 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册