未验证 提交 50bee83f 编写于 作者: P Pei Yang 提交者: GitHub

add TRT support for instance_norm op (#21928)

* add TRT support for instance_norm op
上级 3dbd4087
...@@ -938,6 +938,7 @@ USE_TRT_CONVERTER(conv2d_transpose); ...@@ -938,6 +938,7 @@ USE_TRT_CONVERTER(conv2d_transpose);
USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(leaky_relu);
USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(shuffle_channel);
USE_TRT_CONVERTER(swish); USE_TRT_CONVERTER(swish);
USE_TRT_CONVERTER(instance_norm);
USE_TRT_CONVERTER(layer_norm); USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu); USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul); USE_TRT_CONVERTER(multihead_matmul);
......
...@@ -3,7 +3,7 @@ nv_library(tensorrt_converter ...@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class InstanceNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid prelu op to tensorrt instance norm layer";
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]);
auto* bias_var = scope.FindVar(op_desc.Input("Bias")[0]);
PADDLE_ENFORCE_NOT_NULL(
scale_var,
platform::errors::InvalidArgument(
"Input [Scale] of instance_norm op converter should not be null"));
PADDLE_ENFORCE_NOT_NULL(
bias_var,
platform::errors::InvalidArgument(
"Input [Bias] of instance_norm op converter should not be null"));
auto* scale_tensor = scale_var->GetMutable<framework::LoDTensor>();
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
scale_tensor->numel(), bias_tensor->numel(),
platform::errors::InvalidArgument(
"Num of input [Scale] and [Bias] of instance_norm op converter "
"should be equal. Got Scale num = %ld, but Bias num = %ld",
scale_tensor->numel(), bias_tensor->numel()));
auto* scale_d = scale_tensor->data<float>();
auto* bias_d = bias_tensor->data<float>();
std::vector<float> scale_v;
std::vector<float> bias_v;
for (int i = 0; i < scale_tensor->numel(); i++) {
scale_v.push_back(scale_d[i]);
bias_v.push_back(bias_d[i]);
}
plugin::InstanceNormPlugin* plugin =
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
plugin->getPluginType();
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(instance_norm, InstanceNormOpConverter);
...@@ -32,7 +32,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -32,7 +32,8 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
private: private:
std::unordered_set<std::string> teller_set{{"mul", std::unordered_set<std::string> teller_set{{
"mul",
"conv2d", "conv2d",
"pool2d", "pool2d",
"relu", "relu",
...@@ -53,9 +54,11 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -53,9 +54,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"shuffle_channel", "shuffle_channel",
"swish", "swish",
"split", "split",
"instance_norm",
"gelu", "gelu",
"layer_norm", "layer_norm",
"multihead_matmul"}}; "multihead_matmul",
}};
}; };
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
......
nv_library(tensorrt_plugin nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu
DEPS enforce tensorrt_engine prelu) DEPS enforce tensorrt_engine prelu tensor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
cudnnDataType_t *cudnn_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*cudnn_dtype = CUDNN_DATA_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*cudnn_dtype = CUDNN_DATA_HALF;
break;
default:
return CUDNN_STATUS_BAD_PARAM;
}
return CUDNN_STATUS_SUCCESS;
}
InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer,
size_t length) {
return new InstanceNormPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("instance_norm_plugin",
CreateInstanceNormPluginDeserialize);
int InstanceNormPlugin::initialize() {
platform::dynload::cudnnCreate(&handle_);
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&b_desc_);
return 0;
}
nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
int index, const nvinfer1::Dims *inputDims, int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs,
void **outputs, void *workspace,
cudaStream_t stream) {
const auto &input_dims = this->getInputDims(0);
PADDLE_ENFORCE_EQ(input_dims.nbDims, 3,
platform::errors::InvalidArgument(
"Input Dims should be 3 (except the batch), got %d",
input_dims.nbDims));
int n = batch_size;
int c = input_dims.d[0];
int h = input_dims.d[1];
int w = input_dims.d[2];
scale_t.Resize(framework::make_ddim({batch_size, c}));
bias_t.Resize(framework::make_ddim({batch_size, c}));
int device_id;
cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
for (int i = 0; i < batch_size; i++) {
cudaMemcpyAsync(scale_d + i * c, scale_.data(), sizeof(float) * c,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(bias_d + i * c, bias_.data(), sizeof(float) * c,
cudaMemcpyHostToDevice, stream);
}
platform::dynload::cudnnSetTensor4dDescriptor(
b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1);
cudnnDataType_t cudnn_dtype;
nvinfer1::DataType data_type = getDataType();
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
platform::dynload::cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
cudnn_dtype, 1, n * c, h, w);
platform::dynload::cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW,
cudnn_dtype, 1, n * c, h, w);
float alpha = 1;
float beta = 0;
platform::dynload::cudnnSetStream(handle_, stream);
void const *x_ptr = inputs[0];
void *y_ptr = outputs[0];
platform::dynload::cudnnBatchNormalizationForwardTraining(
handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, x_desc_,
x_ptr, y_desc_, y_ptr, b_desc_, scale_d, bias_d, 1., nullptr, nullptr,
eps_, nullptr, nullptr);
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class InstanceNormPlugin : public PluginTensorRT {
private:
float eps_;
std::vector<float> scale_;
std::vector<float> bias_;
framework::Tensor scale_t;
framework::Tensor bias_t;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_;
protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(eps_) +
SerializedSize(scale_) + SerializedSize(bias_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
}
public:
explicit InstanceNormPlugin(const float eps, const std::vector<float> scale,
const std::vector<float> bias)
: eps_(eps), scale_(scale), bias_(bias) {
PADDLE_ENFORCE_EQ(scale.size(), bias.size(),
platform::errors::InvalidArgument(
"The instanceNorm's scale and bias should be the "
"same size. Got scale size = %d, but bias size = %d",
scale.size(), bias.size()));
}
// It was used for tensorrt deserialization.
// It should not be called by users.
InstanceNormPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &bias_);
}
~InstanceNormPlugin() {}
int initialize() override;
InstanceNormPlugin *clone() const override {
return new InstanceNormPlugin(eps_, scale_, bias_);
}
const char *getPluginType() const override { return "instance_norm_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
}
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -321,6 +321,10 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -321,6 +321,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL}) if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL})
inference_download_and_uncompress(${TEST_SPLIT_CONVERTER_MODEL} ${INFERENCE_URL}/tensorrt_test "split_converter.tgz") inference_download_and_uncompress(${TEST_SPLIT_CONVERTER_MODEL} ${INFERENCE_URL}/tensorrt_test "split_converter.tgz")
endif() endif()
set(TEST_INSTANCE_NORM_MODEL "${TRT_MODEL_INSTALL_DIR}/trt_instance_norm_test")
if (NOT EXISTS ${TEST_INSTANCE_NORM_MODEL})
inference_download_and_uncompress(${TEST_INSTANCE_NORM_MODEL} ${INFERENCE_URL}/tensorrt_test "instance_norm.tgz")
endif()
inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
...@@ -342,6 +346,9 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -342,6 +346,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_split_converter_test SRCS trt_split_converter_test.cc inference_analysis_test(trt_split_converter_test SRCS trt_split_converter_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_SPLIT_CONVERTER_MODEL}/) ARGS --infer_model=${TEST_SPLIT_CONVERTER_MODEL}/)
inference_analysis_test(trt_instance_norm_test SRCS trt_instance_norm_converter_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_INSTANCE_NORM_MODEL}/)
inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(TensorRT, instance_norm) {
std::string model_dir = FLAGS_infer_model + "/instance_norm";
AnalysisConfig config;
int batch_size = 4;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 20, batch_size, 0,
AnalysisConfig::Precision::kFloat32, false);
auto predictor = CreatePaddlePredictor(config);
int length = 4;
int input_num = batch_size * length;
float *input = new float[input_num];
memset(input, 1.0, input_num * sizeof(float));
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({batch_size, length});
input_t->copy_from_cpu(input);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册