未验证 提交 c15e53d6 编写于 作者: 周周周 提交者: GitHub

commit (#54339)

上级 cb2476cf
...@@ -21,6 +21,122 @@ namespace paddle { ...@@ -21,6 +21,122 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
class ExprWrapper {
public:
ExprWrapper() {}
ExprWrapper(const nvinfer1::IDimensionExpr* expr,
nvinfer1::IExprBuilder* expr_builder) {
this->expr = expr;
this->expr_builder = expr_builder;
}
ExprWrapper(int value, nvinfer1::IExprBuilder* expr_builder) {
this->expr = expr_builder->constant(value);
this->expr_builder = expr_builder;
}
const nvinfer1::IDimensionExpr* extract_expr() const { return expr; }
public:
friend ExprWrapper BinaryOp(const ExprWrapper& a,
const ExprWrapper& b,
nvinfer1::DimensionOperation op) {
ExprWrapper result;
if (a.expr_builder) {
result.expr_builder = a.expr_builder;
}
if (b.expr_builder) {
result.expr_builder = b.expr_builder;
}
assert(result.expr);
result.expr = result.expr_builder->operation(op, *a.expr, *b.expr);
return result;
}
friend ExprWrapper BinaryOp(const ExprWrapper& a,
int b_value,
nvinfer1::DimensionOperation op) {
assert(a.expr_builder);
ExprWrapper b;
b.expr_builder = a.expr_builder;
b.expr = b.expr_builder->constant(b_value);
return BinaryOp(a, b, op);
}
friend ExprWrapper operator+(const ExprWrapper& a, const ExprWrapper& b) {
return BinaryOp(a, b, nvinfer1::DimensionOperation::kSUM);
}
friend ExprWrapper operator+(const ExprWrapper& a, int b_value) {
return BinaryOp(a, b_value, nvinfer1::DimensionOperation::kSUM);
}
friend ExprWrapper operator+(int a_value, const ExprWrapper& b) {
return a_value + b;
}
friend ExprWrapper operator-(const ExprWrapper& a, const ExprWrapper& b) {
return BinaryOp(a, b, nvinfer1::DimensionOperation::kSUB);
}
friend ExprWrapper operator-(const ExprWrapper& a, int b_value) {
return BinaryOp(a, b_value, nvinfer1::DimensionOperation::kSUB);
}
friend ExprWrapper operator*(const ExprWrapper& a, const ExprWrapper& b) {
return BinaryOp(a, b, nvinfer1::DimensionOperation::kPROD);
}
friend ExprWrapper operator*(const ExprWrapper& a, int b_value) {
return BinaryOp(a, b_value, nvinfer1::DimensionOperation::kPROD);
}
friend ExprWrapper operator*(int a_value, const ExprWrapper& b) {
return b * a_value;
}
friend ExprWrapper operator/(const ExprWrapper& a, const ExprWrapper& b) {
return BinaryOp(a, b, nvinfer1::DimensionOperation::kFLOOR_DIV);
}
friend ExprWrapper operator/(const ExprWrapper& a, int b_value) {
return BinaryOp(a, b_value, nvinfer1::DimensionOperation::kFLOOR_DIV);
}
friend ExprWrapper max(const ExprWrapper& a, const ExprWrapper& b) {
return BinaryOp(a, b, nvinfer1::DimensionOperation::kMAX);
}
friend ExprWrapper max(const ExprWrapper& a, int b_value) {
return BinaryOp(a, b_value, nvinfer1::DimensionOperation::kMAX);
}
public:
const nvinfer1::IDimensionExpr* expr;
nvinfer1::IExprBuilder* expr_builder;
};
static std::vector<ExprWrapper> DimsExprs2VecExprWrapper(
const nvinfer1::DimsExprs& x_dims,
nvinfer1::IExprBuilder& expr_builder // NOLINT
) {
std::vector<ExprWrapper> x_dims_wrap;
for (int i = 0; i < x_dims.nbDims; i++) {
x_dims_wrap.push_back(ExprWrapper(x_dims.d[i], &expr_builder));
}
return x_dims_wrap;
}
static nvinfer1::DimsExprs VecExprWrapper2DimsExprs(
const std::vector<ExprWrapper>& output_dims_wrapper) {
nvinfer1::DimsExprs output_dims;
output_dims.nbDims = output_dims_wrapper.size();
for (int i = 0; i < output_dims.nbDims; i++) {
output_dims.d[i] = output_dims_wrapper[i].extract_expr();
}
return output_dims;
}
nvinfer1::DimsExprs GatherNdInferMeta( nvinfer1::DimsExprs GatherNdInferMeta(
int output_index, int output_index,
const nvinfer1::DimsExprs* inputs, const nvinfer1::DimsExprs* inputs,
...@@ -417,6 +533,148 @@ nvinfer1::DimsExprs GridSamplerInferMeta( ...@@ -417,6 +533,148 @@ nvinfer1::DimsExprs GridSamplerInferMeta(
return output; return output;
} }
inline const void UpdatePaddingAndDilation(
std::vector<ExprWrapper>* paddings_wrap,
std::vector<int>* dilation,
const std::string padding_algorithm,
const std::vector<ExprWrapper>& hw_dims,
const std::vector<int>& strides,
const std::vector<ExprWrapper>& k_dims,
nvinfer1::IExprBuilder& expr_builder // NOLINT
) {
if (paddings_wrap->size() == hw_dims.size()) {
for (size_t i = 0; i < hw_dims.size(); ++i) {
auto copy_pad = *(paddings_wrap->begin() + 2 * i);
paddings_wrap->insert(paddings_wrap->begin() + 2 * i + 1, copy_pad);
}
} else {
CHECK_EQ(hw_dims.size() == paddings_wrap->size(), true);
}
// when padding_algorithm is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < hw_dims.size(); ++i) {
auto out_size = (hw_dims[i] + strides[i] - 1) / strides[i];
auto pad_sum =
max((out_size - 1) * strides[i] + k_dims[i] - hw_dims[i], 0);
auto pad_0 = pad_sum / 2;
auto pad_1 = pad_sum - pad_0;
*(paddings_wrap->begin() + i * 2) = pad_0;
*(paddings_wrap->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilation->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto it = paddings_wrap->begin(); it != paddings_wrap->end(); it++) {
*it = ExprWrapper(0, &expr_builder);
}
}
}
// Here are all examples of using h(height), ok for weight too.
inline ExprWrapper ConvOutputSize(ExprWrapper ih,
ExprWrapper kh,
int dilation_h,
ExprWrapper pad_h0,
ExprWrapper pad_h1,
int stride_h) {
ExprWrapper oh =
(ih + pad_h0 + pad_h1 - dilation_h * (kh - 1) - 1) / stride_h + 1;
return oh;
}
nvinfer1::DimsExprs Conv2dFusionInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
// we may update dilations.
std::vector<int> dilations =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations"));
const std::vector<int> strides =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
std::vector<int> paddings =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm"))
padding_algorithm =
PADDLE_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "VALID") {
for (size_t i = 0; i < paddings.size(); i++) {
paddings[i] = 0;
}
}
// TODO(zhangjun): nhwc support
bool channel_last = false;
// conv_fusion: input, filter, bias
const nvinfer1::DimsExprs input_dims = inputs[0];
const nvinfer1::DimsExprs filter_dims = inputs[1];
auto input_dims_wrap = DimsExprs2VecExprWrapper(input_dims, expr_builder);
auto filter_dims_wrap = DimsExprs2VecExprWrapper(filter_dims, expr_builder);
std::vector<ExprWrapper> hw_dims_wrap; // d, h, w
if (channel_last) {
for (int i = 1; i < input_dims.nbDims - 1; ++i) {
hw_dims_wrap.emplace_back(input_dims_wrap[i]);
}
} else {
for (int i = 2; i < input_dims.nbDims; ++i) {
hw_dims_wrap.emplace_back(input_dims_wrap[i]);
}
}
std::vector<ExprWrapper> filter_hw_dims_wrap; // filter_h, filter_w
if (channel_last) {
for (int i = 1; i < filter_dims.nbDims - 1; ++i) {
filter_hw_dims_wrap.emplace_back(filter_dims_wrap[i]);
}
} else {
for (int i = 2; i < filter_dims.nbDims; ++i) {
filter_hw_dims_wrap.emplace_back(filter_dims_wrap[i]);
}
}
std::vector<ExprWrapper> paddings_wrap;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings_wrap.emplace_back(ExprWrapper(paddings[i], &expr_builder));
}
UpdatePaddingAndDilation(&paddings_wrap,
&dilations,
padding_algorithm,
hw_dims_wrap,
strides,
filter_hw_dims_wrap,
expr_builder);
std::vector<ExprWrapper> output_dims_wrap(input_dims.nbDims);
int out_idx = 0;
output_dims_wrap[out_idx++] = input_dims_wrap[0];
if (!channel_last) {
output_dims_wrap[out_idx++] = filter_dims_wrap[0];
}
for (size_t i = 0; i < hw_dims_wrap.size(); ++i) {
output_dims_wrap[out_idx++] = ConvOutputSize(hw_dims_wrap[i],
filter_hw_dims_wrap[i],
dilations[i],
paddings_wrap[2 * i],
paddings_wrap[2 * i + 1],
strides[i]);
}
if (channel_last) {
output_dims_wrap[out_idx++] = filter_dims_wrap[0];
}
return VecExprWrapper2DimsExprs(output_dims_wrap);
}
nvinfer1::DimsExprs LookupTableV2InferMeta( nvinfer1::DimsExprs LookupTableV2InferMeta(
int output_index, int output_index,
const nvinfer1::DimsExprs* inputs, const nvinfer1::DimsExprs* inputs,
...@@ -435,6 +693,85 @@ nvinfer1::DimsExprs LookupTableV2InferMeta( ...@@ -435,6 +693,85 @@ nvinfer1::DimsExprs LookupTableV2InferMeta(
return output; return output;
} }
nvinfer1::DimsExprs Conv2dTransposeInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
auto x_dims = inputs[0];
auto filter_dims = inputs[1];
std::vector<ExprWrapper> x_dims_wrap =
DimsExprs2VecExprWrapper(x_dims, expr_builder);
std::vector<ExprWrapper> filter_dims_wrap =
DimsExprs2VecExprWrapper(filter_dims, expr_builder);
const std::vector<int> dilations =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations"));
const std::vector<int> strides =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
std::vector<int> paddings =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::vector<int> output_size =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("output_size"));
std::vector<int> output_padding =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("output_padding"));
auto data_format =
PADDLE_GET_CONST(std::string, op_desc.GetAttr("data_format"));
int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups"));
std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm")) {
padding_algorithm =
PADDLE_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
}
CHECK_EQ(padding_algorithm == "EXPLICIT", true);
CHECK_EQ(data_format == "NCHW", true);
CHECK_EQ(output_size.size() == 0, true);
CHECK_EQ(paddings.size() == 2, true);
CHECK_EQ(x_dims.nbDims == 4, true);
CHECK_EQ(x_dims.nbDims == filter_dims.nbDims, true);
CHECK_EQ(output_padding.size() == 0, true);
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
CHECK_EQ(strides[i] > 0, true);
}
int in_sub_stride_size = x_dims.nbDims - stride_size;
CHECK_EQ(in_sub_stride_size == 2, true);
if (output_size.size()) {
CHECK_EQ(output_size.size() == strides.size(), true);
}
if (output_padding.size()) {
CHECK_EQ(strides.size() == output_padding.size(), true);
}
std::vector<ExprWrapper> output_dims_wrap(x_dims.nbDims);
output_dims_wrap[0] = x_dims_wrap[0];
output_dims_wrap[1] = filter_dims_wrap[1] * groups;
auto ih = x_dims_wrap[2];
auto iw = x_dims_wrap[3];
auto kh = filter_dims_wrap[2];
auto kw = filter_dims_wrap[3];
int pad_h0 = paddings[0];
int pad_h1 = paddings[0];
int pad_w0 = paddings[1];
int pad_w1 = paddings[1];
output_dims_wrap[2] =
(ih - 1) * strides[0] - pad_h0 - pad_h1 + (kh - 1) * dilations[0] + 1;
output_dims_wrap[3] =
(iw - 1) * strides[1] - pad_w0 - pad_w1 + (kw - 1) * dilations[1] + 1;
return VecExprWrapper2DimsExprs(output_dims_wrap);
}
PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
...@@ -444,6 +781,9 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta); ...@@ -444,6 +781,9 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(conv2d_fusion, Conv2dFusionInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(conv2d, Conv2dFusionInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(conv2d_transpose, Conv2dTransposeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta);
} // namespace tensorrt } // namespace tensorrt
......
...@@ -28,6 +28,9 @@ USE_TRT_DYNAMIC_INFER_META_FN(scatter_nd_add); ...@@ -28,6 +28,9 @@ USE_TRT_DYNAMIC_INFER_META_FN(scatter_nd_add);
USE_TRT_DYNAMIC_INFER_META_FN(pad3d); USE_TRT_DYNAMIC_INFER_META_FN(pad3d);
USE_TRT_DYNAMIC_INFER_META_FN(inverse); USE_TRT_DYNAMIC_INFER_META_FN(inverse);
USE_TRT_DYNAMIC_INFER_META_FN(grid_sampler); USE_TRT_DYNAMIC_INFER_META_FN(grid_sampler);
USE_TRT_DYNAMIC_INFER_META_FN(conv2d_fusion);
USE_TRT_DYNAMIC_INFER_META_FN(conv2d);
USE_TRT_DYNAMIC_INFER_META_FN(conv2d_transpose);
USE_TRT_DYNAMIC_INFER_META_FN(p_norm); USE_TRT_DYNAMIC_INFER_META_FN(p_norm);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册