提交 21f33b42 编写于 作者: H hjchen2

Complete PRelu plugin and Conv2d transpose op converter

上级 23544096
...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { ...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout", "split"}); "elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"});
if (!node->IsOp()) return false; if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) { if (teller_set.count(node->Op()->Type())) {
......
...@@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat); ...@@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(split);
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
#endif #endif
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc pad_op.cc split_op.cc prelu_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -16,7 +16,7 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc ...@@ -16,7 +16,7 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL)
nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op conv_transpose_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
...@@ -33,4 +33,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc ...@@ -33,4 +33,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL) split_op concat_op SERIAL)
nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
prelu_op SERIAL)
...@@ -18,33 +18,35 @@ namespace paddle { ...@@ -18,33 +18,35 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
bool to_skip_merging_optimize(TensorRTEngine* engine_, bool to_skip_merging_optimize(TensorRTEngine* engine,
const std::vector<int>& filters, const std::vector<int>& filters,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
std::string input_name) { std::string input_name) {
if (engine_->itensor_quote_num[input_name] > 0) { if (engine->itensor_quote_num[input_name] > 0) {
return true; return true;
} }
if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 &&
strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0)
engine_->itensor_quote_num[input_name] += 1; engine->itensor_quote_num[input_name] += 1;
return false; return false;
} }
class Conv2dOpConverter : public OpConverter { template <typename RegistFunc, typename SetDilationFunc>
public: void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode,
const framework::Scope& scope, bool test_mode) override { RegistFunc fadd_layer, SetDilationFunc fset_dilation,
VLOG(3) << "convert a fluid conv2d op to tensorrt conv layer without bias"; const std::string& name) {
VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1);
auto* X = engine_->GetITensor(op_desc.Input("Input").front()); PADDLE_ENFORCE(engine != nullptr);
auto* X = engine->GetITensor(op_desc.Input("Input").front());
// Declare weights // Declare weights
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); auto* Y_v = scope.FindVar(op_desc.Input("Filter").front());
...@@ -57,14 +59,13 @@ class Conv2dOpConverter : public OpConverter { ...@@ -57,14 +59,13 @@ class Conv2dOpConverter : public OpConverter {
weight_tensor->Resize(Y_t->dims()); weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); TensorCopySync((*Y_t), cpu_place, weight_tensor.get());
auto* weight_data = auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
weight_tensor->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
const int n_output = weight_tensor->dims()[0]; const int n_output = weight_tensor->dims()[0];
const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2]; const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3]; const int filter_w = weight_tensor->dims()[3];
const int groups = boost::get<int>(op_desc.GetAttr("groups")); const int groups = boost::get<int>(op_desc.GetAttr("groups"));
const std::vector<int> dilations = const std::vector<int> dilations =
boost::get<std::vector<int>>(op_desc.GetAttr("dilations")); boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
...@@ -80,30 +81,76 @@ class Conv2dOpConverter : public OpConverter { ...@@ -80,30 +81,76 @@ class Conv2dOpConverter : public OpConverter {
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
weight_tensor->memory_size() / sizeof(float)}; static_cast<size_t>(weight_tensor->numel())};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = TRT_ENGINE_ADD_LAYER( auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input,
engine_, Convolution, *const_cast<nvinfer1::ITensor*>(X), n_output, nv_ksize, weight, bias);
nv_ksize, weight.get(), bias.get());
PADDLE_ENFORCE(layer != nullptr); PADDLE_ENFORCE(layer != nullptr);
layer->setStride(nv_strides); layer->setStride(nv_strides);
layer->setPadding(nv_paddings); layer->setPadding(nv_paddings);
layer->setDilation(nv_dilations);
layer->setNbGroups(groups); layer->setNbGroups(groups);
// set dilations
fset_dilation(layer, nv_dilations);
auto output_name = op_desc.Output("Output").front(); auto output_name = op_desc.Output("Output").front();
layer->setName(("conv2d (Output: " + output_name + ")").c_str()); layer->setName((name + " (Output: " + output_name + ")").c_str());
engine_->weight_map[op_desc.Input("Filter").front()] = engine->weight_map[op_desc.Input("Filter").front()] =
std::move(weight_tensor); std::move(weight_tensor);
layer->getOutput(0)->setName(output_name.c_str()); layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine->SetITensor(output_name, layer->getOutput(0));
if (test_mode || if (test_mode ||
to_skip_merging_optimize(engine_, {filter_h, filter_w}, strides, to_skip_merging_optimize(engine, {filter_h, filter_w}, strides, paddings,
paddings, op_desc.Input("Input").front())) { op_desc.Input("Input").front())) {
engine_->DeclareOutput(output_name); engine->DeclareOutput(output_name);
} }
}
class Conv2dOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
ConvertConv2d(
engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */
int n_input, /* Conv input maps */
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
ksize, weight.get(), bias.get());
return layer;
},
[](nvinfer1::IConvolutionLayer* layer, nvinfer1::DimsHW& dilations) {
layer->setDilation(dilations);
},
"conv2d");
}
};
class Deconv2dOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
ConvertConv2d(
engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */
int n_input, /* Deconv output maps */
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input,
ksize, weight.get(), bias.get());
return layer;
},
[](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::DimsHW& dilations) {
PADDLE_ENFORCE(
dilations.d[0] == 1 && dilations.d[1] == 1,
"Dilations must be (1, 1) for tensorRT, but given (%d, %d)",
dilations.d[0], dilations.d[1]);
},
"conv2d_transpose");
} }
}; };
...@@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter { ...@@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter {
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter);
...@@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* X = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims dims_x = X->getDimensions(); nvinfer1::Dims dims_x = X->getDimensions();
PADDLE_ENFORCE(dims_x.nbDims >= 3); PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.",
dims_x.nbDims);
auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v); PADDLE_ENFORCE_NOT_NULL(Y_v);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* PRelu converter from fluid to tensorRT.
*/
class PReluOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert fluid prelu op to tensorrt prelu layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE(input_num == 1);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE(output_num == 1);
// Get attrs
std::string mode = boost::get<std::string>(op_desc.GetAttr("mode"));
//
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
PADDLE_ENFORCE_NOT_NULL(alpha_var);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
platform::CPUPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_host(
new framework::LoDTensor());
alpha_tensor_host->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_host.get());
float* alpha_data = alpha_tensor_host->mutable_data<float>(place);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data),
alpha_tensor_host->numel());
engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_host);
//
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
std::string layer_name = "prelu (Output: ";
for (size_t i = 0; i < output_num; i++) {
auto output_name = op_desc.Output("Out")[i];
layer->getOutput(i)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(i));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
layer->setName((layer_name + ")").c_str());
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter);
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_OP(conv2d);
USE_OP(conv2d_transpose);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
...@@ -51,7 +54,38 @@ TEST(conv2d_op, test) { ...@@ -51,7 +54,38 @@ TEST(conv2d_op, test) {
validator.Execute(3); validator.Execute(3);
} }
TEST(conv2d_transpose_op, test) {
std::unordered_set<std::string> parameters({"deconv2d-Y"});
framework::Scope scope;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
validator.DeclInputVar("deconv2d-X", nvinfer1::Dims3(3, 5, 5));
validator.DeclParamVar("deconv2d-Y", nvinfer1::Dims4(3, 2, 3, 3));
validator.DeclOutputVar("deconv2d-Out", nvinfer1::Dims3(2, 5, 5));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("conv2d_transpose");
desc.SetInput("Input", {"deconv2d-X"});
desc.SetInput("Filter", {"deconv2d-Y"});
desc.SetOutput("Output", {"deconv2d-Out"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
const int groups = 1;
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("dilations", dilations);
desc.SetAttr("groups", groups);
validator.SetOp(*desc.Proto());
validator.Execute(3);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(conv2d);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(prelu_op, test_channel_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("channel"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(prelu_op, test_element_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("element"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(prelu_op, test_scalar) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("all"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// USE_OP(prelu);
USE_CPU_ONLY_OP(prelu);
...@@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, ...@@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
Buffer &TensorRTEngine::buffer(const std::string &name) { Buffer &TensorRTEngine::buffer(const std::string &name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name); auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s",
name);
auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
return buffers_[slot_offset]; return buffers_[slot_offset];
} }
......
...@@ -40,12 +40,13 @@ class TensorRTEngine : public EngineBase { ...@@ -40,12 +40,13 @@ class TensorRTEngine : public EngineBase {
// Weight is model parameter. // Weight is model parameter.
class Weight { class Weight {
public: public:
Weight() = default;
Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) { Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) {
w_.type = dtype; w_.type = dtype;
w_.values = value; w_.values = value;
w_.count = num_elem; w_.count = num_elem;
} }
const nvinfer1::Weights& get() { return w_; } nvinfer1::Weights& get() { return w_; }
std::vector<int64_t> dims; std::vector<int64_t> dims;
......
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
float *output, int channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float *out = output + offset;
float scale = alpha[blockIdx.x % channel];
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
const float *scale = alpha + offset;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale[i] * x;
}
}
__global__ void PReluScalarKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float scale = *alpha;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
static inline void PReluChannelWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, dims.d[0], spatial_size);
}
static inline void PReluElementWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
static inline void PReluScalar(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size, const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
int PReluPlugin::initialize() {
nvinfer1::Weights &alpha = cuda_alpha_.get();
alpha.type = alpha_.get().type;
alpha.count = alpha_.get().count;
CHECK_EQ(cudaMalloc(&alpha.values, alpha.count * sizeof(float)), cudaSuccess);
CHECK_EQ(cudaMemcpy(const_cast<void *>(alpha.values), alpha_.get().values,
alpha.count * sizeof(float), cudaMemcpyHostToDevice),
cudaSuccess);
return 0;
}
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha =
reinterpret_cast<const float *>(cuda_alpha_.get().values);
float *output = reinterpret_cast<float **>(outputs)[0];
if (mode_ == "channel") {
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
} else if (mode_ == "element") {
PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
} else {
PReluScalar(stream, input, alpha, output, batchSize, input_dims);
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_;
TensorRTEngine::Weight cuda_alpha_;
std::string mode_;
protected:
size_t getSerializationSize() override {
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
return 0;
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
// serializeBase(buffer);
// SerializeValue(&buffer, alpha_);
// SerializeValue(&buffer, mode_);
}
public:
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode)
: alpha_(alpha), mode_(mode) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
PReluPlugin(void const *serialData, size_t serialLength) {
// deserializeBase(serialData, serialLength);
// DeserializeValue(&serialData, &serialLength, &alpha_);
// DeserializeValue(&serialData, &serialLength, &mode_);
}
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); }
const char *getPluginType() const override { return "prelu"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int initialize() override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册