未验证 提交 00f747f2 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] add generic plugin for p_norm (#53278)

上级 f6f48780
......@@ -14,7 +14,7 @@
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
namespace paddle {
......@@ -322,20 +322,51 @@ nvinfer1::DimsExprs PNormInferMeta(
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
const nvinfer1::DimsExprs x_dim = inputs[0];
std::vector<const nvinfer1::IDimensionExpr*> reduce_dims;
std::vector<const nvinfer1::IDimensionExpr*> keep_dims;
bool asvector = PADDLE_GET_CONST(bool, op_desc.GetAttr("asvector"));
bool keepdim = PADDLE_GET_CONST(bool, op_desc.GetAttr("keepdim"));
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
auto x_dim = inputs[0];
auto x_rank = x_dim.nbDims;
PADDLE_ENFORCE_GE(axis,
-x_rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim.d));
PADDLE_ENFORCE_LT(axis,
x_rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim.d));
// TODO(liuyuanle): support asvector = True
PADDLE_ENFORCE_EQ(
asvector,
false,
phi::errors::InvalidArgument(
"p_norm only support asvector=false, but received asvector: %d.",
asvector));
std::vector<const nvinfer1::IDimensionExpr*> reduce_dims;
if (asvector) {
reduce_dims.emplace_back(expr_builder.constant(1));
keep_dims.emplace_back(expr_builder.constant(1));
if (keepdim) {
for (int i = 1; i < x_dim.nbDims; ++i) {
keep_dims.emplace_back(expr_builder.constant(1));
reduce_dims.emplace_back(expr_builder.constant(1));
}
x_dim.nbDims = reduce_dims.size();
for (size_t i = 0; i < reduce_dims.size(); i++) {
x_dim.d[i] = reduce_dims[i];
}
}
} else {
......@@ -347,12 +378,11 @@ nvinfer1::DimsExprs PNormInferMeta(
reduce_dims.emplace_back(expr_builder.constant(1));
}
}
keep_dims[axis] = expr_builder.constant(1);
x_dim.d[axis] = expr_builder.constant(1);
nvinfer1::DimsExprs output;
if (keepdim) {
output.nbDims = keep_dims.size();
for (int i = 0; i < output.nbDims; i++) output.d[i] = keep_dims[i];
output = x_dim;
} else {
output.nbDims = reduce_dims.size();
for (int i = 0; i < output.nbDims; i++) output.d[i] = reduce_dims[i];
......@@ -396,6 +426,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -28,6 +28,7 @@ USE_TRT_DYNAMIC_INFER_META_FN(scatter_nd_add);
USE_TRT_DYNAMIC_INFER_META_FN(pad3d);
USE_TRT_DYNAMIC_INFER_META_FN(inverse);
USE_TRT_DYNAMIC_INFER_META_FN(grid_sampler);
USE_TRT_DYNAMIC_INFER_META_FN(p_norm);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -472,7 +472,7 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
cudaStream_t stream) TRT_NOEXCEPT {
platform::CUDAPlace place(platform::GetCurrentDeviceId());
// [TODO]now generic plugin do not support FP16 and INT8 precision
// [TODO]now generic plugin do not support INT8 precision
auto protoType2PhiType =
[&](int proto_type,
nvinfer1::DataType nv_dtype) -> std::pair<phi::DataType, int> {
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import Any, Dict, List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertCeluTest(TrtLayerAutoScanTest):
def sample_program_configs(self):
def generate_input1(dims, attrs: List[Dict[str, Any]]):
if dims == 1:
return np.ones([3]).astype(np.float32)
elif dims == 2:
return np.ones([3, 64]).astype(np.float32)
elif dims == 3:
return np.ones([3, 64, 64]).astype(np.float32)
else:
return np.ones([1, 3, 64, 64]).astype(np.float32)
for dims in [2, 3, 4]:
# TODO(liuyuanle): support asvector = True
for asvector in [False]:
for keepdim in [False, True]:
for porder in [0, 1, 2, 3]:
for axis in [-1]:
self.dims = dims
dics = [
{
"asvector": asvector,
"keepdim": keepdim,
"axis": axis,
"porder": porder,
}
]
ops_config = [
{
"op_type": "p_norm",
"op_inputs": {
"X": ["input_data"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input1, dims, dics
)
)
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [128]}
self.dynamic_shape.opt_input_shape = {"input_data": [64]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {"input_data": [1, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]}
elif self.dims == 3:
self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]}
self.dynamic_shape.max_input_shape = {
"input_data": [10, 64, 64]
}
self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]}
else:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 32, 32]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 64, 64]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 64, 64]
}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape mode
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册