未验证 提交 a0566010 编写于 作者: W weishengying 提交者: GitHub

Add symbolic shape deduction function for general Plugin mechanism (#46179)

上级 707d838b
......@@ -54,7 +54,61 @@ nvinfer1::DimsExprs GatherNdInferMeta(
}
return output;
}
nvinfer1::DimsExprs YoloBoxInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
PADDLE_ENFORCE_EQ(
nb_inputs,
2,
phi::errors::InvalidArgument("inputs of yolo_box should be equal to 2, "
"But received (%s)",
nb_inputs));
const nvinfer1::DimsExprs dim_x = inputs[0];
auto anchors = PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("anchors"));
int anchor_num = anchors.size() / 2;
// box_num = dim_x[2] * dim_x[3] * anchor_num;
const nvinfer1::IDimensionExpr* box_num = expr_builder.operation(
nvinfer1::DimensionOperation::kPROD,
*expr_builder.operation(
nvinfer1::DimensionOperation::kPROD, *dim_x.d[2], *dim_x.d[3]),
*expr_builder.constant(anchor_num));
nvinfer1::DimsExprs output;
output.nbDims = 3;
if (output_index == 0) {
output.d[0] = dim_x.d[0];
output.d[1] = box_num;
output.d[2] = expr_builder.constant(4);
} else {
auto class_num = PADDLE_GET_CONST(int, op_desc.GetAttr("class_num"));
output.d[0] = dim_x.d[0];
output.d[1] = box_num;
output.d[2] = expr_builder.constant(class_num);
}
return output;
}
nvinfer1::DimsExprs InstanceNormInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
nvinfer1::DimsExprs x_dims = inputs[0];
return x_dims;
}
PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -21,6 +21,8 @@ namespace inference {
namespace tensorrt {
USE_TRT_DYNAMIC_INFER_META_FN(gather_nd);
USE_TRT_DYNAMIC_INFER_META_FN(yolo_box);
USE_TRT_DYNAMIC_INFER_META_FN(instance_norm);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -216,6 +216,7 @@ void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
}
}
}
CHECK_EQ(attr_names.size(), kernel_context->AttrsSize());
}
GenericPlugin::GenericPlugin(
......@@ -333,12 +334,16 @@ int GenericPlugin::initialize() TRT_NOEXCEPT {
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(place));
phi_kernel_context_ = new phi::KernelContext(dev_ctx);
dense_tensor_inputs_ = new std::vector<phi::DenseTensor>(getNbInputs());
dense_tensor_outputs_ = new std::vector<phi::DenseTensor>(getNbOutputs());
if (!phi_kernel_context_) {
phi_kernel_context_ = new phi::KernelContext(dev_ctx);
BuildPhiKernelContextAttr(
op_desc_, phi_kernel_context_, phi_kernel_signature, phi_kernel);
}
if (!dense_tensor_inputs_)
dense_tensor_inputs_ = new std::vector<phi::DenseTensor>(getNbInputs());
if (!dense_tensor_outputs_)
dense_tensor_outputs_ = new std::vector<phi::DenseTensor>(getNbOutputs());
BuildPhiKernelContextAttr(
op_desc_, phi_kernel_context_, phi_kernel_signature, phi_kernel);
return 0;
}
......@@ -387,26 +392,28 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
platform::CUDAPlace place(platform::GetCurrentDeviceId());
// [TODO]now generic plugin do not support FP16 and INT8 precision
auto protoType2PhiType = [](int proto_type) -> phi::DataType {
auto protoType2PhiType = [](int proto_type) -> std::pair<phi::DataType, int> {
if (proto_type ==
static_cast<int>(framework::proto::VarType_Type::VarType_Type_FP32))
return phi::DataType::FLOAT32;
return {phi::DataType::FLOAT32, sizeof(float)};
else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT64) ||
proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT32))
return phi::DataType::INT32;
return {phi::DataType::INT32, sizeof(int32_t)};
else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_BOOL))
return phi::DataType::BOOL;
return {phi::DataType::BOOL, sizeof(bool)};
else
CHECK(false) << "precision is not supported";
};
// input
phi_kernel_context_->ClearInputOutput();
for (int i = 0; i < getNbInputs(); i++) {
auto const& input_dims = input_desc[i].dims;
......@@ -417,11 +424,12 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
int input_numel = 1;
for (int k = 0; k < input_shape.size(); k++) input_numel *= input_shape[k];
phi::DenseTensorMeta input_meta(protoType2PhiType(inputs_data_type_[i]),
auto data_type_and_size = protoType2PhiType(inputs_data_type_[i]);
phi::DenseTensorMeta input_meta(data_type_and_size.first,
phi::make_ddim(input_shape));
std::shared_ptr<phi::Allocation> input_alloc(
new phi::Allocation((void*)(inputs[i]), // NOLINT
input_numel * sizeof(int32_t),
input_numel * data_type_and_size.second,
place));
(*dense_tensor_inputs_)[i] =
std::move(phi::DenseTensor(input_alloc, input_meta));
......@@ -440,11 +448,12 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
for (int k = 0; k < output_shape.size(); k++)
output_numel *= output_shape[k];
phi::DenseTensorMeta output_meta(protoType2PhiType(outputs_data_type_[i]),
auto data_type_and_size = protoType2PhiType(inputs_data_type_[i]);
phi::DenseTensorMeta output_meta(data_type_and_size.first,
phi::make_ddim(output_shape));
std::shared_ptr<phi::Allocation> output_alloc(
new phi::Allocation(reinterpret_cast<void*>(outputs[i]),
output_numel * sizeof(float),
output_numel * data_type_and_size.second,
place));
phi::DenseTensor output_densetonsor(output_alloc, output_meta);
(*dense_tensor_outputs_)[i] =
......@@ -452,6 +461,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
phi_kernel_context_->EmplaceBackOutput(&((*dense_tensor_outputs_)[i]));
}
CHECK_EQ(phi_kernel_context_->InputsSize(), getNbInputs());
CHECK_EQ(phi_kernel_context_->OutputsSize(), getNbOutputs());
(*phi_kernel_)(phi_kernel_context_);
return cudaGetLastError() != cudaSuccess;
......
......@@ -128,10 +128,11 @@ class GenericPlugin : public DynamicPluginTensorRT {
framework::OpDesc op_desc_;
private:
phi::KernelContext* phi_kernel_context_;
const phi::Kernel* phi_kernel_;
std::vector<phi::DenseTensor>* dense_tensor_inputs_;
std::vector<phi::DenseTensor>* dense_tensor_outputs_;
const phi::Kernel* phi_kernel_{nullptr};
phi::KernelContext* phi_kernel_context_{nullptr};
std::vector<phi::DenseTensor>* dense_tensor_inputs_{nullptr};
std::vector<phi::DenseTensor>* dense_tensor_outputs_{nullptr};
private:
InputOutPutVarInfo in_out_info_;
......
......@@ -144,6 +144,13 @@ class KernelContext {
size_t OutputsSize() const { return outputs_.size(); }
size_t AttrsSize() const { return attrs_.size(); }
void ClearInputOutput() {
inputs_.clear();
input_range_.clear();
outputs_.clear();
output_range_.clear();
}
private:
DeviceContext* dev_ctx_;
......
......@@ -20,6 +20,7 @@ import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import os
class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
......@@ -113,7 +114,9 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape or self.in_dim != 4:
if dynamic_shape:
return 1, 2
if self.in_dim != 4:
return 0, 3
return 1, 2
......@@ -139,7 +142,30 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-3, 1e-3)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(
self.dynamic_shape.min_input_shape
) != 0 and self.trt_param.precision == paddle_infer.PrecisionType.Half:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in dynamic fp16 mode.")
def teller2(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"The output has diff between gpu and trt in Windows.")
def test(self):
self.add_skip_trt_case()
self.run_test()
......
......@@ -19,6 +19,7 @@ import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import os
class TrtConvertYoloBoxTest(TrtLayerAutoScanTest):
......@@ -139,10 +140,7 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape == True:
return 0, 5
else:
return 1, 4
return 1, 4
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
......@@ -166,7 +164,26 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest):
attrs, True), 1e-3
def add_skip_trt_case(self):
pass
def teller1(program_config, predictor_config):
if len(
self.dynamic_shape.min_input_shape
) != 0 and self.trt_param.precision == paddle_infer.PrecisionType.Half:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in dynamic fp16 mode.")
def teller2(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"The output has diff between gpu and trt in Windows.")
def test(self):
self.add_skip_trt_case()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册