未验证 提交 ad349e77 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #14452 from NHZlX/fix_avg_pool_trt_bug

fix avg pool trt bug 
......@@ -18,7 +18,7 @@ nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine conv_op conv_transpose_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL)
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op tensorrt_plugin SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
elementwise_add_op elementwise_mul_op SERIAL)
......
......@@ -13,25 +13,57 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
void DealCeilMode(const nvinfer1::Dims &input_shape, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
nvinfer1::DimsHW *pre_pad, nvinfer1::DimsHW *post_pad,
int input_dims) {
int input_height = input_shape.d[input_dims - 2];
int input_width = input_shape.d[input_dims - 1];
int floor_h_output_size =
(input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
int ceil_h_output_size =
(input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) /
strides[0] +
1;
int floor_w_output_size =
(input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
int ceil_w_output_size =
(input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) / strides[1] +
1;
if (floor_h_output_size != ceil_h_output_size) {
post_pad->h() = strides[0] - 1;
}
if (floor_w_output_size != ceil_w_output_size) {
post_pad->w() = strides[1] - 1;
}
}
/*
* Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights.
*/
class Pool2dOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3)
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope, bool test_mode) override {
VLOG(40)
<< "convert a fluid pool2d op to tensorrt pool2d layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto *input1 = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_shape = input1->getDimensions();
int input_dims = input_shape.nbDims;
PADDLE_ENFORCE_EQ(input_dims, 3UL);
bool global_pooling = boost::get<bool>(op_desc.GetAttr("global_pooling"));
std::string pool_type =
......@@ -44,23 +76,6 @@ class Pool2dOpConverter : public OpConverter {
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
if (global_pooling == true) {
nv_ksize.d[0] = input_shape.d[nbDims - 2];
nv_ksize.d[1] = input_shape.d[nbDims - 1];
nv_strides.h() = 1;
nv_strides.w() = 1;
nv_paddings.h() = 0;
nv_paddings.w() = 0;
}
PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL);
nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX;
if (pool_type == "max") {
nv_pool_type = nvinfer1::PoolingType::kMAX;
......@@ -70,42 +85,63 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_THROW("TensorRT unsupported pooling type!");
}
if (ceil_mode) {
nvinfer1::DimsHW pre_pad(0, 0);
nvinfer1::DimsHW post_pad(0, 0);
int input_height = input_shape.d[nbDims - 2];
int input_width = input_shape.d[nbDims - 1];
int floor_h_output_size =
(input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
int ceil_h_output_size =
(input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) /
strides[0] +
1;
int floor_w_output_size =
(input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
int ceil_w_output_size =
(input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) /
strides[1] +
1;
if (floor_h_output_size != ceil_h_output_size) {
post_pad.h() = strides[0] - 1;
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
nvinfer1::ILayer *layer = nullptr;
if (global_pooling == true) {
nv_ksize.d[0] = input_shape.d[input_dims - 2];
nv_ksize.d[1] = input_shape.d[input_dims - 1];
auto *layer = TRT_ENGINE_ADD_LAYER(
engine_, Pooling, *const_cast<nvinfer1::ITensor *>(input1),
nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL(layer, "pool layer could not be created.");
auto output_name = op_desc.Output("Out")[0];
layer->setName(("pool2d (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
return;
}
if (floor_w_output_size != ceil_w_output_size) {
post_pad.w() = strides[1] - 1;
if (pool_type == "max") {
nvinfer1::DimsHW pre_pad(paddings[0], paddings[1]);
nvinfer1::DimsHW post_pad(paddings[0], paddings[1]);
if (ceil_mode) {
// If ceil mode is true, we will pad the appropriate size to the input.
DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad, &post_pad,
input_dims);
auto *pad_layer = TRT_ENGINE_ADD_LAYER(
engine_, Padding, *const_cast<nvinfer1::ITensor *>(input1), pre_pad,
post_pad);
PADDLE_ENFORCE_NOT_NULL(
pad_layer, "pad layer in poolOp converter could not be created.");
input1 = pad_layer->getOutput(0);
}
auto *pool_layer = TRT_ENGINE_ADD_LAYER(
engine_, Pooling, *const_cast<nvinfer1::ITensor *>(input1),
nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL(pool_layer, "pool layer could not be created.");
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
layer = pool_layer;
} else {
// Average pooling needs to exclude the padding pixels from the average
// mean.
// It is not supported well by TRT, we use a plugin here.
std::vector<int> input_shape_v;
for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]);
}
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Padding, *const_cast<nvinfer1::ITensor*>(input1), pre_pad,
post_pad);
input1 = layer->getOutput(0);
plugin::AvgPoolPlugin *plugin = new plugin::AvgPoolPlugin(
ceil_mode, ksize, strides, paddings, input_shape_v);
auto *avg_pool_layer = engine_->AddPlugin(&input1, 1, plugin);
layer = avg_pool_layer;
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling,
*const_cast<nvinfer1::ITensor*>(input1),
nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL(layer, "pool layer could not be created.");
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
auto output_name = op_desc.Output("Out")[0];
layer->setName(("pool2d (Output: " + output_name + ")").c_str());
......
......@@ -20,20 +20,21 @@ namespace paddle {
namespace inference {
namespace tensorrt {
void test_pool2d(bool global_pooling, bool ceil_mode) {
void test_pool2d(bool global_pooling, bool ceil_mode,
std::string pool_type = "max") {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 13, 14));
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 6, 7));
if (global_pooling)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
else if (ceil_mode)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 7));
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 3, 4));
else
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 6));
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 3, 3));
// Prepare Op description
framework::OpDesc desc;
......@@ -41,10 +42,10 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
desc.SetInput("X", {"pool2d-X"});
desc.SetOutput("Out", {"pool2d-Out"});
std::vector<int> ksize({3, 3});
std::vector<int> ksize({2, 2});
std::vector<int> strides({2, 2});
std::vector<int> paddings({0, 0});
std::string pooling_t = "max";
std::string pooling_t = pool_type;
desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize);
......@@ -63,7 +64,8 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
TEST(Pool2dOpConverter, test_ceil_mode) { test_pool2d(false, true); }
TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); }
TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); }
} // namespace tensorrt
} // namespace inference
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
nvinfer1::Dims AvgPoolPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* inputDims, int nbInputs) {
assert(nbInputs == 1);
assert(index == 0);
assert(inputDims[0].nbDims == 3);
nvinfer1::Dims const& input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
output_dims.d[1] = output_shape_[1];
output_dims.d[2] = output_shape_[2];
return output_dims;
}
int AvgPoolPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace,
cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0);
int input_size = 0;
float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs);
paddle::operators::math::AvgPool<float> pool_process;
paddle::operators::math::Pool2dDirectCUDAFunctor<
paddle::operators::math::AvgPool<float>, float>
pool2d_forward;
std::vector<int> input_shape = input_shape_;
std::vector<int> output_shape = output_shape_;
input_shape.insert(input_shape.begin(), batchSize);
output_shape.insert(output_shape.begin(), batchSize);
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_,
pool_process, true, odatas[0], stream);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class AvgPoolPlugin : public PluginTensorRT {
private:
bool ceil_mode_;
std::vector<int> ksize_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
protected:
size_t getSerializationSize() override {
return SerializedSize(ceil_mode_) + SerializedSize(ksize_) +
SerializedSize(strides_) + SerializedSize(paddings_) +
SerializedSize(input_shape_) + getBaseSerializationSize();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, input_shape_);
}
public:
AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
std::vector<int> input_shape)
: ceil_mode_(ceil_mode),
ksize_(ksize),
strides_(strides),
paddings_(paddings),
input_shape_(input_shape) {
int output_h, output_w;
output_shape_ = input_shape_;
if (!ceil_mode_) {
output_h =
(input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1;
output_w =
(input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1;
} else {
output_h =
(input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) /
strides_[0] +
1;
output_w =
(input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) /
strides_[1] +
1;
}
output_shape_[1] = output_h;
output_shape_[2] = output_w;
}
// It was used for tensorrt deserialization.
// It should not be called by users.
AvgPoolPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &ceil_mode_);
DeserializeValue(&serialData, &serialLength, &ksize_);
DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_);
DeserializeValue(&serialData, &serialLength, &input_shape_);
}
AvgPoolPlugin *clone() const override {
return new AvgPoolPlugin(ceil_mode_, ksize_, strides_, paddings_,
input_shape_);
}
const char *getPluginType() const override { return "avg_pool"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int initialize() override { return 0; }
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -153,6 +153,37 @@ __global__ void KernelMaxPool2DGrad(
}
}
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
const T* input, const std::vector<int>& input_shape,
const std::vector<int>& output_shape, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& paddings,
PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) {
const int batch_size = input_shape[0];
const int input_channels = input_shape[1];
const int input_height = input_shape[2];
const int input_width = input_shape[3];
const int output_channels = output_shape[1];
const int output_height = output_shape[2];
const int output_width = output_shape[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
int nthreads = batch_size * output_channels * output_height * output_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
nthreads, input, input_channels, input_height, input_width, output_height,
output_width, ksize_height, ksize_width, stride_height, stride_width,
padding_height, padding_width, pool_compute, exclusive, output);
}
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
......@@ -291,6 +322,11 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
}
};
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
......
......@@ -82,6 +82,19 @@ class AvgPoolGrad {
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
#ifdef PADDLE_WITH_CUDA
template <typename PoolProcess, typename T>
class Pool2dDirectCUDAFunctor {
public:
void operator()(const T* input, const std::vector<int>& input_shape,
const std::vector<int>& output_shape,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, T* output, cudaStream_t stream);
};
#endif
template <typename DeviceContext, typename PoolProcess, typename T>
class Pool2dFunctor {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册