未验证 提交 0a51098a 编写于 作者: P Pei Yang 提交者: GitHub

Add TRT support for BERT (#21135)

* add gelu plugin

* align trt bert with gpu

* add support for fused fc with relu,

* add unittest for bert trt
上级 b0b27ff6
...@@ -938,6 +938,9 @@ USE_TRT_CONVERTER(conv2d_transpose); ...@@ -938,6 +938,9 @@ USE_TRT_CONVERTER(conv2d_transpose);
USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(leaky_relu);
USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(shuffle_channel);
USE_TRT_CONVERTER(swish); USE_TRT_CONVERTER(swish);
USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
#endif #endif
#if PADDLE_WITH_ANAKIN #if PADDLE_WITH_ANAKIN
......
...@@ -76,10 +76,13 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -76,10 +76,13 @@ const std::vector<std::string> kTRTSubgraphPasses({
"shuffle_channel_detect_pass", // "shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", // "delete_quant_dequant_op_pass", //
"conv_bn_fuse_pass", // // "fc_fuse_pass", //
"fc_fuse_pass", // "simplify_with_basic_ops_pass", //
"tensorrt_subgraph_pass", // "multihead_matmul_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"fc_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7 // guaranteed at least v7
"conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add_act_fuse_pass", //
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc shuffle_channel_op.cc swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
...@@ -44,7 +44,6 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights, // NOLINT ...@@ -44,7 +44,6 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights, // NOLINT
static_cast<float*>(const_cast<void*>(oweights->get().values)), static_cast<float*>(const_cast<void*>(oweights->get().values)),
ostrides); ostrides);
} }
/* /*
* FC converter convert a MUL op in Fluid to a FC layer in TRT. * FC converter convert a MUL op in Fluid to a FC layer in TRT.
*/ */
...@@ -63,7 +62,6 @@ class FcOpConverter : public OpConverter { ...@@ -63,7 +62,6 @@ class FcOpConverter : public OpConverter {
w_name = "W"; w_name = "W";
i_name = "Input"; i_name = "Input";
} }
// Declare inputs // Declare inputs
auto* X = engine_->GetITensor(op_desc.Input(i_name).front()); auto* X = engine_->GetITensor(op_desc.Input(i_name).front());
...@@ -71,6 +69,16 @@ class FcOpConverter : public OpConverter { ...@@ -71,6 +69,16 @@ class FcOpConverter : public OpConverter {
auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); auto* Y_v = scope.FindVar(op_desc.Input(w_name).front());
PADDLE_ENFORCE_NOT_NULL(Y_v); PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
const int x_num_col_dims =
op_desc.HasAttr("x_num_col_dims")
? boost::get<int>(op_desc.GetAttr("x_num_col_dims"))
: (op_desc.HasAttr("in_num_col_dims")
? boost::get<int>(op_desc.GetAttr("in_num_col_dims"))
: 1);
const std::string activation_type =
op_desc.HasAttr("activation_type")
? boost::get<std::string>(op_desc.GetAttr("activation_type"))
: "";
// This may trigger a GPU->CPU copy, because TRT's weight can only be // This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, which can't be avoided. // assigned from CPU memory, which can't be avoided.
float* weight_data = nullptr; float* weight_data = nullptr;
...@@ -128,14 +136,76 @@ class FcOpConverter : public OpConverter { ...@@ -128,14 +136,76 @@ class FcOpConverter : public OpConverter {
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<size_t>(bias_num)}; static_cast<size_t>(bias_num)};
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, // in order to handle situations in NLP models(input dims < 3,
*const_cast<nvinfer1::ITensor*>(X), // x_num_col_dims != 1, etc.), reshape input to perform FC correctly.
n_output, tmp_weight.get(), bias.get()); auto* reshape_itensor = X;
int input_dims = X->getDimensions().nbDims;
auto input_d = X->getDimensions().d;
int reshape_dim3[3] = {0};
int reshape_dim4[4] = {0};
PADDLE_ENFORCE_EQ(
x_num_col_dims == 1 || x_num_col_dims == 2, true,
platform::errors::InvalidArgument(
"Wrong x_num_col_dims param of op mul. Paddle-TRT FC converter "
"expects x_num_col_dims is either 1 or 2, but got %d",
x_num_col_dims));
PADDLE_ENFORCE_LE(x_num_col_dims, input_dims,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_num_col_dims <= input dims"));
if (x_num_col_dims == 1) {
if (input_dims == 4) {
PADDLE_ENFORCE_EQ(
input_d[3], 1,
platform::errors::InvalidArgument(
"Invalid dimensions. When x_num_col_dims equals to 1 and input "
"dims equals to 4, the last dim of input must be 1, but got %d",
input_d[3]));
}
for (int i = 0; i < 3; i++) {
if (i < input_dims) {
reshape_dim3[i] = input_d[i];
} else {
reshape_dim3[i] = 1;
}
}
nvinfer1::Dims3 reshape_dim(reshape_dim3[0], reshape_dim3[1],
reshape_dim3[2]);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
reshape_layer->setReshapeDimensions(reshape_dim);
reshape_itensor = reshape_layer->getOutput(0);
} else {
PADDLE_ENFORCE_NE(input_dims, 1,
platform::errors::InvalidArgument(
"Invalid dimensions. When x_num_col_dims equals to "
"2, input_dims should not be 1"));
for (int i = 0; i < 4; i++) {
if (i < input_dims) {
reshape_dim4[i] = input_d[i];
} else {
reshape_dim4[i] = 1;
}
}
nvinfer1::Dims4 reshape_dim(reshape_dim4[0], reshape_dim4[1],
reshape_dim4[2], reshape_dim4[3]);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
reshape_layer->setReshapeDimensions(reshape_dim);
reshape_itensor = reshape_layer->getOutput(0);
}
auto* fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *reshape_itensor,
n_output, tmp_weight.get(), bias.get());
engine_->SetWeights(op_desc.Input(w_name).front(), std::move(tmp)); engine_->SetWeights(op_desc.Input(w_name).front(), std::move(tmp));
auto output_name = op_desc.Output("Out").front(); auto output_name = op_desc.Output("Out").front();
if (activation_type == "relu") {
RreplenishLayerAndOutput(layer, "fc", {output_name}, test_mode); nvinfer1::IActivationLayer* relu_layer =
TRT_ENGINE_ADD_LAYER(engine_, Activation, *(fc_layer->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer, "fc", {output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer, "fc", {output_name}, test_mode);
}
} }
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Gelu converter from fluid to tensorRT.
*/
class GeluOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid gelu op to tensorrt gelu layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE_EQ(input_num, 1,
platform::errors::InvalidArgument(
"gelu op has only 1 input, but got %d", input_num));
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE_EQ(output_num, 1,
platform::errors::InvalidArgument(
"gelu op has only 1 output, but got %d", output_num));
// Get input shape and volume
nvinfer1::Dims input_shape = input->getDimensions();
size_t input_volume = 1;
for (int i = 0; i < input_shape.nbDims; i++) {
input_volume *= input_shape.d[i];
}
plugin::GeluPlugin* plugin = new plugin::GeluPlugin(input_volume);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(gelu, GeluOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class LayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid layer_norm op to tensorrt layer_norm plugin";
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(), 1,
platform::errors::InvalidArgument(
"input of layer_norm op converter should be 1, got %d",
op_desc.Input("X").size()));
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1,
platform::errors::InvalidArgument(
"Bias of layer_norm op converter should be 1, got %d",
op_desc.Input("Bias").size())); // Bias is a weight
PADDLE_ENFORCE_EQ(
op_desc.Input("Scale").size(), 1,
platform::errors::InvalidArgument(
"Scale of layer_norm op converter should be 1, got %d",
op_desc.Input("Scale").size())); // Scale is a weight
PADDLE_ENFORCE_EQ(
op_desc.Output("Y").size(), 1,
platform::errors::InvalidArgument(
"output of layer_norm op converter should be 1, got %d",
op_desc.Input("Y").size()));
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
const int begin_norm_axis =
op_desc.HasAttr("begin_norm_axis")
? boost::get<int>(op_desc.GetAttr("begin_norm_axis"))
: 1;
const float eps = op_desc.HasAttr("epsilon")
? boost::get<float>(op_desc.GetAttr("epsilon"))
: 1e-5f;
PADDLE_ENFORCE_NOT_NULL(
Bias_v, platform::errors::InvalidArgument(
"Input(Bias) of layer_norm should not be null."));
PADDLE_ENFORCE_NOT_NULL(
Scale_v, platform::errors::InvalidArgument(
"Input(Scale) of layer_norm should not be null."));
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
int input_num = 1;
for (int i = 0; i < X->getDimensions().nbDims; i++) {
input_num *= X->getDimensions().d[i];
}
std::vector<int64_t> mean_shape{input_num};
std::vector<int64_t> variance_shape{input_num};
std::unique_ptr<framework::LoDTensor> bias_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> scale_tensor(
new framework::LoDTensor());
bias_tensor->Resize(Bias_t->dims());
scale_tensor->Resize(Scale_t->dims());
platform::CPUPlace cpu_place;
TensorCopySync((*Bias_t), cpu_place, &(*bias_tensor));
TensorCopySync((*Scale_t), cpu_place, &(*scale_tensor));
auto* bias_data = bias_tensor->mutable_data<float>(platform::CPUPlace());
auto* scale_data = scale_tensor->mutable_data<float>(platform::CPUPlace());
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(),
begin_norm_axis, eps, mean_shape, variance_shape);
nvinfer1::IPluginLayer* layernorm_layer = engine_->AddPlugin(&X, 1, plugin);
auto output_name = op_desc.Output("Y").front();
engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor));
engine_->SetWeights(op_desc.Input("Scale").front(),
std::move(scale_tensor));
RreplenishLayerAndOutput(layernorm_layer, "layer_norm", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(layer_norm, LayerNormOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class MultiheadMatMulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt "
"network structure";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* Q = engine_->GetITensor(op_desc.Input("Q").front());
auto* K = engine_->GetITensor(op_desc.Input("K").front());
auto* V = engine_->GetITensor(op_desc.Input("V").front());
auto* BiasQ = scope.FindVar(op_desc.Input("BiasQ").front());
auto* BiasK = scope.FindVar(op_desc.Input("BiasK").front());
auto* BiasV = scope.FindVar(op_desc.Input("BiasV").front());
auto* BiasQK = engine_->GetITensor(op_desc.Input("BiasQK").front());
PADDLE_ENFORCE_EQ(op_desc.Input("Q").size(), 1,
platform::errors::InvalidArgument(
"size of input Q of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(op_desc.Input("K").size(), 1,
platform::errors::InvalidArgument(
"size of input K of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(op_desc.Input("V").size(), 1,
platform::errors::InvalidArgument(
"size of input V of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(
op_desc.Input("BiasQK").size(), 1,
platform::errors::InvalidArgument(
"size of input BiasQK of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1,
platform::errors::InvalidArgument(
"size of output of multihead_matmul should be 1"));
PADDLE_ENFORCE_NOT_NULL(
BiasQ, platform::errors::InvalidArgument(
"param BiasQ of multihead_matmul should not be null"));
PADDLE_ENFORCE_NOT_NULL(
BiasK, platform::errors::InvalidArgument(
"param BiasK of multihead_matmul should not be null"));
PADDLE_ENFORCE_NOT_NULL(
BiasV, platform::errors::InvalidArgument(
"param BiasV of multihead_matmul should not be null"));
PADDLE_ENFORCE_EQ(
BiasQK->getDimensions().nbDims, 3,
platform::errors::InvalidArgument(
"dims size of input BiasQK of multihead_matmul should be 3"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("alpha"), true,
platform::errors::PreconditionNotMet(
"attribute alpha of multihead_matmul should not be empty"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("head_number"), true,
platform::errors::PreconditionNotMet(
"attribute head_number of multihead_matmul should not be empty"));
// Declare attributes
const bool transpose_q =
op_desc.HasAttr("transpose_Q")
? boost::get<bool>(op_desc.GetAttr("transpose_Q"))
: false;
const bool transpose_k =
op_desc.HasAttr("transpose_K")
? boost::get<bool>(op_desc.GetAttr("transpose_K"))
: true;
const bool transpose_v =
op_desc.HasAttr("transpose_V")
? boost::get<bool>(op_desc.GetAttr("transpose_V"))
: false;
const float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
const int head_number = boost::get<int>(op_desc.GetAttr("head_number"));
nvinfer1::Dims q_shape = Q->getDimensions();
int seq_len = q_shape.d[0];
int size_per_head = q_shape.d[1] / head_number;
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
framework::DDim alpha_dim = framework::make_ddim({1});
std::unique_ptr<framework::LoDTensor> alpha_t(new framework::LoDTensor());
alpha_t->Resize(alpha_dim);
float* alpha_data = alpha_t->mutable_data<float>(platform::CPUPlace());
alpha_data[0] = alpha;
TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data), 1};
TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0};
TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* bias_q_t = BiasQ->GetMutable<framework::LoDTensor>();
auto* bias_k_t = BiasK->GetMutable<framework::LoDTensor>();
auto* bias_v_t = BiasV->GetMutable<framework::LoDTensor>();
float* bias_q_cpu_data = engine_->GetWeightCPUData(
op_desc.Input("BiasQ").front(), bias_q_t, false);
float* bias_k_cpu_data = engine_->GetWeightCPUData(
op_desc.Input("BiasK").front(), bias_k_t, false);
float* bias_v_cpu_data = engine_->GetWeightCPUData(
op_desc.Input("BiasV").front(), bias_v_t, false);
std::unique_ptr<framework::LoDTensor> bias_q_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> bias_k_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> bias_v_tensor(
new framework::LoDTensor());
bias_q_tensor->Resize(bias_q_t->dims());
bias_k_tensor->Resize(bias_k_t->dims());
bias_v_tensor->Resize(bias_v_t->dims());
platform::CPUPlace cpu_place;
TensorCopySync((*bias_q_t), cpu_place, bias_q_tensor.get());
TensorCopySync((*bias_k_t), cpu_place, bias_k_tensor.get());
TensorCopySync((*bias_v_t), cpu_place, bias_v_tensor.get());
TensorRTEngine::Weight scale_weights_q{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight shift_weights_q{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_q_cpu_data),
bias_q_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights_q{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight scale_weights_k{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight shift_weights_k{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_k_cpu_data),
bias_k_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights_k{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight scale_weights_v{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight shift_weights_v{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_v_cpu_data),
bias_v_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights_v{nvinfer1::DataType::kFLOAT, nullptr,
0};
auto* q_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *Q, nvinfer1::ScaleMode::kCHANNEL,
shift_weights_q.get(), scale_weights_q.get(), power_weights_q.get());
auto* k_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *K, nvinfer1::ScaleMode::kCHANNEL,
shift_weights_k.get(), scale_weights_k.get(), power_weights_k.get());
auto* v_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *V, nvinfer1::ScaleMode::kCHANNEL,
shift_weights_v.get(), scale_weights_v.get(), power_weights_v.get());
auto* v_transpose_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(v_eltadd_layer->getOutput(0)));
auto* q_transpose_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(q_eltadd_layer->getOutput(0)));
auto* k_transpose_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(k_eltadd_layer->getOutput(0)));
nvinfer1::Dims3 head_reshape_dim(seq_len, head_number, size_per_head);
v_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
v_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
q_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
q_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
k_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
k_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
auto* q_scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *(q_transpose_reshape_layer->getOutput(0)),
nvinfer1::ScaleMode::kUNIFORM, shift.get(), scale.get(), power.get());
auto* qk_matmul_layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *(q_scale_layer->getOutput(0)), transpose_q,
*(k_transpose_reshape_layer->getOutput(0)), transpose_k);
auto* qk_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *BiasQK, *(qk_matmul_layer->getOutput(0)),
nvinfer1::ElementWiseOperation::kSUM);
auto* softmax_layer = TRT_ENGINE_ADD_LAYER(
engine_, SoftMax, *(qk_eltadd_layer->getOutput(0)));
softmax_layer->setAxes(4);
auto* qkv_matmul_layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *(softmax_layer->getOutput(0)), false,
*(v_transpose_reshape_layer->getOutput(0)), transpose_v);
auto* qkv_transpose_reshape_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *(qkv_matmul_layer->getOutput(0)));
nvinfer1::Dims2 qkv_reshape_dim(seq_len, head_number * size_per_head);
qkv_transpose_reshape_layer->setFirstTranspose({1, 0, 2});
qkv_transpose_reshape_layer->setReshapeDimensions(qkv_reshape_dim);
engine_->SetWeights(alpha_name, std::move(alpha_t));
engine_->SetWeights(op_desc.Input("BiasQ").front(),
std::move(bias_q_tensor));
engine_->SetWeights(op_desc.Input("BiasK").front(),
std::move(bias_k_tensor));
engine_->SetWeights(op_desc.Input("BiasV").front(),
std::move(bias_v_tensor));
auto output_name = op_desc.Output("Out").front();
RreplenishLayerAndOutput(qkv_transpose_reshape_layer, "multihead_matmul",
{output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(multihead_matmul, MultiheadMatMulOpConverter);
...@@ -43,18 +43,27 @@ TRT_DT FluidDataType2TRT(FluidDT type) { ...@@ -43,18 +43,27 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
default: default:
return TRT_DT::kINT32; return TRT_DT::kINT32;
} }
PADDLE_THROW("unkown type"); PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in TRT op converter"));
return TRT_DT::kINT32; return TRT_DT::kINT32;
} }
nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape) { nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape,
std::string input) {
PADDLE_ENFORCE_GT(shape.size(), 1UL, PADDLE_ENFORCE_GT(shape.size(), 1UL,
"TensorRT' tensor input requires at least 2 dimensions"); platform::errors::InvalidArgument(
"TensorRT's tensor input requires at least 2 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
PADDLE_ENFORCE_LE(shape.size(), 4UL, PADDLE_ENFORCE_LE(shape.size(), 4UL,
"TensorRT' tensor input requires at most 4 dimensions"); platform::errors::InvalidArgument(
PADDLE_ENFORCE(shape.size() == 4UL || shape.size() == 2UL); "TensorRT's tensor input requires at most 4 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
if (shape.size() == 4UL) if (shape.size() == 4UL)
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
else if (shape.size() == 3UL)
return nvinfer1::Dims2(shape[1], shape[2]);
return nvinfer1::DimsCHW(shape[1], 1, 1); return nvinfer1::DimsCHW(shape[1], 1, 1);
} }
...@@ -162,7 +171,7 @@ class OpConverter { ...@@ -162,7 +171,7 @@ class OpConverter {
engine->DeclareInput( engine->DeclareInput(
input, FluidDataType2TRT( input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()), var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var_shape)); Vec2TRT_Dims(var_shape, input));
} }
framework::proto::BlockDesc* block_proto = block_desc->Proto(); framework::proto::BlockDesc* block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine); ConvertBlock(*block_proto, parameters, scope, engine);
......
...@@ -52,7 +52,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -52,7 +52,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"fc", "fc",
"shuffle_channel", "shuffle_channel",
"swish", "swish",
"split"}}; "split",
"gelu",
"layer_norm",
"multihead_matmul"}};
}; };
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
......
nv_library(tensorrt_plugin nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
DEPS enforce tensorrt_engine prelu) DEPS enforce tensorrt_engine prelu)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// constants for approximating the normal cdf
constexpr float A = 1.41421356237309504; // sqrt(2)
GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
return new GeluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("gelu plugin", CreateGeluPluginDeserialize);
nvinfer1::Dims GeluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims* in_dims,
int nb_inputs) {
assert(nb_inputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = in_dims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
template <typename T, unsigned TPB>
__global__ void geluKernel(const T a, int n, const T* input, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < n) {
const T in = input[idx];
const T cdf = 0.5 * (1.0 + erf(in * 0.5 * a));
output[idx] = in * cdf;
}
}
int computeGelu(cudaStream_t stream, int n, const float* input, float* output) {
constexpr int blockSize = 256;
const int gridSize = (n + blockSize - 1) / blockSize;
geluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, n, input,
output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
return 0;
}
int GeluPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void*, cudaStream_t stream) {
int status = -1;
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
status = computeGelu(stream, input_volume_ * batchSize, input, output);
return status;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class GeluPlugin : public PluginTensorRT {
protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(getPluginType()) +
SerializedSize(input_volume_);
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, input_volume_);
}
public:
explicit GeluPlugin(size_t input_volume) : input_volume_(input_volume) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
GeluPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &input_volume_);
}
~GeluPlugin() {}
int initialize() override { return 0; }
GeluPlugin *clone() const override { return new GeluPlugin(input_volume_); }
const char *getPluginType() const override { return "gelu_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
private:
size_t input_volume_;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/layer_norm_op.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
LayerNormPlugin *CreateLayerNormPluginDeserialize(const void *buffer,
size_t length) {
return new LayerNormPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("layer_norm_plugin", CreateLayerNormPluginDeserialize);
int LayerNormPlugin::initialize() { return 0; }
nvinfer1::Dims LayerNormPlugin::getOutputDimensions(
int index, const nvinfer1::Dims *inputDims, int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
void **outputs, void *workspace,
cudaStream_t stream) {
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float **>(outputs)[0];
int begin_norm_axis = begin_norm_axis_;
float eps = eps_;
int c = input_dims.d[begin_norm_axis - 1];
scale_t.Resize(framework::make_ddim({c}));
bias_t.Resize(framework::make_ddim({c}));
mean_t.Resize(framework::make_ddim(mean_shape_));
variance_t.Resize(framework::make_ddim(variance_shape_));
int device_id;
cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * c,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * c,
cudaMemcpyHostToDevice, stream);
std::vector<int> input_shape;
input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class LayerNormPlugin : public PluginTensorRT {
std::vector<float> bias_;
std::vector<float> scale_;
framework::Tensor scale_t;
framework::Tensor bias_t;
framework::Tensor mean_t;
framework::Tensor variance_t;
int begin_norm_axis_;
float eps_;
std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_;
protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(bias_) +
SerializedSize(scale_) + SerializedSize(begin_norm_axis_) +
SerializedSize(eps_) + SerializedSize(mean_shape_) +
SerializedSize(variance_shape_) + SerializedSize(getPluginType());
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, begin_norm_axis_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_);
}
public:
LayerNormPlugin(const float *bias, const int bias_num, const float *scale,
const int scale_num, int begin_norm_axis, float eps,
std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape)
: begin_norm_axis_(begin_norm_axis),
eps_(eps),
mean_shape_(mean_shape),
variance_shape_(variance_shape) {
bias_.resize(bias_num);
scale_.resize(scale_num);
std::copy(bias, bias + bias_num, bias_.data());
std::copy(scale, scale + scale_num, scale_.data());
}
// It was used for tensorrt deserialization.
// It should not be called by users.
LayerNormPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_);
}
~LayerNormPlugin() {}
int initialize() override;
LayerNormPlugin *clone() const override {
return new LayerNormPlugin(bias_.data(), bias_.size(), scale_.data(),
scale_.size(), begin_norm_axis_, eps_,
mean_shape_, variance_shape_);
}
const char *getPluginType() const override { return "layer_norm_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -330,6 +330,9 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -330,6 +330,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_resnext_test SRCS trt_resnext_test.cc inference_analysis_test(trt_resnext_test SRCS trt_resnext_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_bert_test SRCS trt_bert_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${BERT_INSTALL_DIR}/model)
inference_analysis_test(trt_fc_prelu_test SRCS trt_fc_prelu_test.cc inference_analysis_test(trt_fc_prelu_test SRCS trt_fc_prelu_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(TensorRT, split_converter) {
AnalysisConfig config;
int batch_size = 1;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(1200, 0);
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 30, batch_size, 10,
AnalysisConfig::Precision::kFloat32, false,
false);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
int64_t i0[128] = {
96, 54, 78, 37, 106, 35, 122, 33, 95, 63, 81, 60, 65, 68, 45, 96,
117, 61, 43, 15, 12, 64, 91, 100, 90, 74, 99, 23, 22, 91, 83, 13,
28, 71, 59, 15, 40, 26, 66, 18, 31, 87, 85, 11, 55, 67, 28, 126,
7, 89, 39, 67, 88, 29, 66, 38, 98, 1, 66, 38, 95, 56, 48, 95,
9, 38, 90, 82, 101, 6, 75, 46, 42, 89, 98, 12, 6, 101, 82, 55,
81, 113, 33, 91, 44, 73, 41, 39, 12, 113, 13, 86, 36, 91, 53, 68,
103, 67, 65, 92, 27, 76, 24, 107, 54, 94, 63, 10, 15, 32, 91, 45,
37, 126, 49, 118, 73, 127, 122, 119, 28, 96, 92, 79, 21, 90, 11, 40};
int64_t i1[128] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127};
int64_t i2[128] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float i3[128 * 128] = {0.0};
int64_t i4[1] = {0};
auto input_names = predictor->GetInputNames();
auto input_t0 = predictor->GetInputTensor(input_names[0]);
input_t0->Reshape({batch_size, 128, 1});
input_t0->copy_from_cpu(i0);
auto input_t1 = predictor->GetInputTensor(input_names[1]);
input_t1->Reshape({batch_size, 128, 1});
input_t1->copy_from_cpu(i1);
auto input_t2 = predictor->GetInputTensor(input_names[2]);
input_t2->Reshape({batch_size, 128, 1});
input_t2->copy_from_cpu(i2);
auto input_t3 = predictor->GetInputTensor(input_names[3]);
input_t3->Reshape({batch_size, 128, 128});
input_t3->copy_from_cpu(i3);
auto input_t4 = predictor->GetInputTensor(input_names[4]);
input_t4->Reshape({batch_size, 1});
input_t4->copy_from_cpu(i4);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
} // namespace inference
} // namespace paddle
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/operators/layer_norm_op.h"
namespace paddle { namespace paddle {
...@@ -427,6 +430,29 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -427,6 +430,29 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
} }
} }
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
const T *input,
std::vector<int> input_shape,
const T *bias, const T *scale,
T *output, T *mean, T *variance,
int begin_norm_axis, float eps) {
const auto x_dims = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end in layer_norm must be larger "
"than 1"));
break;
}
}
template <typename T> template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T> class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -512,7 +538,7 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T> ...@@ -512,7 +538,7 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
batch_size, feature_size, stream); batch_size, feature_size, stream);
} }
}; };
template class LayerNormDirectCUDAFunctor<float>;
#undef FIXED_BLOCK_DIM_CASE_BASE #undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE #undef FIXED_BLOCK_DIM_CASE
} // namespace operators } // namespace operators
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
...@@ -151,6 +153,17 @@ using Tensor = framework::Tensor; ...@@ -151,6 +153,17 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
#ifdef PADDLE_WITH_CUDA
template <typename T>
class LayerNormDirectCUDAFunctor {
public:
void operator()(cudaStream_t stream, const T* input,
std::vector<int> input_shape, const T* bias, const T* scale,
T* output, T* mean, T* variance, int begin_norm_axis,
float eps);
};
#endif
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LayerNormKernel : public framework::OpKernel<T> { class LayerNormKernel : public framework::OpKernel<T> {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册