未验证 提交 5a44bf7e 编写于 作者: W Wilber 提交者: GitHub

Add trt pow converter. (#53462)

* Add trt pow converter.

* update to use AddConstantLayer

* add dims=0 ut
上级 a4997311
......@@ -319,6 +319,37 @@ class ElementwiseTensorModOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorModOpConverter() { op_type_ = "mod"; }
};
// The diff between `pow` and `elementwise_pow` is in:
// https://github.com/PaddlePaddle/Paddle/blob/release/2.4/python/paddle/tensor/math.py#L420
class PowOpConverter : public OpConverter {
public:
PowOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "Convert a pow op to TensorRT IElementWiseLayer";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
float factor = PADDLE_GET_CONST(float, op_desc.GetAttr("factor"));
nvinfer1::Dims dims_x = X->getDimensions();
auto output_name = op_desc.Output("Out")[0];
nvinfer1::Dims trt_dims_y;
trt_dims_y.nbDims = dims_x.nbDims;
for (int i = 0; i < trt_dims_y.nbDims; i++) {
trt_dims_y.d[i] = 1;
}
std::vector<float> w_data{factor};
auto* Y = AddConstantLayer(w_data.data(), trt_dims_y);
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kPOW);
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -369,3 +400,5 @@ REGISTER_TRT_OP_CONVERTER(logical_and, ElementwiseTensorLogicalAndOpConverter);
REGISTER_TRT_OP_CONVERTER(less_equal, ElementwiseTensorLessEqualOpConverter);
REGISTER_TRT_OP_CONVERTER(greater_equal,
ElementwiseTensorGreaterEqualOpConverter);
REGISTER_TRT_OP_CONVERTER(pow, PowOpConverter);
......@@ -1498,6 +1498,31 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "pow") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << op_type
<< " op does not support input's dim is 1 or 0 in tensorrt "
"static shape mode.";
return false;
}
// the same as `elementwise_pow`.
if (x_var_desc->GetDataType() ==
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "These operations (pow) do not support int32 "
"datatype.";
return false;
}
}
if (op_type == "stack") {
if (!with_dynamic_shape) {
VLOG(3)
......@@ -2885,6 +2910,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"pow",
"elementwise_min",
"elementwise_max",
"elementwise_floordiv",
......
......@@ -1092,5 +1092,127 @@ class TrtConvertElementwiseTestTwoInputSkipCase(TrtLayerAutoScanTest):
self.run_test()
class TrtConvertPowOp(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
if len(shape) == 0:
return np.random.random([]).astype(np.float32)
return np.random.random(shape).astype(np.float32)
for batch in [1, 4]:
for shape in [
[],
[32],
[batch, 32],
[batch, 32, 32],
[batch, 32, 16, 32],
]:
for factor in [1.0, 2.0, -1.0, 0.5, -2]:
self.dims = len(shape)
dics = [{"factor": factor}]
ops_config = [
{
"op_type": "pow",
"op_inputs": {
"X": ["input_data"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
"outputs_dtype": {"output_data": np.float32},
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, shape)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [4]}
self.dynamic_shape.max_input_shape = {"input_data": [32]}
self.dynamic_shape.opt_input_shape = {"input_data": [16]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {"input_data": [1, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 32]}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 32]}
elif self.dims == 3:
self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 4]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 32, 32]}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 32, 32]}
elif self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 32, 4, 4]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 32, 32]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [4, 32, 16, 32]
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if (self.dims == 1 or self.dims == 0) and not dynamic_shape:
return 0, 3
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册