Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
50bee83f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
50bee83f
编写于
1月 07, 2020
作者:
P
Pei Yang
提交者:
GitHub
1月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add TRT support for instance_norm op (#21928)
* add TRT support for instance_norm op
上级
3dbd4087
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
385 addition
and
27 deletion
+385
-27
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc
paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc
+75
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+27
-24
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+2
-2
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu
...luid/inference/tensorrt/plugin/instance_norm_op_plugin.cu
+119
-0
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h
...fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h
+103
-0
paddle/fluid/inference/tests/api/CMakeLists.txt
paddle/fluid/inference/tests/api/CMakeLists.txt
+7
-0
paddle/fluid/inference/tests/api/trt_instance_norm_converter_test.cc
...d/inference/tests/api/trt_instance_norm_converter_test.cc
+50
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
50bee83f
...
@@ -938,6 +938,7 @@ USE_TRT_CONVERTER(conv2d_transpose);
...
@@ -938,6 +938,7 @@ USE_TRT_CONVERTER(conv2d_transpose);
USE_TRT_CONVERTER
(
leaky_relu
);
USE_TRT_CONVERTER
(
leaky_relu
);
USE_TRT_CONVERTER
(
shuffle_channel
);
USE_TRT_CONVERTER
(
shuffle_channel
);
USE_TRT_CONVERTER
(
swish
);
USE_TRT_CONVERTER
(
swish
);
USE_TRT_CONVERTER
(
instance_norm
);
USE_TRT_CONVERTER
(
layer_norm
);
USE_TRT_CONVERTER
(
layer_norm
);
USE_TRT_CONVERTER
(
gelu
);
USE_TRT_CONVERTER
(
gelu
);
USE_TRT_CONVERTER
(
multihead_matmul
);
USE_TRT_CONVERTER
(
multihead_matmul
);
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
50bee83f
...
@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
...
@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc
shuffle_channel_op.cc swish_op.cc
instance_norm_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
...
paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc
0 → 100644
浏览文件 @
50bee83f
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
InstanceNormOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid prelu op to tensorrt instance norm layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
float
eps
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"epsilon"
));
auto
*
scale_var
=
scope
.
FindVar
(
op_desc
.
Input
(
"Scale"
)[
0
]);
auto
*
bias_var
=
scope
.
FindVar
(
op_desc
.
Input
(
"Bias"
)[
0
]);
PADDLE_ENFORCE_NOT_NULL
(
scale_var
,
platform
::
errors
::
InvalidArgument
(
"Input [Scale] of instance_norm op converter should not be null"
));
PADDLE_ENFORCE_NOT_NULL
(
bias_var
,
platform
::
errors
::
InvalidArgument
(
"Input [Bias] of instance_norm op converter should not be null"
));
auto
*
scale_tensor
=
scale_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
scale_tensor
->
numel
(),
bias_tensor
->
numel
(),
platform
::
errors
::
InvalidArgument
(
"Num of input [Scale] and [Bias] of instance_norm op converter "
"should be equal. Got Scale num = %ld, but Bias num = %ld"
,
scale_tensor
->
numel
(),
bias_tensor
->
numel
()));
auto
*
scale_d
=
scale_tensor
->
data
<
float
>
();
auto
*
bias_d
=
bias_tensor
->
data
<
float
>
();
std
::
vector
<
float
>
scale_v
;
std
::
vector
<
float
>
bias_v
;
for
(
int
i
=
0
;
i
<
scale_tensor
->
numel
();
i
++
)
{
scale_v
.
push_back
(
scale_d
[
i
]);
bias_v
.
push_back
(
bias_d
[
i
]);
}
plugin
::
InstanceNormPlugin
*
plugin
=
new
plugin
::
InstanceNormPlugin
(
eps
,
scale_v
,
bias_v
);
plugin
->
getPluginType
();
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
AddPlugin
(
&
input
,
1
,
plugin
);
auto
output_name
=
op_desc
.
Output
(
"Y"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"instance_norm"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
instance_norm
,
InstanceNormOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
50bee83f
...
@@ -32,30 +32,33 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -32,30 +32,33 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
private:
private:
std
::
unordered_set
<
std
::
string
>
teller_set
{{
"mul"
,
std
::
unordered_set
<
std
::
string
>
teller_set
{{
"conv2d"
,
"mul"
,
"pool2d"
,
"conv2d"
,
"relu"
,
"pool2d"
,
"softmax"
,
"relu"
,
"sigmoid"
,
"softmax"
,
"depthwise_conv2d"
,
"sigmoid"
,
"batch_norm"
,
"depthwise_conv2d"
,
"concat"
,
"batch_norm"
,
"tanh"
,
"concat"
,
"pad"
,
"tanh"
,
"elementwise_add"
,
"pad"
,
"elementwise_mul"
,
"elementwise_add"
,
"dropout"
,
"elementwise_mul"
,
"prelu"
,
"dropout"
,
"conv2d_transpose"
,
"prelu"
,
"leaky_relu"
,
"conv2d_transpose"
,
"fc"
,
"leaky_relu"
,
"shuffle_channel"
,
"fc"
,
"swish"
,
"shuffle_channel"
,
"split"
,
"swish"
,
"gelu"
,
"split"
,
"layer_norm"
,
"instance_norm"
,
"multihead_matmul"
}};
"gelu"
,
"layer_norm"
,
"multihead_matmul"
,
}};
};
};
bool
OpTeller
::
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
{
bool
OpTeller
::
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
{
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
50bee83f
nv_library
(
tensorrt_plugin
nv_library
(
tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu
DEPS enforce tensorrt_engine prelu
)
DEPS enforce tensorrt_engine prelu
tensor
)
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu
0 → 100644
浏览文件 @
50bee83f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
cudnnStatus_t
convert_trt2cudnn_dtype
(
nvinfer1
::
DataType
trt_dtype
,
cudnnDataType_t
*
cudnn_dtype
)
{
switch
(
trt_dtype
)
{
case
nvinfer1
::
DataType
::
kFLOAT
:
*
cudnn_dtype
=
CUDNN_DATA_FLOAT
;
break
;
case
nvinfer1
::
DataType
::
kHALF
:
*
cudnn_dtype
=
CUDNN_DATA_HALF
;
break
;
default:
return
CUDNN_STATUS_BAD_PARAM
;
}
return
CUDNN_STATUS_SUCCESS
;
}
InstanceNormPlugin
*
CreateInstanceNormPluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
InstanceNormPlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"instance_norm_plugin"
,
CreateInstanceNormPluginDeserialize
);
int
InstanceNormPlugin
::
initialize
()
{
platform
::
dynload
::
cudnnCreate
(
&
handle_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
x_desc_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
y_desc_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
b_desc_
);
return
0
;
}
nvinfer1
::
Dims
InstanceNormPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
assert
(
nbInputs
==
1
);
assert
(
index
<
this
->
getNbOutputs
());
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
return
output_dims
;
}
int
InstanceNormPlugin
::
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
const
auto
&
input_dims
=
this
->
getInputDims
(
0
);
PADDLE_ENFORCE_EQ
(
input_dims
.
nbDims
,
3
,
platform
::
errors
::
InvalidArgument
(
"Input Dims should be 3 (except the batch), got %d"
,
input_dims
.
nbDims
));
int
n
=
batch_size
;
int
c
=
input_dims
.
d
[
0
];
int
h
=
input_dims
.
d
[
1
];
int
w
=
input_dims
.
d
[
2
];
scale_t
.
Resize
(
framework
::
make_ddim
({
batch_size
,
c
}));
bias_t
.
Resize
(
framework
::
make_ddim
({
batch_size
,
c
}));
int
device_id
;
cudaGetDevice
(
&
device_id
);
float
*
scale_d
=
scale_t
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
float
*
bias_d
=
bias_t
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
cudaMemcpyAsync
(
scale_d
+
i
*
c
,
scale_
.
data
(),
sizeof
(
float
)
*
c
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
bias_d
+
i
*
c
,
bias_
.
data
(),
sizeof
(
float
)
*
c
,
cudaMemcpyHostToDevice
,
stream
);
}
platform
::
dynload
::
cudnnSetTensor4dDescriptor
(
b_desc_
,
CUDNN_TENSOR_NCHW
,
CUDNN_DATA_FLOAT
,
1
,
n
*
c
,
1
,
1
);
cudnnDataType_t
cudnn_dtype
;
nvinfer1
::
DataType
data_type
=
getDataType
();
convert_trt2cudnn_dtype
(
data_type
,
&
cudnn_dtype
);
platform
::
dynload
::
cudnnSetTensor4dDescriptor
(
x_desc_
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
1
,
n
*
c
,
h
,
w
);
platform
::
dynload
::
cudnnSetTensor4dDescriptor
(
y_desc_
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
1
,
n
*
c
,
h
,
w
);
float
alpha
=
1
;
float
beta
=
0
;
platform
::
dynload
::
cudnnSetStream
(
handle_
,
stream
);
void
const
*
x_ptr
=
inputs
[
0
];
void
*
y_ptr
=
outputs
[
0
];
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
handle_
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
,
&
alpha
,
&
beta
,
x_desc_
,
x_ptr
,
y_desc_
,
y_ptr
,
b_desc_
,
scale_d
,
bias_d
,
1.
,
nullptr
,
nullptr
,
eps_
,
nullptr
,
nullptr
);
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h
0 → 100644
浏览文件 @
50bee83f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
InstanceNormPlugin
:
public
PluginTensorRT
{
private:
float
eps_
;
std
::
vector
<
float
>
scale_
;
std
::
vector
<
float
>
bias_
;
framework
::
Tensor
scale_t
;
framework
::
Tensor
bias_t
;
cudnnHandle_t
handle_
;
cudnnTensorDescriptor_t
x_desc_
,
y_desc_
,
b_desc_
;
protected:
size_t
getSerializationSize
()
override
{
return
getBaseSerializationSize
()
+
SerializedSize
(
eps_
)
+
SerializedSize
(
scale_
)
+
SerializedSize
(
bias_
);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void
serialize
(
void
*
buffer
)
override
{
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
bias_
);
}
public:
explicit
InstanceNormPlugin
(
const
float
eps
,
const
std
::
vector
<
float
>
scale
,
const
std
::
vector
<
float
>
bias
)
:
eps_
(
eps
),
scale_
(
scale
),
bias_
(
bias
)
{
PADDLE_ENFORCE_EQ
(
scale
.
size
(),
bias
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The instanceNorm's scale and bias should be the "
"same size. Got scale size = %d, but bias size = %d"
,
scale
.
size
(),
bias
.
size
()));
}
// It was used for tensorrt deserialization.
// It should not be called by users.
InstanceNormPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
eps_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
scale_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
bias_
);
}
~
InstanceNormPlugin
()
{}
int
initialize
()
override
;
InstanceNormPlugin
*
clone
()
const
override
{
return
new
InstanceNormPlugin
(
eps_
,
scale_
,
bias_
);
}
const
char
*
getPluginType
()
const
override
{
return
"instance_norm_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
override
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kNCHW
));
}
};
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/api/CMakeLists.txt
浏览文件 @
50bee83f
...
@@ -321,6 +321,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
...
@@ -321,6 +321,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
if
(
NOT EXISTS
${
TEST_SPLIT_CONVERTER_MODEL
}
)
if
(
NOT EXISTS
${
TEST_SPLIT_CONVERTER_MODEL
}
)
inference_download_and_uncompress
(
${
TEST_SPLIT_CONVERTER_MODEL
}
${
INFERENCE_URL
}
/tensorrt_test
"split_converter.tgz"
)
inference_download_and_uncompress
(
${
TEST_SPLIT_CONVERTER_MODEL
}
${
INFERENCE_URL
}
/tensorrt_test
"split_converter.tgz"
)
endif
()
endif
()
set
(
TEST_INSTANCE_NORM_MODEL
"
${
TRT_MODEL_INSTALL_DIR
}
/trt_instance_norm_test"
)
if
(
NOT EXISTS
${
TEST_INSTANCE_NORM_MODEL
}
)
inference_download_and_uncompress
(
${
TEST_INSTANCE_NORM_MODEL
}
${
INFERENCE_URL
}
/tensorrt_test
"instance_norm.tgz"
)
endif
()
inference_analysis_test
(
trt_mobilenet_test SRCS trt_mobilenet_test.cc
inference_analysis_test
(
trt_mobilenet_test SRCS trt_mobilenet_test.cc
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
ARGS --infer_model=
${
TRT_MODEL_INSTALL_DIR
}
/trt_inference_test_models
)
ARGS --infer_model=
${
TRT_MODEL_INSTALL_DIR
}
/trt_inference_test_models
)
...
@@ -342,6 +346,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
...
@@ -342,6 +346,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test
(
trt_split_converter_test SRCS trt_split_converter_test.cc
inference_analysis_test
(
trt_split_converter_test SRCS trt_split_converter_test.cc
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
ARGS --infer_model=
${
TEST_SPLIT_CONVERTER_MODEL
}
/
)
ARGS --infer_model=
${
TEST_SPLIT_CONVERTER_MODEL
}
/
)
inference_analysis_test
(
trt_instance_norm_test SRCS trt_instance_norm_converter_test.cc
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
ARGS --infer_model=
${
TEST_INSTANCE_NORM_MODEL
}
/
)
inference_analysis_test
(
test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
inference_analysis_test
(
test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
paddle_fluid_c
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
paddle_fluid_c
ARGS --infer_model=
${
TRT_MODEL_INSTALL_DIR
}
/trt_inference_test_models
)
ARGS --infer_model=
${
TRT_MODEL_INSTALL_DIR
}
/trt_inference_test_models
)
...
...
paddle/fluid/inference/tests/api/trt_instance_norm_converter_test.cc
0 → 100644
浏览文件 @
50bee83f
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace
paddle
{
namespace
inference
{
TEST
(
TensorRT
,
instance_norm
)
{
std
::
string
model_dir
=
FLAGS_infer_model
+
"/instance_norm"
;
AnalysisConfig
config
;
int
batch_size
=
4
;
config
.
EnableUseGpu
(
100
,
0
);
config
.
SetModel
(
model_dir
);
config
.
SwitchUseFeedFetchOps
(
false
);
config
.
EnableTensorRtEngine
(
1
<<
20
,
batch_size
,
0
,
AnalysisConfig
::
Precision
::
kFloat32
,
false
);
auto
predictor
=
CreatePaddlePredictor
(
config
);
int
length
=
4
;
int
input_num
=
batch_size
*
length
;
float
*
input
=
new
float
[
input_num
];
memset
(
input
,
1.0
,
input_num
*
sizeof
(
float
));
auto
input_names
=
predictor
->
GetInputNames
();
auto
input_t
=
predictor
->
GetInputTensor
(
input_names
[
0
]);
input_t
->
Reshape
({
batch_size
,
length
});
input_t
->
copy_from_cpu
(
input
);
ASSERT_TRUE
(
predictor
->
ZeroCopyRun
());
}
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录