未验证 提交 50bee83f 编写于 作者: P Pei Yang 提交者: GitHub

add TRT support for instance_norm op (#21928)

* add TRT support for instance_norm op
上级 3dbd4087
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 improve_sccache incubate/infrt inplace_addto make_flag_adding_easier move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0
无相关合并请求
......@@ -938,6 +938,7 @@ USE_TRT_CONVERTER(conv2d_transpose);
USE_TRT_CONVERTER(leaky_relu);
USE_TRT_CONVERTER(shuffle_channel);
USE_TRT_CONVERTER(swish);
USE_TRT_CONVERTER(instance_norm);
USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
......
......@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class InstanceNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid prelu op to tensorrt instance norm layer";
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]);
auto* bias_var = scope.FindVar(op_desc.Input("Bias")[0]);
PADDLE_ENFORCE_NOT_NULL(
scale_var,
platform::errors::InvalidArgument(
"Input [Scale] of instance_norm op converter should not be null"));
PADDLE_ENFORCE_NOT_NULL(
bias_var,
platform::errors::InvalidArgument(
"Input [Bias] of instance_norm op converter should not be null"));
auto* scale_tensor = scale_var->GetMutable<framework::LoDTensor>();
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
scale_tensor->numel(), bias_tensor->numel(),
platform::errors::InvalidArgument(
"Num of input [Scale] and [Bias] of instance_norm op converter "
"should be equal. Got Scale num = %ld, but Bias num = %ld",
scale_tensor->numel(), bias_tensor->numel()));
auto* scale_d = scale_tensor->data<float>();
auto* bias_d = bias_tensor->data<float>();
std::vector<float> scale_v;
std::vector<float> bias_v;
for (int i = 0; i < scale_tensor->numel(); i++) {
scale_v.push_back(scale_d[i]);
bias_v.push_back(bias_d[i]);
}
plugin::InstanceNormPlugin* plugin =
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
plugin->getPluginType();
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(instance_norm, InstanceNormOpConverter);
......@@ -32,30 +32,33 @@ struct SimpleOpTypeSetTeller : public Teller {
}
private:
std::unordered_set<std::string> teller_set{{"mul",
"conv2d",
"pool2d",
"relu",
"softmax",
"sigmoid",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_mul",
"dropout",
"prelu",
"conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"split",
"gelu",
"layer_norm",
"multihead_matmul"}};
std::unordered_set<std::string> teller_set{{
"mul",
"conv2d",
"pool2d",
"relu",
"softmax",
"sigmoid",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_mul",
"dropout",
"prelu",
"conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"split",
"instance_norm",
"gelu",
"layer_norm",
"multihead_matmul",
}};
};
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
DEPS enforce tensorrt_engine prelu)
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
cudnnDataType_t *cudnn_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*cudnn_dtype = CUDNN_DATA_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*cudnn_dtype = CUDNN_DATA_HALF;
break;
default:
return CUDNN_STATUS_BAD_PARAM;
}
return CUDNN_STATUS_SUCCESS;
}
InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer,
size_t length) {
return new InstanceNormPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("instance_norm_plugin",
CreateInstanceNormPluginDeserialize);
int InstanceNormPlugin::initialize() {
platform::dynload::cudnnCreate(&handle_);
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&b_desc_);
return 0;
}
nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
int index, const nvinfer1::Dims *inputDims, int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs,
void **outputs, void *workspace,
cudaStream_t stream) {
const auto &input_dims = this->getInputDims(0);
PADDLE_ENFORCE_EQ(input_dims.nbDims, 3,
platform::errors::InvalidArgument(
"Input Dims should be 3 (except the batch), got %d",
input_dims.nbDims));
int n = batch_size;
int c = input_dims.d[0];
int h = input_dims.d[1];
int w = input_dims.d[2];
scale_t.Resize(framework::make_ddim({batch_size, c}));
bias_t.Resize(framework::make_ddim({batch_size, c}));
int device_id;
cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
for (int i = 0; i < batch_size; i++) {
cudaMemcpyAsync(scale_d + i * c, scale_.data(), sizeof(float) * c,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(bias_d + i * c, bias_.data(), sizeof(float) * c,
cudaMemcpyHostToDevice, stream);
}
platform::dynload::cudnnSetTensor4dDescriptor(
b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1);
cudnnDataType_t cudnn_dtype;
nvinfer1::DataType data_type = getDataType();
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
platform::dynload::cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
cudnn_dtype, 1, n * c, h, w);
platform::dynload::cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW,
cudnn_dtype, 1, n * c, h, w);
float alpha = 1;
float beta = 0;
platform::dynload::cudnnSetStream(handle_, stream);
void const *x_ptr = inputs[0];
void *y_ptr = outputs[0];
platform::dynload::cudnnBatchNormalizationForwardTraining(
handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, x_desc_,
x_ptr, y_desc_, y_ptr, b_desc_, scale_d, bias_d, 1., nullptr, nullptr,
eps_, nullptr, nullptr);
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class InstanceNormPlugin : public PluginTensorRT {
private:
float eps_;
std::vector<float> scale_;
std::vector<float> bias_;
framework::Tensor scale_t;
framework::Tensor bias_t;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_;
protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(eps_) +
SerializedSize(scale_) + SerializedSize(bias_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
}
public:
explicit InstanceNormPlugin(const float eps, const std::vector<float> scale,
const std::vector<float> bias)
: eps_(eps), scale_(scale), bias_(bias) {
PADDLE_ENFORCE_EQ(scale.size(), bias.size(),
platform::errors::InvalidArgument(
"The instanceNorm's scale and bias should be the "
"same size. Got scale size = %d, but bias size = %d",
scale.size(), bias.size()));
}
// It was used for tensorrt deserialization.
// It should not be called by users.
InstanceNormPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &bias_);
}
~InstanceNormPlugin() {}
int initialize() override;
InstanceNormPlugin *clone() const override {
return new InstanceNormPlugin(eps_, scale_, bias_);
}
const char *getPluginType() const override { return "instance_norm_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
}
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -321,6 +321,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL})
inference_download_and_uncompress(${TEST_SPLIT_CONVERTER_MODEL} ${INFERENCE_URL}/tensorrt_test "split_converter.tgz")
endif()
set(TEST_INSTANCE_NORM_MODEL "${TRT_MODEL_INSTALL_DIR}/trt_instance_norm_test")
if (NOT EXISTS ${TEST_INSTANCE_NORM_MODEL})
inference_download_and_uncompress(${TEST_INSTANCE_NORM_MODEL} ${INFERENCE_URL}/tensorrt_test "instance_norm.tgz")
endif()
inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
......@@ -342,6 +346,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_split_converter_test SRCS trt_split_converter_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_SPLIT_CONVERTER_MODEL}/)
inference_analysis_test(trt_instance_norm_test SRCS trt_instance_norm_converter_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_INSTANCE_NORM_MODEL}/)
inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(TensorRT, instance_norm) {
std::string model_dir = FLAGS_infer_model + "/instance_norm";
AnalysisConfig config;
int batch_size = 4;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 20, batch_size, 0,
AnalysisConfig::Precision::kFloat32, false);
auto predictor = CreatePaddlePredictor(config);
int length = 4;
int input_num = batch_size * length;
float *input = new float[input_num];
memset(input, 1.0, input_num * sizeof(float));
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({batch_size, length});
input_t->copy_from_cpu(input);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部