未验证 提交 822ea0f9 编写于 作者: S Sanbu 提交者: GitHub

Add not_equal trt converter (#49393)

上级 c5137b22
...@@ -2396,6 +2396,7 @@ USE_TRT_CONVERTER(cast) ...@@ -2396,6 +2396,7 @@ USE_TRT_CONVERTER(cast)
USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding) USE_TRT_CONVERTER(remove_padding)
USE_TRT_CONVERTER(equal); USE_TRT_CONVERTER(equal);
USE_TRT_CONVERTER(not_equal);
USE_TRT_CONVERTER(top_k) USE_TRT_CONVERTER(top_k)
USE_TRT_CONVERTER(top_k_v2) USE_TRT_CONVERTER(top_k_v2)
USE_TRT_CONVERTER(range) USE_TRT_CONVERTER(range)
......
...@@ -35,7 +35,6 @@ class EqualOpConverter : public OpConverter { ...@@ -35,7 +35,6 @@ class EqualOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, const framework::Scope& scope,
bool test_mode) override { bool test_mode) override {
#if IS_TRT_VERSION_GE(8000)
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
...@@ -79,11 +78,62 @@ class EqualOpConverter : public OpConverter { ...@@ -79,11 +78,62 @@ class EqualOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER( layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kEQUAL); engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kEQUAL);
RreplenishLayerAndOutput(layer, "equal", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "equal", {output_name}, test_mode);
#else }
PADDLE_THROW( };
platform::errors::Fatal("ElementWise Equal Operation is only supported "
"on TRT 8 or higher version.")); class NotEqualOpConverter : public OpConverter {
#endif public:
NotEqualOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());
nvinfer1::Dims dims_x = X->getDimensions();
nvinfer1::Dims dims_y = Y->getDimensions();
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis < 0) {
axis = std::abs(dims_x.nbDims - dims_y.nbDims);
}
auto output_name = op_desc.Output("Out")[0];
nvinfer1::IShuffleLayer* expand_layer = nullptr;
if (dims_x.nbDims > dims_y.nbDims) {
nvinfer1::Dims expand_shape;
expand_shape.nbDims = dims_x.nbDims;
for (int i = 0; i < expand_shape.nbDims; i++) {
expand_shape.d[i] = 1;
}
for (int i = 0; i < dims_y.nbDims; i++) {
expand_shape.d[i + axis] = dims_y.d[i];
}
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *Y);
expand_layer->setReshapeDimensions(expand_shape);
Y = expand_layer->getOutput(0);
} else if (dims_x.nbDims < dims_y.nbDims) {
nvinfer1::Dims expand_shape;
expand_shape.nbDims = dims_y.nbDims;
for (int i = 0; i < expand_shape.nbDims; i++) {
expand_shape.d[i] = 1;
}
for (int i = 0; i < dims_x.nbDims; i++) {
expand_shape.d[i + axis] = dims_x.d[i];
}
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
expand_layer->setReshapeDimensions(expand_shape);
X = expand_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kEQUAL);
layer = TRT_ENGINE_ADD_LAYER(
engine_, Unary, *layer->getOutput(0), nvinfer1::UnaryOperation::kNOT);
RreplenishLayerAndOutput(layer, "not_equal", {output_name}, test_mode);
} }
}; };
...@@ -92,3 +142,4 @@ class EqualOpConverter : public OpConverter { ...@@ -92,3 +142,4 @@ class EqualOpConverter : public OpConverter {
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(equal, EqualOpConverter); REGISTER_TRT_OP_CONVERTER(equal, EqualOpConverter);
REGISTER_TRT_OP_CONVERTER(not_equal, NotEqualOpConverter);
...@@ -2341,7 +2341,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2341,7 +2341,7 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
#endif #endif
if (op_type == "equal") { if (op_type == "equal" || op_type == "not_equal") {
#if !IS_TRT_VERSION_GE(8000) #if !IS_TRT_VERSION_GE(8000)
VLOG(3) << "compare is not supported when TensorRT < 8.0"; VLOG(3) << "compare is not supported when TensorRT < 8.0";
return false; return false;
...@@ -2493,6 +2493,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2493,6 +2493,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_max", "elementwise_max",
"elementwise_floordiv", "elementwise_floordiv",
"equal", "equal",
"not_equal",
"less_than", "less_than",
"greater_than", "greater_than",
"logical_or", "logical_or",
...@@ -2639,6 +2640,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2639,6 +2640,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_max", "elementwise_max",
"elementwise_floordiv", "elementwise_floordiv",
"equal", "equal",
"not_equal",
"less_than", "less_than",
"greater_than", "greater_than",
"logical_or", "logical_or",
......
...@@ -39,45 +39,46 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -39,45 +39,46 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
def generate_input(shape): def generate_input(shape):
return np.random.random(shape).astype(np.float32) return np.random.random(shape).astype(np.float32)
for batch in [1, 2, 4]: for op_type in ["equal", "not_equal"]:
for shape in [[batch, 1], [batch, 1, 32], [batch, 1, 16, 32]]: for batch in [1, 2, 4]:
for axis in [-1 if len(shape) == 1 else 1]: for shape in [[batch, 1], [batch, 1, 32], [batch, 1, 16, 32]]:
self.dims = len(shape) for axis in [-1 if len(shape) == 1 else 1]:
dics = [{"axis": axis}, {"in_dtype": 0, "out_dtype": 5}] self.dims = len(shape)
ops_config = [ dics = [{"axis": axis}, {"in_dtype": 0, "out_dtype": 5}]
{ ops_config = [
"op_type": "equal", {
"op_inputs": { "op_type": op_type,
"X": ["input_data1"], "op_inputs": {
"Y": ["input_data2"], "X": ["input_data1"],
"Y": ["input_data2"],
},
"op_outputs": {"Out": ["compare_output_data"]},
"op_attrs": dics[0],
}, },
"op_outputs": {"Out": ["compare_output_data"]}, {
"op_attrs": dics[0], "op_type": "cast",
}, "op_inputs": {"X": ["compare_output_data"]},
{ "op_outputs": {"Out": ["output_data"]},
"op_type": "cast", "op_attrs": dics[1],
"op_inputs": {"X": ["compare_output_data"]}, },
"op_outputs": {"Out": ["output_data"]}, ]
"op_attrs": dics[1], ops = self.generate_op_config(ops_config)
},
] program_config = ProgramConfig(
ops = self.generate_op_config(ops_config) ops=ops,
weights={},
program_config = ProgramConfig( inputs={
ops=ops, "input_data1": TensorConfig(
weights={}, data_gen=partial(generate_input, shape)
inputs={ ),
"input_data1": TensorConfig( "input_data2": TensorConfig(
data_gen=partial(generate_input, shape) data_gen=partial(generate_input, shape)
), ),
"input_data2": TensorConfig( },
data_gen=partial(generate_input, shape) outputs=["output_data"],
), )
},
outputs=["output_data"], yield program_config
)
yield program_config
def sample_predictor_configs( def sample_predictor_configs(
self, program_config self, program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册