未验证 提交 45d1ae21 编写于 作者: S Shang Zhizhou 提交者: GitHub

add dynamic layer_norm plugin (#33293)

* add dynamic layer_norm plugin

* fix bug

* fix numpy.allclose

* fix format

* fix code style

* remove shepe in dynamic shape

* code format

* remove layer norm fp16

* fix format
上级 260f92da
......@@ -46,13 +46,6 @@ class LayerNormOpConverter : public OpConverter {
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
int input_num = 1;
for (int i = 0; i < X->getDimensions().nbDims; i++) {
input_num *= X->getDimensions().d[i];
}
std::vector<int64_t> mean_shape{input_num};
std::vector<int64_t> variance_shape{input_num};
std::unique_ptr<framework::LoDTensor> bias_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> scale_tensor(
......@@ -68,10 +61,33 @@ class LayerNormOpConverter : public OpConverter {
auto* bias_data = bias_tensor->mutable_data<float>(platform::CPUPlace());
auto* scale_data = scale_tensor->mutable_data<float>(platform::CPUPlace());
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(),
begin_norm_axis, eps, mean_shape, variance_shape);
nvinfer1::IPluginLayer* layernorm_layer = engine_->AddPlugin(&X, 1, plugin);
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
int input_num = 1;
for (int i = begin_norm_axis; i < X->getDimensions().nbDims; i++) {
input_num *= X->getDimensions().d[i];
}
std::vector<int64_t> mean_shape{input_num};
std::vector<int64_t> variance_shape{input_num};
plugin::LayerNormPluginDynamic* plugin =
new plugin::LayerNormPluginDynamic(bias_data, bias_tensor->numel(),
scale_data, scale_tensor->numel(),
begin_norm_axis, eps, mean_shape,
variance_shape);
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else {
int input_num = 1;
for (int i = begin_norm_axis - 1; i < X->getDimensions().nbDims; i++) {
input_num *= X->getDimensions().d[i];
}
std::vector<int64_t> mean_shape{input_num};
std::vector<int64_t> variance_shape{input_num};
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(),
begin_norm_axis, eps, mean_shape, variance_shape);
layernorm_layer = engine_->AddPlugin(
&X, 1, reinterpret_cast<plugin::PluginTensorRT*>(plugin));
}
auto output_name = op_desc.Output("Y").front();
engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor));
......
......@@ -703,7 +703,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
// Paddle-TRT does not support the input tensors: Shape and ShapeTensor
} else if (desc.Input("Shape").size() >= 1 ||
desc.Input("ShapeTensor").size() >= 1) {
desc.Input("ShapeTensor").size() >= 1 || with_dynamic_shape) {
return false;
} else {
std::vector<int> shape =
......
......@@ -57,8 +57,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
input_shape.push_back(input_dims.d[i]);
}
const auto input_ddim = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis - 1);
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis);
int feature_size = static_cast<int>(matrix_dim[1]);
PADDLE_ENFORCE_EQ(feature_size, scale_.size(),
platform::errors::InvalidArgument(
"scale's size should be equal to the feature_size,"
"but got feature_size:%d, scale's size:%d.",
feature_size, scale_.size()));
PADDLE_ENFORCE_EQ(feature_size, bias_.size(),
platform::errors::InvalidArgument(
"bias's size should be equal to the feature_size,"
"but got feature_size:%d, bias's size:%d.",
feature_size, bias_.size()));
scale_t.Resize(framework::make_ddim({feature_size}));
bias_t.Resize(framework::make_ddim({feature_size}));
......@@ -82,6 +92,163 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
return cudaGetLastError() != cudaSuccess;
}
nvinfer1::DimsExprs LayerNormPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputDims, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
return inputDims[0];
}
bool LayerNormPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of layernorm plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
// TODO(Shangzhizhou) FP16 support
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0,
platform::errors::InvalidArgument(
"The LayerNormPlugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int LayerNormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
const auto &input_dims = input_desc[0].dims;
int begin_norm_axis = begin_norm_axis_;
float eps = eps_;
std::vector<int> input_shape;
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
const auto input_ddim = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis);
int feature_size = static_cast<int>(matrix_dim[1]);
PADDLE_ENFORCE_EQ(feature_size, scale_.size(),
platform::errors::InvalidArgument(
"scale's size should be equal to the feature_size,"
"but got feature_size:%d, scale's size:%d.",
feature_size, scale_.size()));
PADDLE_ENFORCE_EQ(feature_size, bias_.size(),
platform::errors::InvalidArgument(
"bias's size should be equal to the feature_size,"
"but got feature_size:%d, bias's size:%d.",
feature_size, bias_.size()));
int device_id;
cudaGetDevice(&device_id);
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp32";
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
scale_t.Resize(framework::make_ddim({feature_size}));
bias_t.Resize(framework::make_ddim({feature_size}));
mean_t.Resize(framework::make_ddim(mean_shape_));
variance_t.Resize(framework::make_ddim(variance_shape_));
float *scale_d =
scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp16";
const half *input = reinterpret_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
size_t mean_shape_product = 1;
for (auto s : mean_shape_) {
mean_shape_product *= s;
}
size_t variance_shape_product = 1;
for (auto s : variance_shape_) {
variance_shape_product *= s;
}
if (!scale_gpu_half_d_) {
cudaMalloc(&scale_gpu_half_d_, feature_size * sizeof(half));
}
if (!bias_gpu_half_d_) {
cudaMalloc(&bias_gpu_half_d_, feature_size * sizeof(half));
}
if (!mean_gpu_half_d_) {
cudaMalloc(&mean_gpu_half_d_, mean_shape_product * sizeof(half));
}
if (!variance_gpu_half_d_) {
cudaMalloc(&variance_gpu_half_d_, variance_shape_product * sizeof(half));
}
half *scale_cpu_half =
static_cast<half *>(malloc(feature_size * sizeof(half)));
half *bias_cpu_half =
static_cast<half *>(malloc(feature_size * sizeof(half)));
PADDLE_ENFORCE_EQ(
scale_cpu_half && bias_cpu_half, true,
platform::errors::Unavailable("Out of memory, malloc size %d.",
feature_size * sizeof(half)));
for (int i = 0; i < feature_size; i++) {
scale_cpu_half[i] = static_cast<half>(scale_[i]);
bias_cpu_half[i] = static_cast<half>(bias_[i]);
}
cudaMemcpyAsync(scale_gpu_half_d_, scale_cpu_half,
sizeof(half) * feature_size, cudaMemcpyHostToDevice,
stream);
cudaMemcpyAsync(bias_gpu_half_d_, bias_cpu_half,
sizeof(half) * feature_size, cudaMemcpyHostToDevice,
stream);
free(scale_cpu_half);
free(bias_cpu_half);
paddle::operators::LayerNormDirectCUDAFunctor<half> layer_norm;
layer_norm(stream, input, input_shape, bias_gpu_half_d_, scale_gpu_half_d_,
output, mean_gpu_half_d_, variance_gpu_half_d_, begin_norm_axis,
eps);
#else
PADDLE_THROW(platform::errors::Fatal(
"The layer_norm tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The LayerNorm TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -50,7 +50,7 @@ class LayerNormPlugin : public PluginTensorRT {
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, bias_);
......@@ -62,7 +62,7 @@ class LayerNormPlugin : public PluginTensorRT {
}
public:
LayerNormPlugin(const float *bias, const int bias_num, const float *scale,
LayerNormPlugin(const float* bias, const int bias_num, const float* scale,
const int scale_num, int begin_norm_axis, float eps,
std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape)
......@@ -78,7 +78,7 @@ class LayerNormPlugin : public PluginTensorRT {
// It was used for tensorrt deserialization.
// It should not be called by users.
LayerNormPlugin(void const *serialData, size_t serialLength) {
LayerNormPlugin(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &scale_);
......@@ -90,20 +90,180 @@ class LayerNormPlugin : public PluginTensorRT {
~LayerNormPlugin() {}
int initialize() override;
LayerNormPlugin *clone() const override {
LayerNormPlugin* clone() const override {
return new LayerNormPlugin(bias_.data(), bias_.size(), scale_.data(),
scale_.size(), begin_norm_axis_, eps_,
mean_shape_, variance_shape_);
}
const char *getPluginType() const override { return "layer_norm_plugin"; }
const char* getPluginType() const override { return "layer_norm_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
};
class LayerNormPluginDynamic : public DynamicPluginTensorRT {
public:
LayerNormPluginDynamic(const float* bias, const int bias_num,
const float* scale, const int scale_num,
int begin_norm_axis, float eps,
std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape)
: begin_norm_axis_(begin_norm_axis),
eps_(eps),
mean_shape_(mean_shape),
variance_shape_(variance_shape),
scale_gpu_half_d_(nullptr),
bias_gpu_half_d_(nullptr),
mean_gpu_half_d_(nullptr),
variance_gpu_half_d_(nullptr) {
bias_.resize(bias_num);
scale_.resize(scale_num);
std::copy(bias, bias + bias_num, bias_.data());
std::copy(scale, scale + scale_num, scale_.data());
}
LayerNormPluginDynamic(void const* serialData, size_t serialLength)
: scale_gpu_half_d_(nullptr),
bias_gpu_half_d_(nullptr),
mean_gpu_half_d_(nullptr),
variance_gpu_half_d_(nullptr) {
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_);
}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new LayerNormPluginDynamic(bias_.data(), bias_.size(), scale_.data(),
scale_.size(), begin_norm_axis_, eps_,
mean_shape_, variance_shape_);
}
const char* getPluginType() const override { return "layernorm_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override {
return SerializedSize(bias_) + SerializedSize(scale_) +
SerializedSize(begin_norm_axis_) + SerializedSize(eps_) +
SerializedSize(mean_shape_) + SerializedSize(variance_shape_);
}
void serialize(void* buffer) const override {
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, begin_norm_axis_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
~LayerNormPluginDynamic() {
if (scale_gpu_half_d_) {
cudaFree(scale_gpu_half_d_);
}
if (bias_gpu_half_d_) {
cudaFree(bias_gpu_half_d_);
}
if (mean_gpu_half_d_) {
cudaFree(mean_gpu_half_d_);
}
if (variance_gpu_half_d_) {
cudaFree(variance_gpu_half_d_);
}
}
void destroy() override { delete this; }
private:
std::vector<float> bias_;
std::vector<float> scale_;
framework::Tensor scale_t;
framework::Tensor bias_t;
framework::Tensor mean_t;
framework::Tensor variance_t;
int begin_norm_axis_;
float eps_;
std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_;
half* scale_gpu_half_d_;
half* bias_gpu_half_d_;
half* mean_gpu_half_d_;
half* variance_gpu_half_d_;
};
class LayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
LayerNormPluginDynamicCreator() {}
const char* getPluginName() const override { return "layernorm_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new LayerNormPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(LayerNormPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -209,6 +209,73 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
}
}
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForwardFP16(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int feature_size) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U mean_share;
__shared__ U var_share;
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var
U mean_val = 0;
U var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
U tmp = static_cast<U>(x[i]);
mean_val += tmp;
var_val += (tmp * tmp);
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<U>(mean_val, var_val),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / static_cast<U>(feature_size);
mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var[blockIdx.x] = var_share =
static_cast<U>(pair.second_ / static_cast<U>(feature_size) - tmp * tmp);
}
__syncthreads();
mean_val = mean_share;
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y
if (scale != nullptr) {
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
}
}
} else { // scale == nullptr
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
#endif
}
template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
......@@ -872,6 +939,28 @@ void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
}
}
template <>
void LayerNormDirectCUDAFunctor<half>::operator()(
gpuStream_t stream, const half *input, std::vector<int> input_shape,
const half *bias, const half *scale, half *output, half *mean,
half *variance, int begin_norm_axis, float eps) {
const auto x_dims = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForwardFP16<half, half,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end in layer_norm must be larger "
"than 1"));
break;
}
}
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -961,6 +1050,9 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
};
template class LayerNormDirectCUDAFunctor<float>;
#ifdef TRT_PLUGIN_FP16_AVALIABLE
template class LayerNormDirectCUDAFunctor<half>;
#endif
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
......
......@@ -511,6 +511,7 @@ void BindAnalysisConfig(py::module *m) {
py::arg("disable_trt_plugin_fp16") = false)
.def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS)
.def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled)
.def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs)
.def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA,
py::arg("dla_core") = 0)
.def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled)
......
......@@ -160,7 +160,8 @@ class InferencePassTest(unittest.TestCase):
use_gpu,
atol=1e-5,
flatten=False,
quant=False):
quant=False,
rtol=1e-5):
'''
Check whether calculating on CPU and GPU, enable TensorRT
or disable TensorRT, enable MKLDNN or disable MKLDNN
......@@ -260,7 +261,7 @@ class InferencePassTest(unittest.TestCase):
self.assertTrue(
np.allclose(
out, tensorrt_output, atol=atol),
out, tensorrt_output, rtol=rtol, atol=atol),
"Output has diff between GPU and TensorRT. ")
# Check whether the mkldnn results and the CPU results are the same.
......
......@@ -366,6 +366,61 @@ class TensorRTSubgraphPassLayerNormTest(InferencePassTest):
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassLayerNormDynamicTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
out = fluid.layers.layer_norm(
data, begin_norm_axis=self.begin_norm_axis)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.set_trt_params()
self.fetch_list = [out]
def set_trt_params(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassLayerNormDynamicTest.TensorRTParam(
1 << 30, 32, 0, self.precision, self.serialize, False)
self.dynamic_shape_params = TensorRTSubgraphPassLayerNormDynamicTest.DynamicShapeParam(
{
'data': [1, 3, 64, 64],
}, {'data': [8, 8, 64, 64], }, {'data': [4, 4, 64, 64], }, False)
def set_params(self):
self.begin_norm_axis = 2
self.precision = AnalysisConfig.Precision.Float32
self.serialize = True
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassLayerNormDynamicFP16Test(
TensorRTSubgraphPassLayerNormDynamicTest):
def set_params(self):
self.begin_norm_axis = 2
self.precision = AnalysisConfig.Precision.Half
self.serialize = True
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, atol=0.01, rtol=0.01)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassLayerNormBeginNormAxis2Test(
TensorRTSubgraphPassLayerNormTest):
def set_params(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册