未验证 提交 8f156fd7 编写于 作者: G gaoziyuan 提交者: GitHub

[Hackathon NO.74] 为 Paddle-TRT 添加 grid_sampler 算子 (#50934)

上级 51331098
......@@ -2539,6 +2539,9 @@ USE_TRT_CONVERTER(preln_groupnorm_act)
USE_TRT_CONVERTER(flash_multihead_matmul)
USE_TRT_CONVERTER(cross_multihead_matmul)
#endif
#if IS_TRT_VERSION_GE(8510)
USE_TRT_CONVERTER(grid_sampler)
#endif
#if IS_TRT_VERSION_GE(8200)
USE_TRT_CONVERTER(set_value)
#endif
......
......@@ -27,6 +27,7 @@ list(
multihead_matmul_roformer_op.cc
flash_multihead_matmul_op.cc
cross_multihead_matmul_op.cc
grid_sampler_op.cc
shuffle_channel_op.cc
fill_any_like_op.cc
where_op.cc
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* GridSampler Op
*/
class GridSamplerOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(8510)
VLOG(3) << "convert a fluid grid_sampler op to tensorrt GridSample layer";
framework::OpDesc op_desc(op, nullptr);
std::string input_x_name = op_desc.Input("X").front();
std::string input_grid_name = op_desc.Input("Grid").front();
std::string output_name = op_desc.Output("Output").front();
auto* input_x_tensor = engine_->GetITensor(input_x_name);
auto* input_grid_tensor = engine_->GetITensor(input_grid_name);
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, GridSample, *input_x_tensor, *input_grid_tensor);
const std::string mode =
PADDLE_GET_CONST(std::string, op_desc.GetAttr("mode"));
const std::string padding_mode =
PADDLE_GET_CONST(std::string, op_desc.GetAttr("padding_mode"));
const bool align_corners =
PADDLE_GET_CONST(bool, op_desc.GetAttr("align_corners"));
nvinfer1::InterpolationMode interpolationMode{
nvinfer1::InterpolationMode::kNEAREST};
if (mode == "nearest") {
interpolationMode = nvinfer1::ResizeMode::kNEAREST;
} else if (mode == "bilinear") {
interpolationMode = nvinfer1::ResizeMode::kLINEAR;
}
nvinfer1::SampleMode sampleMode{nvinfer1::SampleMode::kFILL};
if (padding_mode == "zeros") {
sampleMode = nvinfer1::SampleMode::kFILL;
} else if (padding_mode == "border") {
sampleMode = nvinfer1::SampleMode::kCLAMP;
} else if (padding_mode == "reflection") {
sampleMode = nvinfer1::SampleMode::kREFLECT;
}
layer->setInterpolationMode(interpolationMode);
layer->setSampleMode(sampleMode);
layer->setAlignCorners(align_corners);
RreplenishLayerAndOutput(layer, "grid_sampler", {output_name}, test_mode);
#else
VLOG(3) << "grid_sampler is not supported when TensorRT < 8.5.1";
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(grid_sampler, GridSamplerOpConverter);
......@@ -377,7 +377,7 @@ nvinfer1::DimsExprs GridSamplerInferMeta(
output.d[2] = grid_dims.d[1];
output.d[3] = grid_dims.d[2];
} else {
output.nbDims = 4;
output.nbDims = 5;
output.d[0] = x_dims.d[0];
output.d[1] = x_dims.d[1];
output.d[2] = grid_dims.d[1];
......
......@@ -2542,6 +2542,48 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "grid_sampler") {
#if !IS_TRT_VERSION_GE(8510)
VLOG(3) << "grid_sampler is not supported when TensorRT < 8.5.1";
return false;
#else
if (!with_dynamic_shape) {
VLOG(3) << "the grid_sampler does not support "
"static shape yet";
return false;
}
if (!desc.HasAttr("mode") || !desc.HasAttr("padding_mode") ||
!desc.HasAttr("align_corners")) {
VLOG(3) << "grid_sampler need attributes : mode, padding_mode, "
"align_corners";
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto input_name = desc.Input("X")[0];
auto* input_desc = block->FindVar(input_name);
const auto input_shape = input_desc->GetShape();
auto grid_name = desc.Input("Grid")[0];
auto* grid_desc = block->FindVar(grid_name);
const auto grid_shape = grid_desc->GetShape();
if (input_shape.size() != 4 || grid_shape.size() != 4) {
VLOG(3) << "The input and grid tensors must be shape tensors of rank 4 "
"using TRT GridSample layer.";
return false;
}
#endif
}
if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
} else {
......@@ -2701,7 +2743,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"expand_v2",
"fuse_eleadd_transpose",
"skip_groupnorm_act",
"preln_groupnorm_act"};
"preln_groupnorm_act",
"grid_sampler"};
std::unordered_set<std::string> teller_set{
"mul",
......@@ -2853,7 +2896,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"expand_v2",
"fuse_eleadd_transpose",
"skip_groupnorm_act",
"preln_groupnorm_act"};
"preln_groupnorm_act",
"grid_sampler"};
};
struct GenericPluginTeller : public Teller {
......
......@@ -25,61 +25,103 @@ import paddle.inference as paddle_infer
class TrtConvertGridSampler(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
self.trt_param.workspace_size = 1073741824
return True
def sample_program_configs(self):
def generate_input1():
return np.random.random([1, 3, 32, 32]).astype(np.float32)
if self.dims == 4:
self.input_shape = [1, 3, 32, 32]
return np.random.random([1, 3, 32, 32]).astype(np.float32)
elif self.dims == 5:
self.input_shape = [1, 3, 32, 32, 64]
return np.random.random([1, 3, 32, 32, 64]).astype(np.float32)
def generate_input2():
return np.random.random([1, 3, 3, 2]).astype(np.float32)
ops_config = [
{
"op_type": "grid_sampler",
"op_inputs": {
"X": ["input_data"],
"Grid": ["grid_data"],
},
"op_outputs": {"Output": ["output_data"]},
"op_attrs": {},
}
]
ops = self.generate_op_config(ops_config)
for i in range(10):
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1)
),
"grid_data": TensorConfig(
data_gen=partial(generate_input2)
),
},
outputs=["output_data"],
)
yield program_config
if self.dims == 4:
self.input_shape = [1, 3, 3, 2]
return np.random.random([1, 3, 3, 2]).astype(np.float32)
elif self.dims == 5:
self.input_shape = [1, 3, 3, 2, 3]
return np.random.random([1, 3, 3, 2, 3]).astype(np.float32)
mode = ["bilinear", "nearest"]
padding_mode = ["zeros", "reflection", "border"]
align_corners = [True, False]
descs = []
for m in mode:
for p in padding_mode:
for a in align_corners:
descs.append(
{
"mode": m,
"padding_mode": p,
"align_corners": a,
}
)
for dims in [4, 5]:
for desc in descs:
self.dims = dims
ops_config = [
{
"op_type": "grid_sampler",
"op_inputs": {
"X": ["input_data"],
"Grid": ["grid_data"],
},
"op_outputs": {"Output": ["output_data"]},
"op_attrs": desc,
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1)
),
"grid_data": TensorConfig(
data_gen=partial(generate_input2)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 32, 32],
"grid_data": [1, 3, 3, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [1, 3, 64, 64],
"grid_data": [1, 3, 4, 4],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 32, 32],
"grid_data": [1, 3, 3, 2],
}
def generate_dynamic_shape():
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 32, 32],
"grid_data": [1, 3, 3, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [1, 3, 64, 64],
"grid_data": [1, 3, 6, 2],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 32, 32],
"grid_data": [1, 3, 3, 2],
}
elif self.dims == 5:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 32, 32, 64],
"grid_data": [1, 3, 3, 2, 3],
}
self.dynamic_shape.max_input_shape = {
"input_data": [1, 3, 64, 64, 128],
"grid_data": [1, 3, 3, 6, 3],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 32, 32, 64],
"grid_data": [1, 3, 3, 2, 3],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
......@@ -92,13 +134,9 @@ class TrtConvertGridSampler(TrtLayerAutoScanTest):
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 4), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-3
# for dynamic_shape
generate_dynamic_shape(attrs)
generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册