未验证 提交 1948210c 编写于 作者: P Pei Yang 提交者: GitHub

Bug Fix: Paddle-TRT cannot handle adaptive pooling in pool2d op converter and...

Bug Fix: Paddle-TRT cannot handle adaptive pooling in pool2d op converter and "num" attribute in split op converter (#20733) (#20902)

* fix pool2d trt converter, test=develop

* add fix for split op converter, test=develop
上级 6fb04e8a
...@@ -213,7 +213,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -213,7 +213,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
for (auto *x : node->inputs) { for (auto *x : node->inputs) {
if (x->IsVar() && x->Var()) { if (x->IsVar() && x->Var()) {
framework::VarDesc *var = x->Var(); framework::VarDesc *var = x->Var();
SetAttr(op_desc->Proto(), var->Name() + "_shape", var->GetShape()); op_desc->SetAttr(var->Name() + "_shape", var->GetShape());
} }
} }
......
...@@ -507,7 +507,6 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -507,7 +507,6 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
} }
} }
if (config.glog_info_disabled()) { if (config.glog_info_disabled()) {
google::InitGoogleLogging("Init");
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
FLAGS_minloglevel = google::WARNING; FLAGS_minloglevel = google::WARNING;
LOG(WARNING) << " - GLOG's LOG(INFO) is disabled."; LOG(WARNING) << " - GLOG's LOG(INFO) is disabled.";
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -75,12 +75,19 @@ class Pool2dOpConverter : public OpConverter { ...@@ -75,12 +75,19 @@ class Pool2dOpConverter : public OpConverter {
std::vector<int> paddings = std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode")); bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
bool adaptive = false;
if (op_desc.HasAttr("adaptive"))
adaptive = boost::get<bool>(op_desc.GetAttr("adaptive"));
nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX; nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX;
plugin::PoolPlugin::PoolType plugin_pool_type =
plugin::PoolPlugin::PoolType::max;
if (pool_type == "max") { if (pool_type == "max") {
nv_pool_type = nvinfer1::PoolingType::kMAX; nv_pool_type = nvinfer1::PoolingType::kMAX;
plugin_pool_type = plugin::PoolPlugin::PoolType::max;
} else if (pool_type == "avg") { } else if (pool_type == "avg") {
nv_pool_type = nvinfer1::PoolingType::kAVERAGE; nv_pool_type = nvinfer1::PoolingType::kAVERAGE;
plugin_pool_type = plugin::PoolPlugin::PoolType::avg;
} else { } else {
PADDLE_THROW("TensorRT unsupported pooling type!"); PADDLE_THROW("TensorRT unsupported pooling type!");
} }
...@@ -108,7 +115,7 @@ class Pool2dOpConverter : public OpConverter { ...@@ -108,7 +115,7 @@ class Pool2dOpConverter : public OpConverter {
return; return;
} }
if (pool_type == "max") { if (!adaptive && pool_type == "max") {
// Under ceil mode, the pre_pad and post_pad are used to // Under ceil mode, the pre_pad and post_pad are used to
// record the the padding size. In some ceil mode cases, // record the the padding size. In some ceil mode cases,
// we do not need padding, so we initialize the two vars to 0. // we do not need padding, so we initialize the two vars to 0.
...@@ -141,10 +148,13 @@ class Pool2dOpConverter : public OpConverter { ...@@ -141,10 +148,13 @@ class Pool2dOpConverter : public OpConverter {
for (int i = 0; i < input_dims; i++) { for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]); input_shape_v.push_back(input_shape.d[i]);
} }
plugin::AvgPoolPlugin *plugin = new plugin::AvgPoolPlugin( plugin::PoolPlugin *plugin =
ceil_mode, ksize, strides, paddings, input_shape_v); new plugin::PoolPlugin(ceil_mode, plugin_pool_type, adaptive, ksize,
auto *avg_pool_layer = engine_->AddPlugin(&input1, 1, plugin); strides, paddings, input_shape_v);
layer = avg_pool_layer; PADDLE_ENFORCE_NOT_NULL(plugin->getPluginType(),
"The plugin used must not be null");
auto *pool_layer = engine_->AddPlugin(&input1, 1, plugin);
layer = pool_layer;
} }
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
......
...@@ -35,12 +35,23 @@ class SplitOpConverter : public OpConverter { ...@@ -35,12 +35,23 @@ class SplitOpConverter : public OpConverter {
// Get Attrs // Get Attrs
PADDLE_ENFORCE(input_num == 1); PADDLE_ENFORCE(input_num == 1);
int axis = boost::get<int>(op_desc.GetAttr("axis")); int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
// split on batch is not supported in TensorRT // split on batch is not supported in TensorRT
PADDLE_ENFORCE(axis != 0); PADDLE_ENFORCE(axis != 0);
axis += (axis < 0) ? input_dims.nbDims : -1; axis += (axis < 0) ? input_dims.nbDims : -1;
std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
output_lengths.reserve(output_num);
int num = boost::get<int>(op_desc.GetAttr("num"));
if (num > 0) {
int64_t in_axis_dim = input_dims.d[axis];
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
"Tensor split does not result"
" in an equal division");
size_t out_axis_dim = in_axis_dim / num;
for (size_t i = 0; i < output_num; ++i) {
output_lengths.push_back(out_axis_dim);
}
}
PADDLE_ENFORCE(output_lengths.size() == output_num); PADDLE_ENFORCE(output_lengths.size() == output_num);
plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths); plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer = nvinfer1::IPluginLayer* layer =
......
nv_library(tensorrt_plugin nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc prelu_op_plugin.cu trt_plugin_factory.cc
avg_pool_op_plugin.cu swish_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu) DEPS enforce tensorrt_engine prelu)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
...@@ -21,14 +21,14 @@ namespace inference { ...@@ -21,14 +21,14 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
AvgPoolPlugin* CreateAvgPoolPluginDeserialize(const void* buffer, PoolPlugin* CreatePoolPluginDeserialize(const void* buffer, size_t length) {
size_t length) { return new PoolPlugin(buffer, length);
return new AvgPoolPlugin(buffer, length);
} }
REGISTER_TRT_PLUGIN("avg_pool_plugin", CreateAvgPoolPluginDeserialize); REGISTER_TRT_PLUGIN("pool_plugin", CreatePoolPluginDeserialize);
nvinfer1::Dims AvgPoolPlugin::getOutputDimensions( nvinfer1::Dims PoolPlugin::getOutputDimensions(int index,
int index, const nvinfer1::Dims* inputDims, int nbInputs) { const nvinfer1::Dims* inputDims,
int nbInputs) {
assert(nbInputs == 1); assert(nbInputs == 1);
assert(index == 0); assert(index == 0);
assert(inputDims[0].nbDims == 3); assert(inputDims[0].nbDims == 3);
...@@ -41,26 +41,33 @@ nvinfer1::Dims AvgPoolPlugin::getOutputDimensions( ...@@ -41,26 +41,33 @@ nvinfer1::Dims AvgPoolPlugin::getOutputDimensions(
return output_dims; return output_dims;
} }
int AvgPoolPlugin::enqueue(int batchSize, const void* const* inputs, int PoolPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, void** outputs, void* workspace, cudaStream_t stream) {
cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0); auto const& input_dims = this->getInputDims(0);
int input_size = 0; int input_size = 0;
float const* idata = reinterpret_cast<float const*>(inputs[0]); float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs); float** odatas = reinterpret_cast<float**>(outputs);
paddle::operators::math::AvgPool<float> pool_process;
paddle::operators::math::Pool2dDirectCUDAFunctor<
paddle::operators::math::AvgPool<float>, float>
pool2d_forward;
std::vector<int> input_shape = input_shape_; std::vector<int> input_shape = input_shape_;
std::vector<int> output_shape = output_shape_; std::vector<int> output_shape = output_shape_;
input_shape.insert(input_shape.begin(), batchSize); input_shape.insert(input_shape.begin(), batchSize);
output_shape.insert(output_shape.begin(), batchSize); output_shape.insert(output_shape.begin(), batchSize);
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_, if (pool_type_ == PoolType::max) {
pool_process, true, odatas[0], stream); paddle::operators::math::MaxPool<float> pool_process;
paddle::operators::math::Pool2dDirectCUDAFunctor<
paddle::operators::math::MaxPool<float>, float>
pool2d_forward;
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_,
paddings_, pool_process, true, adaptive_, odatas[0], stream);
} else if (pool_type_ == PoolType::avg) {
paddle::operators::math::AvgPool<float> pool_process;
paddle::operators::math::Pool2dDirectCUDAFunctor<
paddle::operators::math::AvgPool<float>, float>
pool2d_forward;
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_,
paddings_, pool_process, true, adaptive_, odatas[0], stream);
}
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <stdio.h>
#include <cassert> #include <cassert>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...@@ -22,18 +24,11 @@ namespace inference { ...@@ -22,18 +24,11 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
class AvgPoolPlugin : public PluginTensorRT { class PoolPlugin : public PluginTensorRT {
private:
bool ceil_mode_;
std::vector<int> ksize_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + SerializedSize(ceil_mode_) + return SerializedSize(getPluginType()) + SerializedSize(ceil_mode_) +
SerializedSize(pool_type_) + SerializedSize(adaptive_) +
SerializedSize(ksize_) + SerializedSize(strides_) + SerializedSize(ksize_) + SerializedSize(strides_) +
SerializedSize(paddings_) + SerializedSize(input_shape_) + SerializedSize(paddings_) + SerializedSize(input_shape_) +
SerializedSize(output_shape_) + getBaseSerializationSize(); SerializedSize(output_shape_) + getBaseSerializationSize();
...@@ -45,6 +40,8 @@ class AvgPoolPlugin : public PluginTensorRT { ...@@ -45,6 +40,8 @@ class AvgPoolPlugin : public PluginTensorRT {
SerializeValue(&buffer, getPluginType()); SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_); SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, pool_type_);
SerializeValue(&buffer, adaptive_);
SerializeValue(&buffer, ksize_); SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_); SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_); SerializeValue(&buffer, paddings_);
...@@ -53,41 +50,54 @@ class AvgPoolPlugin : public PluginTensorRT { ...@@ -53,41 +50,54 @@ class AvgPoolPlugin : public PluginTensorRT {
} }
public: public:
AvgPoolPlugin() {} enum class PoolType {
AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize, max = 0,
std::vector<int> strides, std::vector<int> paddings, avg,
std::vector<int> input_shape) };
PoolPlugin() {}
PoolPlugin(bool ceil_mode, PoolType pool_type, bool adaptive,
std::vector<int> ksize, std::vector<int> strides,
std::vector<int> paddings, std::vector<int> input_shape)
: ceil_mode_(ceil_mode), : ceil_mode_(ceil_mode),
pool_type_(pool_type),
adaptive_(adaptive),
ksize_(ksize), ksize_(ksize),
strides_(strides), strides_(strides),
paddings_(paddings), paddings_(paddings),
input_shape_(input_shape) { input_shape_(input_shape) {
int output_h, output_w;
output_shape_ = input_shape_; output_shape_ = input_shape_;
if (!ceil_mode_) { if (adaptive_) {
output_h = output_shape_[1] = ksize[0];
(input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1; output_shape_[2] = ksize[1];
output_w =
(input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1;
} else { } else {
output_h = int output_h, output_w;
(input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / if (!ceil_mode_) {
strides_[0] + output_h =
1; (input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1;
output_w = output_w =
(input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / (input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1;
strides_[1] + } else {
1; output_h =
(input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) /
strides_[0] +
1;
output_w =
(input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) /
strides_[1] +
1;
}
output_shape_[1] = output_h;
output_shape_[2] = output_w;
} }
output_shape_[1] = output_h;
output_shape_[2] = output_w;
} }
// It was used for tensorrt deserialization. // It was used for tensorrt deserialization.
// It should not be called by users. // It should not be called by users.
AvgPoolPlugin(void const *serialData, size_t serialLength) { PoolPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength); deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &ceil_mode_); DeserializeValue(&serialData, &serialLength, &ceil_mode_);
DeserializeValue(&serialData, &serialLength, &pool_type_);
DeserializeValue(&serialData, &serialLength, &adaptive_);
DeserializeValue(&serialData, &serialLength, &ksize_); DeserializeValue(&serialData, &serialLength, &ksize_);
DeserializeValue(&serialData, &serialLength, &strides_); DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_); DeserializeValue(&serialData, &serialLength, &paddings_);
...@@ -95,18 +105,28 @@ class AvgPoolPlugin : public PluginTensorRT { ...@@ -95,18 +105,28 @@ class AvgPoolPlugin : public PluginTensorRT {
DeserializeValue(&serialData, &serialLength, &output_shape_); DeserializeValue(&serialData, &serialLength, &output_shape_);
} }
AvgPoolPlugin *clone() const override { PoolPlugin *clone() const override {
return new AvgPoolPlugin(ceil_mode_, ksize_, strides_, paddings_, return new PoolPlugin(ceil_mode_, pool_type_, adaptive_, ksize_, strides_,
input_shape_); paddings_, input_shape_);
} }
const char *getPluginType() const override { return "avg_pool_plugin"; } const char *getPluginType() const override { return "pool_plugin"; }
int getNbOutputs() const override { return 1; } int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override; int nbInputDims) override;
int initialize() override { return 0; } int initialize() override { return 0; }
int enqueue(int batchSize, const void *const *inputs, void **outputs, int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override; void *workspace, cudaStream_t stream) override;
private:
bool ceil_mode_;
PoolType pool_type_;
bool adaptive_;
std::vector<int> ksize_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
}; };
} // namespace plugin } // namespace plugin
......
...@@ -268,6 +268,10 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -268,6 +268,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}) if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_inference_test_models.tar.gz") inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_inference_test_models.tar.gz")
endif() endif()
set(TEST_SPLIT_CONVERTER_MODEL "${TRT_MODEL_INSTALL_DIR}/trt_split_op_converter_test")
if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL})
inference_download_and_uncompress(${TEST_SPLIT_CONVERTER_MODEL} ${INFERENCE_URL}/tensorrt_test "split_converter.tgz")
endif()
inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
...@@ -283,6 +287,9 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -283,6 +287,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_cascade_rcnn_test SRCS trt_cascade_rcnn_test.cc inference_analysis_test(trt_cascade_rcnn_test SRCS trt_cascade_rcnn_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_split_converter_test SRCS trt_split_converter_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_SPLIT_CONVERTER_MODEL}/)
inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(TensorRT, split_converter) {
std::string model_dir = FLAGS_infer_model + "/split_converter";
AnalysisConfig config;
int batch_size = 4;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 20, batch_size, 1,
AnalysisConfig::Precision::kFloat32, false);
auto predictor = CreatePaddlePredictor(config);
int channels = 4;
int height = 4;
int width = 4;
int input_num = batch_size * channels * height * width;
float *input = new float[input_num];
memset(input, 1.0, input_num * sizeof(float));
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({batch_size, channels, height, width});
input_t->copy_from_cpu(input);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
} // namespace inference
} // namespace paddle
...@@ -236,7 +236,8 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()( ...@@ -236,7 +236,8 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
const T* input, const std::vector<int>& input_shape, const T* input, const std::vector<int>& input_shape,
const std::vector<int>& output_shape, const std::vector<int>& ksize, const std::vector<int>& output_shape, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int>& strides, const std::vector<int>& paddings,
PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) { PoolProcess pool_compute, bool exclusive, bool adaptive, T* output,
cudaStream_t stream) {
const int batch_size = input_shape[0]; const int batch_size = input_shape[0];
const int input_channels = input_shape[1]; const int input_channels = input_shape[1];
const int input_height = input_shape[2]; const int input_height = input_shape[2];
...@@ -259,7 +260,7 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()( ...@@ -259,7 +260,7 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>( KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
nthreads, input, input_channels, input_height, input_width, output_height, nthreads, input, input_channels, input_height, input_width, output_height,
output_width, ksize_height, ksize_width, stride_height, stride_width, output_width, ksize_height, ksize_width, stride_height, stride_width,
padding_height, padding_width, pool_compute, exclusive, false, output); padding_height, padding_width, pool_compute, exclusive, adaptive, output);
} }
/* /*
......
...@@ -105,7 +105,8 @@ class Pool2dDirectCUDAFunctor { ...@@ -105,7 +105,8 @@ class Pool2dDirectCUDAFunctor {
const std::vector<int>& ksize, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute, const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, T* output, cudaStream_t stream); bool exclusive, bool adaptive, T* output,
cudaStream_t stream);
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册