Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
296b64ac
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
296b64ac
编写于
2月 12, 2023
作者:
W
wangruting
浏览文件
操作
浏览文件
下载
差异文件
fix_conflict
上级
6d73091e
648cb508
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
601 addition
and
54 deletion
+601
-54
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+34
-14
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+22
-7
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+37
-5
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+1
-4
paddle/phi/kernels/selected_rows/full_kernel.cc
paddle/phi/kernels/selected_rows/full_kernel.cc
+14
-0
paddle/phi/kernels/xpu/full_kernel.cc
paddle/phi/kernels/xpu/full_kernel.cc
+22
-16
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_inference_predictor.py
...ts/unittests/ir/inference/test_trt_inference_predictor.py
+399
-0
python/paddle/incubate/autograd/composite_rules.py
python/paddle/incubate/autograd/composite_rules.py
+2
-2
python/paddle/incubate/autograd/generate_op_map.py
python/paddle/incubate/autograd/generate_op_map.py
+1
-1
python/paddle/incubate/autograd/primx.py
python/paddle/incubate/autograd/primx.py
+25
-2
python/paddle/incubate/autograd/utils.py
python/paddle/incubate/autograd/utils.py
+43
-3
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
296b64ac
...
...
@@ -62,6 +62,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/utils/string/split.h"
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
...
...
@@ -1890,16 +1891,16 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
void
AnalysisPredictor
::
CollectShapeRangeInfo
()
{
// if use gpu, sync first.
if
(
config_
.
use_gpu
())
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
auto
gpu_place
=
place_
;
auto
*
dev_ctx
=
static_cast
<
const
phi
::
GPUContext
*>
(
pool
.
Get
(
gpu_place
));
if
(
config_
.
use_gpu
())
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto
*
dev_ctx
=
pool
.
Get
(
place_
);
auto
stream
=
static_cast
<
phi
::
GPUContext
*>
(
dev_ctx
)
->
stream
();
#ifdef PADDLE_WITH_HIP
hipStreamSynchronize
(
dev_ctx
->
stream
()
);
hipStreamSynchronize
(
stream
);
#else
cudaStreamSynchronize
(
dev_ctx
->
stream
()
);
cudaStreamSynchronize
(
stream
);
#endif
#endif
}
...
...
@@ -1911,6 +1912,7 @@ void AnalysisPredictor::CollectShapeRangeInfo() {
continue
;
}
auto
tensor
=
var
->
Get
<
phi
::
DenseTensor
>
();
if
(
!
tensor
.
initialized
())
continue
;
framework
::
DDim
dim
=
tensor
.
dims
();
std
::
vector
<
int32_t
>
shape
(
dim
.
size
());
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
shape
[
i
]
=
dim
[
i
];
...
...
@@ -1922,22 +1924,40 @@ void AnalysisPredictor::CollectShapeRangeInfo() {
// This is a simple method to identify all shape tensors with some
// mistakes, but it doesn't matter.
auto
is_shape_tensor
=
tensor
.
numel
()
<=
7
&&
tensor
.
numel
()
>=
1
;
if
(
tensor
.
dtype
()
==
paddle
::
experimental
::
DataType
::
INT32
&&
if
((
tensor
.
dtype
()
==
phi
::
DataType
::
INT32
||
tensor
.
dtype
()
==
phi
::
DataType
::
INT64
)
&&
is_shape_tensor
)
{
std
::
vector
<
int
>
int32_host
(
tensor
.
numel
());
if
(
tensor
.
place
()
==
platform
::
CPUPlace
())
{
if
(
platform
::
is_cpu_place
(
tensor
.
place
()))
{
auto
&
int32_tensor
=
tensor
;
if
(
tensor
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
auto
*
cpu_ctx
=
pool
.
Get
(
platform
::
CPUPlace
());
int32_tensor
=
phi
::
funcs
::
TransDataType
(
reinterpret_cast
<
const
phi
::
CPUContext
&>
(
*
cpu_ctx
),
tensor
,
DataType
::
INT32
);
}
paddle
::
memory
::
Copy
(
platform
::
CPUPlace
(),
int32_host
.
data
(),
platform
::
CPUPlace
(),
tensor
.
data
<
int
>
(),
tensor
.
numel
()
*
sizeof
(
int
));
}
else
if
(
tensor
.
place
()
==
platform
::
CUDAPlace
(
))
{
int32_
tensor
.
data
<
int
>
(),
int32_
tensor
.
numel
()
*
sizeof
(
int
));
}
else
if
(
platform
::
is_gpu_place
(
tensor
.
place
()
))
{
#if defined(PADDLE_WITH_CUDA)
auto
*
dev_ctx
=
pool
.
Get
(
tensor
.
place
());
auto
&
int32_tensor
=
tensor
;
if
(
tensor
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
int32_tensor
=
phi
::
funcs
::
TransDataType
(
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
*
dev_ctx
),
tensor
,
DataType
::
INT32
);
}
paddle
::
memory
::
Copy
(
platform
::
CPUPlace
(),
int32_host
.
data
(),
platform
::
CUDAP
lace
(),
tensor
.
data
<
int
>
(),
tensor
.
numel
()
*
sizeof
(
int
),
int32_tensor
.
p
lace
(),
int32_
tensor
.
data
<
int
>
(),
int32_
tensor
.
numel
()
*
sizeof
(
int
),
nullptr
);
#endif
}
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
296b64ac
...
...
@@ -544,6 +544,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
"index=%d >= total inputs and outputs=%d"
,
bind_index
,
num_bindings
));
auto
type
=
framework
::
TransToProtoVarType
(
t
.
dtype
());
if
(
!
engine
->
with_dynamic_shape
())
{
// check if the input shapes are consistent with model.
if
(
HasAttr
(
x
+
"_shape"
))
{
...
...
@@ -586,12 +587,27 @@ class TensorRTEngineOp : public framework::OperatorBase {
if
(
engine
->
engine
()
->
isShapeBinding
(
bind_index
)
&&
engine
->
engine
()
->
bindingIsInput
(
bind_index
))
{
std
::
vector
<
int
>
shape_v
(
t
.
numel
());
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
paddle
::
memory
::
Copy
(
platform
::
CPUPlace
(),
shape_v
.
data
(),
platform
::
CUDAP
lace
(),
t
.
p
lace
(),
t
.
data
<
int32_t
>
(),
t
.
numel
()
*
sizeof
(
int
),
nullptr
);
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
auto
int32_tensor
=
scope
.
FindVar
(
x
+
"_cast_to_INT32"
)
->
GetMutable
<
phi
::
DenseTensor
>
();
*
int32_tensor
=
phi
::
Cast
<
int64_t
>
(
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
dev_ctx
),
t
,
phi
::
DataType
::
INT32
);
paddle
::
memory
::
Copy
(
platform
::
CPUPlace
(),
shape_v
.
data
(),
int32_tensor
->
place
(),
int32_tensor
->
data
<
int32_t
>
(),
int32_tensor
->
numel
()
*
sizeof
(
int
),
nullptr
);
}
trt_context
->
setInputShapeBinding
(
bind_index
,
shape_v
.
data
());
}
#endif
...
...
@@ -608,7 +624,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
"The TRT Engine OP's input type should equal "
"to the input data type"
));
auto
type
=
framework
::
TransToProtoVarType
(
t
.
dtype
());
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
...
...
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
296b64ac
...
...
@@ -145,6 +145,13 @@
variance
:
Variance
scale
:
Scale
bias
:
Bias
outputs
:
out
:
Y
mean_out
:
MeanOut
variance_out
:
VarianceOut
saved_mean
:
SavedMean
saved_variance
:
SavedVariance
reserve_space
:
ReserveSpace
extra
:
attrs
:
[
bool use_mkldnn = false
,
bool fuse_with_relu = false
]
...
...
@@ -407,6 +414,17 @@
-
op
:
dropout
backward
:
dropout_grad
inputs
:
x
:
X
outputs
:
out
:
Out
mask
:
Mask
attrs
:
p
:
dropout_prob
is_test
:
is_test
mode
:
dropout_implementation
seed
:
seed
fix_seed
:
fix_seed
extra
:
attrs
:
[
bool fix_seed = false
,
int seed = 0
]
...
...
@@ -783,6 +801,14 @@
-
op
:
layer_norm
backward
:
layer_norm_grad
inputs
:
x
:
X
scale
:
Scale
bias
:
Bias
outputs
:
out
:
Y
mean
:
Mean
variance
:
Variance
extra
:
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
,
bool is_test = false
]
...
...
@@ -933,6 +959,17 @@
outputs
:
out
:
Out
-
op
:
mean (reduce_mean)
backward
:
reduce_mean_grad
inputs
:
x
:
X
outputs
:
out
:
Out
attrs
:
{
axis
:
dim
,
keepdim
:
keep_dim
}
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
meshgrid
backward
:
meshgrid_grad
inputs
:
...
...
@@ -1138,11 +1175,6 @@
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
reduce_mean
backward
:
reduce_mean_grad
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
reduce_min
backward
:
reduce_min_grad
extra
:
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
296b64ac
...
...
@@ -248,11 +248,8 @@ XPUOpMap& get_kl2_ops() {
phi
::
DataType
::
INT16
,
phi
::
DataType
::
UINT8
,
phi
::
DataType
::
BOOL
,
phi
::
DataType
::
FLOAT64
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
COMPLEX64
,
phi
::
DataType
::
COMPLEX128
})},
phi
::
DataType
::
FLOAT16
})},
{
"flatten2_grad"
,
XPUKernelSet
({
phi
::
DataType
::
INT64
,
phi
::
DataType
::
INT32
,
...
...
paddle/phi/kernels/selected_rows/full_kernel.cc
浏览文件 @
296b64ac
...
...
@@ -70,3 +70,17 @@ PD_REGISTER_KERNEL(full_sr,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
#endif
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
full_sr
,
XPU
,
ALL_LAYOUT
,
phi
::
sr
::
FullKernel
,
float
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
float16
)
{}
#endif
paddle/phi/kernels/xpu/full_kernel.cc
浏览文件 @
296b64ac
...
...
@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
...
...
@@ -59,8 +60,19 @@ void FullKernel(const Context& dev_ctx,
const
Scalar
&
val
,
DataType
dtype
,
DenseTensor
*
out
)
{
using
XPUInTDType
=
typename
XPUTypeTrait
<
T
>::
Type
;
out
->
Resize
(
phi
::
make_ddim
(
shape
.
GetData
()));
FullValueXPU
<
T
>
(
dev_ctx
,
out
,
val
.
to
<
T
>
());
int
numel
=
out
->
numel
();
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
value
=
val
.
to
<
double
>
();
auto
out_data
=
reinterpret_cast
<
XPUInTDType
*>
(
out
->
data
<
T
>
());
if
(
numel
>
0
)
{
int
r
=
xpu
::
constant
(
dev_ctx
.
x_context
(),
out_data
,
out
->
numel
(),
static_cast
<
XPUInTDType
>
(
value
));
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"constant"
);
}
}
template
<
typename
T
,
typename
Context
>
...
...
@@ -103,16 +115,11 @@ void FullLikeKernel(const Context& dev_ctx,
phi
::
errors
::
InvalidArgument
(
"The filled value is Inf."
));
auto
out_data
=
reinterpret_cast
<
XPUInTDType
*>
(
out
->
data
<
T
>
());
int
r
et
=
xpu
::
constant
(
dev_ctx
.
x_context
(),
int
r
=
xpu
::
constant
(
dev_ctx
.
x_context
(),
out_data
,
out
->
numel
(),
static_cast
<
XPUInTDType
>
(
value
));
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
phi
::
errors
::
External
(
"XPU CONSTANT API return wrong value[%d %s]."
,
ret
,
XPUAPIErrorMsg
[
ret
]));
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"constant"
);
}
}
// namespace phi
...
...
@@ -122,24 +129,23 @@ PD_REGISTER_KERNEL(full,
ALL_LAYOUT
,
phi
::
FullKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
full_like
,
XPU
,
ALL_LAYOUT
,
phi
::
FullLikeKernel
,
float
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
296b64ac
...
...
@@ -135,6 +135,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
#set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200)
set_tests_properties
(
test_trt_dynamic_shape PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_trt_inspector PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_trt_inference_predictor PROPERTIES TIMEOUT 60
)
if
(
WITH_NV_JETSON
)
set_tests_properties
(
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_inference_predictor.py
0 → 100644
浏览文件 @
296b64ac
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
sys
import
tempfile
import
unittest
import
numpy
as
np
import
yaml
import
paddle
import
paddle.nn
as
nn
try
:
import
paddle.inference
as
paddle_infer
except
Exception
as
e
:
sys
.
stderr
.
write
(
"Cannot import paddle, maybe paddle is not installed.
\n
"
)
paddle
.
set_device
(
'cpu'
)
paddle
.
disable_signal_handler
()
def
str2bool
(
v
):
if
v
.
lower
()
==
'true'
:
return
True
else
:
return
False
def
getdtype
(
dtype
=
"float32"
):
if
dtype
==
"float32"
or
dtype
==
"float"
:
return
np
.
float32
if
dtype
==
"float16"
:
return
np
.
float16
if
dtype
==
"float64"
:
return
np
.
float64
if
dtype
==
"int32"
:
return
np
.
int32
if
dtype
==
"int64"
:
return
np
.
int64
class
BackendPaddle
:
def
__init__
(
self
):
super
(
BackendPaddle
,
self
).
__init__
()
self
.
h2d_time
=
[]
self
.
compute_time
=
[]
self
.
d2h_time
=
[]
def
version
(
self
):
return
paddle
.
version
.
full_version
def
name
(
self
):
return
"paddle"
def
load
(
self
,
config_arg
,
inputs
=
None
,
outpus
=
None
):
self
.
args
=
config_arg
if
os
.
path
.
exists
(
self
.
args
.
model_dir
):
model_file
=
os
.
path
.
join
(
self
.
args
.
model_dir
+
"/"
+
self
.
args
.
paddle_model_file
)
model_params
=
os
.
path
.
join
(
self
.
args
.
model_dir
+
"/"
+
self
.
args
.
paddle_params_file
)
config
=
paddle_infer
.
Config
(
model_file
,
model_params
)
else
:
raise
ValueError
(
f
"The model dir
{
self
.
args
.
model_dir
}
does not exists!"
)
# enable memory optim
if
not
self
.
args
.
enable_tune
:
config
.
enable_memory_optim
()
config
.
set_cpu_math_library_num_threads
(
self
.
args
.
cpu_threads
)
config
.
switch_ir_optim
(
True
)
# debug
if
self
.
args
.
enable_debug
:
config
.
switch_ir_debug
()
precision_mode
=
paddle_infer
.
PrecisionType
.
Float32
if
self
.
args
.
precision
==
'fp16'
:
precision_mode
=
paddle_infer
.
PrecisionType
.
Half
elif
self
.
args
.
precision
==
'int8'
:
precision_mode
=
paddle_infer
.
PrecisionType
.
Int8
if
self
.
args
.
enable_mkldnn
and
not
self
.
args
.
enable_gpu
:
config
.
disable_gpu
()
config
.
enable_mkldnn
()
if
self
.
args
.
precision
==
'int8'
:
config
.
enable_mkldnn_int8
(
{
"conv2d"
,
"depthwise_conv2d"
,
"transpose2"
,
"pool2d"
}
)
if
not
self
.
args
.
enable_mkldnn
and
not
self
.
args
.
enable_gpu
:
config
.
disable_gpu
()
# config.enable_mkldnn()
if
self
.
args
.
enable_profile
:
config
.
enable_profile
()
shape_range_file
=
os
.
path
.
join
(
self
.
args
.
model_dir
,
self
.
args
.
shape_range_file
)
if
self
.
args
.
enable_tune
:
config
.
collect_shape_range_info
(
shape_range_file
)
if
self
.
args
.
enable_gpu
:
config
.
enable_use_gpu
(
256
,
self
.
args
.
gpu_id
)
if
self
.
args
.
enable_trt
:
max_batch_size
=
self
.
args
.
batch_size
if
(
self
.
args
.
yaml_config
[
"input_shape"
][
"0"
][
"shape"
][
self
.
args
.
test_num
][
0
]
!=
-
1
):
max_batch_size
=
self
.
args
.
yaml_config
[
"input_shape"
][
"0"
][
"shape"
][
self
.
args
.
test_num
][
0
]
config
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
33
,
precision_mode
=
precision_mode
,
max_batch_size
=
max_batch_size
,
min_subgraph_size
=
self
.
args
.
subgraph_size
,
use_static
=
False
,
use_calib_mode
=
False
if
self
.
args
.
precision
==
'int8'
else
False
,
)
if
self
.
args
.
enable_dynamic_shape
:
if
os
.
path
.
exists
(
shape_range_file
):
config
.
enable_tuned_tensorrt_dynamic_shape
(
shape_range_file
,
True
)
config
.
disable_glog_info
()
config
.
exp_disable_tensorrt_ops
([
"range"
])
self
.
predictor
=
paddle_infer
.
create_predictor
(
config
)
input_shape
=
self
.
args
.
yaml_config
[
"input_shape"
]
if
len
(
input_shape
)
<=
0
:
raise
Exception
(
"input shape is empty."
)
if
"input_data"
in
self
.
args
.
yaml_config
:
input_file
=
self
.
args
.
yaml_config
[
"input_data"
][
"data"
][
self
.
args
.
test_num
]
self
.
numpy_input
=
np
.
load
(
input_file
,
allow_pickle
=
True
)
return
self
def
set_input
(
self
):
# set input tensor
input_names
=
self
.
predictor
.
get_input_names
()
for
i
,
name
in
enumerate
(
input_names
):
input_tensor
=
self
.
predictor
.
get_input_handle
(
name
)
if
"input_data"
not
in
self
.
args
.
yaml_config
:
if
(
self
.
args
.
yaml_config
[
"input_shape"
][
str
(
i
)][
"shape"
][
self
.
args
.
test_num
][
0
]
==
-
1
):
input_shape
=
[
self
.
args
.
batch_size
]
+
self
.
args
.
yaml_config
[
"input_shape"
][
str
(
i
)][
"shape"
][
self
.
args
.
test_num
][
1
:
]
dtype
=
self
.
args
.
yaml_config
[
"input_shape"
][
str
(
i
)][
"dtype"
][
self
.
args
.
test_num
]
else
:
input_shape
=
self
.
args
.
yaml_config
[
"input_shape"
][
str
(
i
)][
"shape"
][
self
.
args
.
test_num
]
dtype
=
self
.
args
.
yaml_config
[
"input_shape"
][
str
(
i
)][
"dtype"
][
self
.
args
.
test_num
]
if
hasattr
(
self
.
args
,
"test_data"
):
fake_input
=
self
.
args
.
test_data
[
i
].
astype
(
getdtype
(
dtype
))
else
:
fake_input
=
np
.
ones
(
input_shape
,
dtype
=
getdtype
(
dtype
))
input_tensor
.
copy_from_cpu
(
fake_input
)
else
:
real_input
=
np
.
expand_dims
(
self
.
numpy_input
[
i
],
0
).
repeat
(
self
.
args
.
batch_size
,
axis
=
0
)
input_tensor
.
copy_from_cpu
(
real_input
)
def
set_output
(
self
):
results
=
[]
# get out data from output tensor
output_names
=
self
.
predictor
.
get_output_names
()
for
i
,
name
in
enumerate
(
output_names
):
output_tensor
=
self
.
predictor
.
get_output_handle
(
name
)
output_data
=
output_tensor
.
copy_to_cpu
()
if
self
.
args
.
return_result
or
self
.
args
.
save_result
:
results
.
append
(
output_data
)
if
self
.
args
.
return_result
or
self
.
args
.
save_result
:
return
results
def
reset
(
self
):
self
.
h2d_time
.
clear
()
self
.
d2h_time
.
clear
()
self
.
compute_time
.
clear
()
def
warmup
(
self
):
pass
def
predict
(
self
,
feed
=
None
):
self
.
set_input
()
self
.
predictor
.
run
()
output
=
self
.
set_output
()
if
self
.
args
.
return_result
or
self
.
args
.
save_result
:
return
output
def
predict_nocopy
(
self
,
feed
=
None
):
self
.
predictor
.
run
()
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--cpu_threads'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--inter_op_threads'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--precision'
,
type
=
str
,
choices
=
[
"fp32"
,
"fp16"
,
"int8"
]
)
parser
.
add_argument
(
'--backend_type'
,
type
=
str
,
choices
=
[
"paddle"
,
"onnxruntime"
,
"openvino"
,
"tensorrt"
],
default
=
"paddle"
,
)
parser
.
add_argument
(
'--gpu_id'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--subgraph_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
)
parser
.
add_argument
(
'--paddle_model_file'
,
type
=
str
,
default
=
"model.pdmodel"
)
parser
.
add_argument
(
'--paddle_params_file'
,
type
=
str
,
default
=
"model.pdiparams"
)
parser
.
add_argument
(
'--enable_mkldnn'
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
'--enable_gpu'
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
'--enable_trt'
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
'--enable_dynamic_shape'
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
'--enable_tune'
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
'--enable_profile'
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
'--enable_benchmark'
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
'--save_result'
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
'--return_result'
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
'--enable_debug'
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
'--config_file'
,
type
=
str
,
required
=
False
,
default
=
"config/model.yaml"
)
parser
.
add_argument
(
'--shape_range_file'
,
type
=
str
,
default
=
"shape_range.pbtxt"
)
args
,
unknown
=
parser
.
parse_known_args
()
return
args
def
run_infer
(
model_path
):
conf
=
parse_args
()
yaml_config
=
yaml
.
safe_load
(
'''
input_shape:
'0':
dtype: [float32]
shape:
- [-1, 3, 32, 32]
'''
)
conf
.
yaml_config
=
yaml_config
conf
.
test_num
=
0
conf
.
model_dir
=
model_path
conf
.
enable_tune
=
True
# collect shape use CPU
conf
.
enable_gpu
=
False
backend
=
BackendPaddle
()
backend
.
load
(
conf
)
backend
.
predict
()
# collect shape use GPU
conf
.
enable_gpu
=
True
backend
=
BackendPaddle
()
backend
.
load
(
conf
)
backend
.
predict
()
# run inference predictor
conf
.
enable_tune
=
False
backend
=
BackendPaddle
()
backend
.
load
(
conf
)
backend
.
predict
()
class
ConvBNLayer
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
):
super
().
__init__
()
self
.
_conv
=
paddle
.
nn
.
Conv2D
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
bias_attr
=
False
,
)
self
.
_batch_norm
=
paddle
.
nn
.
BatchNorm
(
num_filters
,
act
=
act
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
Test
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
Test
,
self
).
__init__
()
self
.
conv
=
ConvBNLayer
(
num_channels
=
3
,
num_filters
=
64
,
filter_size
=
3
,
stride
=
2
,
act
=
'relu'
)
self
.
pool2d_max
=
paddle
.
nn
.
MaxPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
pool2d_avg
=
paddle
.
nn
.
AdaptiveAvgPool2D
(
output_size
=
1
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
pool2d_avg
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
paddle
.
to_tensor
([
-
1
],
dtype
=
paddle
.
int64
),
paddle
.
to_tensor
([
8
],
dtype
=
paddle
.
int64
),
],
)
return
x
class
TestInferencePredictor
(
unittest
.
TestCase
):
def
setUp
(
self
):
# enable dygraph mode
paddle
.
disable_static
()
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
self
.
path
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
'./inference/model'
)
self
.
path
=
"./inference/model"
def
tearDown
(
self
):
self
.
temp_dir
.
cleanup
()
def
SaveInferenceModel
(
self
):
paddle
.
disable_static
()
net
=
Test
()
net
.
eval
()
net
(
paddle
.
rand
(
shape
=
[
1
,
3
,
32
,
32
],
dtype
=
'float32'
))
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
-
1
,
3
,
32
,
32
],
dtype
=
paddle
.
float32
,
name
=
'input'
)
]
static_model
=
paddle
.
jit
.
to_static
(
net
,
input_spec
=
input_spec
)
paddle
.
jit
.
save
(
static_model
,
self
.
path
)
def
testInferencePredictor
(
self
):
self
.
SaveInferenceModel
()
run_infer
(
os
.
path
.
dirname
(
self
.
path
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/autograd/composite_rules.py
浏览文件 @
296b64ac
...
...
@@ -98,9 +98,9 @@ def composite_batchnorm(
run_mean_
=
assign
(
run_mean
)
run_var_
=
assign
(
run_var
)
if
trainable_statistics
or
not
is_test
:
return
run_mean_
,
None
,
batch_mean_
,
batch_var_
,
run_var_
,
y
return
y
,
run_mean_
,
run_var_
,
batch_mean_
,
batch_var_
,
None
else
:
return
run_mean_
,
batch_mean_
,
batch_var_
,
run_var_
,
y
return
y
,
run_mean_
,
run_var_
,
batch_mean_
,
batch_var_
@
REGISTER_COMPOSITE
(
'layer_norm'
)
...
...
python/paddle/incubate/autograd/generate_op_map.py
浏览文件 @
296b64ac
...
...
@@ -84,7 +84,7 @@ def generate_code(
else
:
op_name
=
key
map_dct
[
op_name
]
=
{
"phi_name"
:
op_name
}
for
element
in
[
"inputs"
,
"attrs"
]:
for
element
in
[
"inputs"
,
"
outputs"
,
"
attrs"
]:
if
element
in
item
.
keys
():
map_dct
[
op_name
][
element
]
=
item
[
element
]
for
element
in
[
"scalar"
,
"int_array"
]:
...
...
python/paddle/incubate/autograd/primx.py
浏览文件 @
296b64ac
...
...
@@ -36,6 +36,7 @@ from .utils import (
flatten_and_remove_none
,
get_input_var_list
,
get_output_var_list
,
get_output_vars_from_comosite
,
prepare_python_api_arguments
,
)
...
...
@@ -596,19 +597,37 @@ def _lower_composite(block, blacklist=[]):
# if output var of composite rule is None, this means this var is not needed
none_vars_to_remove
=
set
()
change
=
None
# Step2: Process all ops in the target block
for
op_idx
in
range
(
len
(
block
.
ops
)):
op
=
block
.
ops
[
op_idx
]
ops_to_remove
.
append
(
op_idx
)
if
lookup_fn
(
op
.
type
)
is
not
None
and
op
.
type
not
in
blacklist
:
change
=
True
op_name
=
op
.
type
input_args
=
prepare_python_api_arguments
(
op
)
bind
(
input_args
,
to_bind
,
value_table
)
orig_outs
=
expand_nested_list
(
get_output_vars_from_comosite
(
op
)
)
new_outs
=
expand_nested_list
(
as_tensors
(
lower_fn
(
op
,
*
input_args
))
)
assert
len
(
orig_outs
)
==
len
(
new_outs
),
(
f
'when replace origin op
{
op_name
}
with composite rule, num of origin outs should be equal to new outs, '
f
'but len(orig_outs) =
{
len
(
orig_outs
)
}
and len(new_outs) =
{
len
(
new_outs
)
}
'
)
for
orig_out
,
new_out
in
zip
(
expand_nested_list
(
get_output_var_list
(
op
))
,
expand_nested_list
(
as_tensors
(
lower_fn
(
op
,
*
input_args
)))
,
orig_outs
,
new_outs
,
):
if
new_out
is
not
None
:
if
orig_out
.
shape
and
new_out
.
shape
:
assert
orig_out
.
shape
==
new_out
.
shape
,
(
f
'when replace origin op
{
op_name
}
with composite rule, origin out shape should be equal to new out shape, '
f
'but orig_out.shape=
{
orig_out
.
shape
}
and new_out.shape=
{
new_out
.
shape
}
'
)
assert
not
(
orig_out
is
None
)
^
(
new_out
is
None
),
"orig_out and new_out should match."
...
...
@@ -675,6 +694,10 @@ def _lower_composite(block, blacklist=[]):
block
.
desc
.
_remove_var
(
var_name
.
encode
())
del
block
.
vars
[
var_name
]
block
.
_sync_with_cpp
()
# composite ops may contain other composite ops, thus, call _lower_composite again.
if
change
:
_lower_composite
(
block
,
blacklist
)
return
elif
isinstance
(
block
,
typing
.
Sequence
):
...
...
python/paddle/incubate/autograd/utils.py
浏览文件 @
296b64ac
...
...
@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name):
arg_type
,
arg_name
=
_solve_arg
(
item
)
op_content
=
op_map
[
op
.
type
]
if
arg_type
in
(
"Tensor"
,
"Tensor[]"
):
# assume Tensor type must belong to inputs
if
(
"inputs"
in
op_content
.
keys
()
and
arg_name
in
op_content
[
"inputs"
].
keys
()
...
...
@@ -182,7 +183,10 @@ def _get_args_values(op, phi_name):
"attrs"
in
op_content
.
keys
()
and
arg_name
in
op_content
[
"attrs"
].
keys
()
):
attrs
.
append
(
op
.
attr
(
op_content
[
"attrs"
][
arg_name
]))
arg_name
=
op_content
[
"attrs"
][
arg_name
]
if
arg_name
not
in
op
.
attr_names
:
attrs
.
append
(
None
)
else
:
attrs
.
append
(
op
.
attr
(
arg_name
))
return
inputs
,
attrs
...
...
@@ -202,7 +206,12 @@ def prepare_python_api_arguments(op):
else
:
phi_name
=
op
.
type
inputs
,
attrs
=
_get_args_values
(
op
,
phi_name
)
res
=
[
get_var_block
(
op
.
block
,
op
.
input
(
n
))
for
n
in
inputs
]
res
=
[]
for
item
in
inputs
:
if
item
in
op
.
input_names
:
res
.
append
(
get_var_block
(
op
.
block
,
op
.
input
(
item
)))
else
:
res
.
append
(
None
)
if
attrs
:
res
.
extend
(
attrs
)
return
res
...
...
@@ -218,6 +227,37 @@ def get_output_var_list(op):
]
def
get_output_vars_from_comosite
(
op
):
"""origin op outputs must be mapped into outputs of composite rule."""
origin_output_names
=
op
.
output_names
if
origin_output_names
is
None
:
return
[]
else
:
name
=
op
.
type
res
=
[]
if
op_map
[
name
].
get
(
"outputs"
):
for
item
in
op_map
[
name
][
"outputs"
].
keys
():
origin_output_name
=
op_map
[
name
][
"outputs"
][
item
]
if
origin_output_name
not
in
origin_output_names
:
# in some cases, some output of origin op is optional, so op name may not be in origin_output_names
continue
origin_output_var
=
get_var_block
(
op
.
block
,
op
.
output
(
origin_output_name
)
)
res
.
append
(
origin_output_var
)
elif
len
(
origin_output_names
)
==
1
:
# When origin output num is 1, map info is not needed.
origin_output_var
=
get_var_block
(
op
.
block
,
op
.
output
(
origin_output_names
[
0
])
)
res
.
append
(
origin_output_var
)
else
:
raise
ValueError
(
"When replace op with composite rule, there must exist output map info from origin op to composite rule."
)
return
res
def
flatten
(
inp
):
if
inp
is
None
or
isinstance
(
inp
,
paddle
.
fluid
.
framework
.
Variable
):
return
[
inp
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录