未验证 提交 26ae6d49 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Update trt5 for paddle-trt (#18645)

* update paddle-trt for:
    1. fix bug: when batch > 2, core in split plugin.
    2. add leaky_relu trt5.0 support (yolov3 from 65ms to 42ms.)
    3. add new attr to dropout.
    4. shuffle channel, swish, relu6 support
    test=develop

* 1. fix ci
test=develop
上级 d8396281
......@@ -1880,6 +1880,9 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
reshape1_op->assert_more([&](Node *x) {
return boost::get<std::vector<int>>(x->Op()->GetAttr("shape")).size() == 5;
});
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
......
......@@ -968,6 +968,8 @@ USE_TRT_CONVERTER(split);
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
USE_TRT_CONVERTER(leaky_relu);
USE_TRT_CONVERTER(shuffle_channel);
USE_TRT_CONVERTER(swish);
#endif
#if PADDLE_WITH_ANAKIN
......
......@@ -74,6 +74,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"infer_clean_graph_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
// "fc_fuse_pass", //
......
......@@ -3,6 +3,7 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc
shuffle_channel_op.cc swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......@@ -42,3 +43,9 @@ nv_test(test_op_converter SRCS test_op_converter.cc DEPS
# prelu_op)
#nv_test(test_trt_leaky_relu_op SRCS test_leaky_relu_op.cc leaky_relu_op.cc
# DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op)
#nv_test(test_shuffle_channel_op SRCS test_shuffle_channel_op.cc shuffle_channel_op.cc
# DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine shuffle_channel_op)
#nv_test(test_swish_op SRCS test_swish_op.cc swish_op.cc
# DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op tensorrt_plugin)
......@@ -42,11 +42,20 @@ class ActivationOpConverter : public OpConverter {
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
op_pair->second);
#if IS_TRT_VERSION_GE(5130)
// max(alpha, min(beta, x))
if (op_type_ == "relu6") {
layer->setAlpha(0.);
layer->setBeta(6.);
}
#endif
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
if (op_desc.HasAttr("out_scale")) {
#if IS_TRT_VERSION_GE(5000)
#if IS_TRT_VERSION_GE(5130)
float out_scale = boost::get<float>(op_desc.GetAttr("out_scale"));
engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale);
#endif
......@@ -63,6 +72,9 @@ const std::unordered_map<std::string, nvinfer1::ActivationType>
{"relu", nvinfer1::ActivationType::kRELU},
{"sigmoid", nvinfer1::ActivationType::kSIGMOID},
{"tanh", nvinfer1::ActivationType::kTANH},
#if IS_TRT_VERSION_GE(5130)
{"relu6", nvinfer1::ActivationType::kCLIP},
#endif
};
class ReluOpConverter : public ActivationOpConverter {
......@@ -80,6 +92,11 @@ class TanhOpConverter : public ActivationOpConverter {
TanhOpConverter() { op_type_ = "tanh"; }
};
class Relu6OpConverter : public ActivationOpConverter {
public:
Relu6OpConverter() { op_type_ = "relu6"; }
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -87,3 +104,4 @@ class TanhOpConverter : public ActivationOpConverter {
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter);
REGISTER_TRT_OP_CONVERTER(relu6, Relu6OpConverter);
......@@ -31,6 +31,20 @@ class DropoutOpConverter : public OpConverter {
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
float dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob"));
std::string downgrade_in_infer = "";
if (op_desc.HasAttr("dropout_implementation")) {
downgrade_in_infer =
boost::get<std::string>(op_desc.GetAttr("dropout_implementation"));
}
if (!downgrade_in_infer.empty() &&
downgrade_in_infer == "upscale_in_train") {
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input1);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "dropout", {output_name}, test_mode);
return;
}
platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
......
......@@ -35,7 +35,14 @@ class LeakyReluOpConverter : public OpConverter {
PADDLE_ENFORCE(output_num == 1);
// Get attrs
float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
nvinfer1::ILayer* output_layer = nullptr;
#if IS_TRT_VERSION_GE(5100)
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *input, nvinfer1::ActivationType::kLEAKY_RELU);
layer->setAlpha(alpha);
output_layer = layer;
#else
platform::CPUPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor(
new framework::LoDTensor());
......@@ -65,7 +72,7 @@ class LeakyReluOpConverter : public OpConverter {
nvinfer1::ScaleMode::kUNIFORM, shift.get(),
sub_scale.get(), power.get());
PADDLE_ENFORCE(nullptr != scale_relu_layer);
auto* output_layer =
output_layer =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *(scale_layer->getOutput(0)),
*(scale_relu_layer->getOutput(0)),
nvinfer1::ElementWiseOperation::kSUM);
......@@ -75,7 +82,7 @@ class LeakyReluOpConverter : public OpConverter {
PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) ==
engine_->weight_map.end());
engine_->weight_map[alpha_name] = std::move(alpha_tensor);
#endif
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(output_layer, "leaky_relu", {output_name},
test_mode);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* ConcatOp
*/
class ShuffleChannelOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
PADDLE_ENFORCE(input_dims.nbDims == 3);
int c = input_dims.d[0];
int h = input_dims.d[1];
int w = input_dims.d[2];
int group = boost::get<int>(op_desc.GetAttr("group"));
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Dims4 reshape_dim(group, c / group, h, w);
layer->setReshapeDimensions(reshape_dim);
layer->setSecondTranspose({1, 0, 2, 3});
auto* output = layer->getOutput(0);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *output);
nvinfer1::DimsCHW reshape_dim2(c, h, w);
reshape_layer->setReshapeDimensions(reshape_dim2);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(reshape_layer, "concat", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class SwishOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid swish op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE(input_num == 1);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE(output_num == 1);
// Get attrs
float beta = boost::get<float>(op_desc.GetAttr("beta"));
plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "swish", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(swish, SwishOpConverter);
......@@ -46,6 +46,8 @@ TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); }
TEST(TanhOpConverter, main) { test_activation("tanh"); }
TEST(Relu6OpConverter, main) { test_activation("relu6"); }
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -53,3 +55,4 @@ TEST(TanhOpConverter, main) { test_activation("tanh"); }
USE_OP(relu);
USE_OP(sigmoid);
USE_OP(tanh);
USE_OP(relu6);
......@@ -34,6 +34,7 @@ TEST(DropoutOpConverter, main) {
framework::OpDesc desc;
int is_test = 1;
float dropout_prob = 0.4;
std::string dropout_implementation = "upscale_in_train";
desc.SetType("dropout");
desc.SetInput("X", {"dropout-X"});
......@@ -42,6 +43,8 @@ TEST(DropoutOpConverter, main) {
desc.SetAttr("is_test", is_test);
desc.SetAttr("dropout_prob", dropout_prob);
desc.SetAttr("dropout_implementation", dropout_implementation);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(leaky_relu_op, test_leaky_relu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("sc_input", nvinfer1::DimsCHW(4, 2, 2));
validator.DeclOutputVar("sc_out", nvinfer1::DimsCHW(4, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("shuffle_channel");
desc.SetInput("X", {"sc_input"});
desc.SetOutput("Out", {"sc_out"});
int group = 2;
desc.SetAttr("group", group);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// USE_OP(leaky_relu);
USE_OP(shuffle_channel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(swish_op, test_swish) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("sw_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("sw_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("swish");
desc.SetInput("X", {"sw_input"});
desc.SetOutput("Out", {"sw_out"});
desc.SetAttr("beta", 2.0f);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(swish);
......@@ -20,7 +20,11 @@ namespace tensorrt {
// Just tell by the op_types.
struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller() {}
SimpleOpTypeSetTeller() {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
#endif
}
bool operator()(const std::string& op_type,
const framework::OpDesc& desc) override {
......@@ -28,11 +32,27 @@ struct SimpleOpTypeSetTeller : public Teller {
}
private:
std::unordered_set<std::string> teller_set{
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "elementwise_mul", "dropout", "prelu",
"conv2d_transpose", "leaky_relu", "fc"}};
std::unordered_set<std::string> teller_set{{"mul",
"conv2d",
"pool2d",
"relu",
"softmax",
"sigmoid",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_mul",
"dropout",
"prelu",
"conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"split"}};
};
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
......
......@@ -18,6 +18,7 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace inference {
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc
avg_pool_op_plugin.cu
avg_pool_op_plugin.cu swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu)
......@@ -34,6 +34,7 @@ int PReluPlugin::initialize() {
cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
cudaMemcpyHostToDevice);
return 0;
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
......
......@@ -27,50 +27,20 @@ SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) {
}
REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize);
// copied from operators::math::SplitFunctor
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int* out_cols,
int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0;
int curr_offset = out_cols[0];
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
int curr_col_offset = out_cols[curr_segment + 1];
while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
curr_col_offset = out_cols[curr_segment + 1];
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs_data[curr_segment];
if (output_ptr != nullptr) {
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input_data[tid_y * in_col + tid_x];
}
}
}
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x / fixed_out_col;
int in_offset = tid_x - split * fixed_out_col;
T* output_ptr = outputs_data[split];
if (output_ptr != nullptr) {
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * in_col + tid_x];
__device__ int upper_bound(T const* vals, int n, T const& key) {
int i = 0;
while (n > 0) {
int m = n / 2;
int j = i + m;
if (!(key < vals[j])) {
i = j + 1;
n -= m + 1;
} else {
n = m;
}
}
return i;
}
nvinfer1::Dims SplitPlugin::getOutputDimensions(
......@@ -101,80 +71,61 @@ int SplitPlugin::initialize() {
if (output_length_[i] != output_length_[0]) {
same_shape_ = false;
}
segment_offsets.push_back(segment_offsets.back() +
output_length_[i] * inner_cols_);
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
}
inner_cols_ *= dims.d[axis_];
axis_shape_ = dims.d[axis_];
d_segment_offsets_ = segment_offsets;
segment_offsets_ = std::move(segment_offsets);
d_output_ptrs_.resize(this->getNbOutputs(), nullptr);
return 0;
}
// The following part of the code refers to onnx-tensorrt
// https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu
template <typename T>
inline void Split(cudaStream_t stream, const bool same_shape,
const int outer_rows, const int inner_cols,
const std::vector<int>& segment_offsets,
const int* d_segment_offsets, const T* input, T** outputs) {
const int kThreadsPerBlock = 1024;
const int kMaxBlocks = 65535;
int block_cols = kThreadsPerBlock;
if (inner_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((inner_cols + 31) >> 5) << 5;
__global__ void split_kernel(int nsegment,
int const* __restrict__ segment_offsets,
T const* __restrict__ idata, T* const* odatas,
int inner_cols, int axis_shape, int outer_rows) {
int x0 = threadIdx.x + blockIdx.x * blockDim.x;
int src_y0 = threadIdx.y + blockIdx.y * blockDim.y;
int z0 = threadIdx.z + blockIdx.z * blockDim.z;
for (int z = z0; z < outer_rows; z += blockDim.z * gridDim.z) {
for (int src_y = src_y0; src_y < axis_shape;
src_y += blockDim.y * gridDim.y) {
for (int x = x0; x < inner_cols; x += blockDim.x * gridDim.x) {
int segment = upper_bound(segment_offsets, nsegment, src_y) - 1;
int dst_y = src_y - segment_offsets[segment];
int dst_ny = segment_offsets[segment + 1] - segment_offsets[segment];
odatas[segment][x + inner_cols * (dst_y + dst_ny * z)] =
idata[x + inner_cols * (src_y + axis_shape * z)];
}
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int grid_cols =
std::min((inner_cols + block_cols - 1) / block_cols, kMaxBlocks);
int grid_rows =
std::min(kMaxBlocks / grid_cols, std::max(outer_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (same_shape) {
SplitKernel<<<grid_size, block_size, 0, stream>>>(
input, outer_rows, inner_cols, segment_offsets[1], outputs);
} else {
SplitKernel<<<grid_size, block_size, 0, stream>>>(
input, outer_rows, inner_cols, d_segment_offsets,
static_cast<int>(segment_offsets.size()), outputs);
}
}
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) {
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
if (((batchSize == 1 && axis_ == 0) || axis_ == -1) &&
this->getNbOutputs() < 10) {
float** output_ptrs = reinterpret_cast<float**>(outputs);
int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT)
? sizeof(float)
: sizeof(__half);
for (int i = 0; i < this->getNbOutputs(); ++i) {
PADDLE_ENFORCE(
cudaMemcpyAsync(
output_ptrs[i], input_ptr + segment_offsets_[i],
(segment_offsets_[i + 1] - segment_offsets_[i]) * data_type_size,
cudaMemcpyDeviceToDevice, stream) == cudaSuccess);
}
} else {
outer_rows_ *= batchSize;
const int* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, outputs,
this->getNbOutputs() * sizeof(float*),
PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, h_odatas,
d_output_ptrs_.size() * sizeof(float*),
cudaMemcpyHostToDevice,
stream) == cudaSuccess);
if (this->getDataType() == nvinfer1::DataType::kFLOAT) {
Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
d_segment_offsets_ptr, input_ptr, output_ptrs);
} else {
Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
d_segment_offsets_ptr, (__half*)input_ptr, // NOLINT
(__half**)output_ptrs); // NOLINT
}
}
int outer_rows = outer_rows_ * batchSize;
dim3 block(32, 16);
dim3 grid(std::min((inner_cols_ - 1) / block.x + 1, 65535u),
std::min((axis_shape_ - 1) / block.y + 1, 65535u),
std::min((outer_rows_ - 1) / block.z + 1, 65535u));
split_kernel<<<grid, block, 0, stream>>>(
d_segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
inner_cols_, axis_shape_, outer_rows);
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -66,6 +66,7 @@ class SplitPlugin : public PluginTensorRT {
int axis_;
int outer_rows_;
int inner_cols_;
int axis_shape_;
bool same_shape_;
std::vector<int> output_length_;
std::vector<int> segment_offsets_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
SwishPlugin *CreateSwishPluginDeserialize(const void *buffer, size_t length) {
return new SwishPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("swish_plugin", CreateSwishPluginDeserialize);
int SwishPlugin::initialize() { return 0; }
nvinfer1::Dims SwishPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
__global__ void swish_kernel(int num, const float *input, float *output,
float beta) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] =
__ldg(input + index) / (1.0f + expf(-beta * __ldg(input + index)));
#else
output[index] = input[index] / (1.0f + expf(-beta * input[index]));
#endif
}
}
int SwishPlugin::enqueue(int batch_size, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float **>(outputs)[0];
int num = batch_size;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
}
int threads = 1024;
int blocks = (num + threads - 1) / threads;
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output, beta_);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class SwishPlugin : public PluginTensorRT {
private:
float beta_;
protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(beta_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, beta_);
}
public:
explicit SwishPlugin(const float beta) : beta_(beta) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
SwishPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &beta_);
}
~SwishPlugin() {}
int initialize() override;
SwishPlugin *clone() const override { return new SwishPlugin(beta_); }
const char *getPluginType() const override { return "swish_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册