Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6abd05f2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6abd05f2
编写于
9月 07, 2020
作者:
J
jingqinghe
浏览文件
操作
浏览文件
下载
差异文件
Merge
https://github.com/PaddlePaddle/Paddle
into doublegrad
上级
836f341d
b150f2b3
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
903 addition
and
232 deletion
+903
-232
paddle/fluid/inference/tests/api/CMakeLists.txt
paddle/fluid/inference/tests/api/CMakeLists.txt
+9
-8
paddle/fluid/inference/tests/api/analyzer_capi_tester.cc
paddle/fluid/inference/tests/api/analyzer_capi_tester.cc
+2
-1
paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc
...ference/tests/api/analyzer_image_classification_tester.cc
+4
-0
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
+1
-1
paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py
...uid/inference/tests/api/full_ILSVRC2012_val_preprocess.py
+2
-2
paddle/fluid/inference/tests/api/full_pascalvoc_test_preprocess.py
...uid/inference/tests/api/full_pascalvoc_test_preprocess.py
+2
-2
paddle/fluid/inference/tests/test.cmake
paddle/fluid/inference/tests/test.cmake
+3
-2
paddle/fluid/operators/cudnn_lstm_op.cc
paddle/fluid/operators/cudnn_lstm_op.cc
+28
-20
paddle/fluid/operators/cudnn_lstm_op.cu.cc
paddle/fluid/operators/cudnn_lstm_op.cu.cc
+141
-64
paddle/fluid/operators/reduce_ops/logsumexp_op.cu
paddle/fluid/operators/reduce_ops/logsumexp_op.cu
+0
-6
paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu
paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu
+22
-0
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+266
-0
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+8
-0
paddle/scripts/paddle_build.bat
paddle/scripts/paddle_build.bat
+1
-1
python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
...ddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
+7
-9
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+15
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py
...uid/tests/unittests/dygraph_to_static/test_declarative.py
+18
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
.../unittests/dygraph_to_static/test_save_inference_model.py
+1
-1
python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py
python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py
+373
-111
未找到文件。
paddle/fluid/inference/tests/api/CMakeLists.txt
浏览文件 @
6abd05f2
...
@@ -125,7 +125,7 @@ endfunction()
...
@@ -125,7 +125,7 @@ endfunction()
if
(
NOT APPLE AND WITH_MKLML
)
if
(
NOT APPLE AND WITH_MKLML
)
# RNN1
# RNN1
set
(
RNN1_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/rnn1"
)
set
(
RNN1_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/rnn1"
)
download_model_and_data
(
${
RNN1_INSTALL_DIR
}
"rnn1
%2Fmodel.tar.gz"
"rnn1%2F
data.txt.tar.gz"
)
download_model_and_data
(
${
RNN1_INSTALL_DIR
}
"rnn1
/model.tar.gz"
"rnn1/
data.txt.tar.gz"
)
inference_analysis_api_test
(
test_analyzer_rnn1
${
RNN1_INSTALL_DIR
}
analyzer_rnn1_tester.cc
)
inference_analysis_api_test
(
test_analyzer_rnn1
${
RNN1_INSTALL_DIR
}
analyzer_rnn1_tester.cc
)
# seq_pool1
# seq_pool1
...
@@ -210,7 +210,7 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana
...
@@ -210,7 +210,7 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana
# transformer, the dataset only works on batch_size=8 now
# transformer, the dataset only works on batch_size=8 now
set
(
TRANSFORMER_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/transformer"
)
set
(
TRANSFORMER_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/transformer"
)
download_model_and_data
(
${
TRANSFORMER_INSTALL_DIR
}
"temp
%2Ftransformer_model.tar.gz"
"temp%2F
transformer_data.txt.tar.gz"
)
download_model_and_data
(
${
TRANSFORMER_INSTALL_DIR
}
"temp
/transformer_model.tar.gz"
"temp/
transformer_data.txt.tar.gz"
)
inference_analysis_test
(
test_analyzer_transformer SRCS analyzer_transformer_tester.cc
inference_analysis_test
(
test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
ARGS --infer_model=
${
TRANSFORMER_INSTALL_DIR
}
/model --infer_data=
${
TRANSFORMER_INSTALL_DIR
}
/data.txt --batch_size=8
ARGS --infer_model=
${
TRANSFORMER_INSTALL_DIR
}
/model --infer_data=
${
TRANSFORMER_INSTALL_DIR
}
/data.txt --batch_size=8
...
@@ -219,7 +219,7 @@ inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_test
...
@@ -219,7 +219,7 @@ inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_test
# ocr
# ocr
set
(
OCR_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/ocr"
)
set
(
OCR_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/ocr"
)
if
(
NOT EXISTS
${
OCR_INSTALL_DIR
}
/ocr.tar.gz
)
if
(
NOT EXISTS
${
OCR_INSTALL_DIR
}
/ocr.tar.gz
)
inference_download_and_uncompress
(
${
OCR_INSTALL_DIR
}
"http://paddlemodels.bj.bcebos.com/"
"inference-vis-demos
%2F
ocr.tar.gz"
)
inference_download_and_uncompress
(
${
OCR_INSTALL_DIR
}
"http://paddlemodels.bj.bcebos.com/"
"inference-vis-demos
/
ocr.tar.gz"
)
endif
()
endif
()
inference_analysis_api_test
(
test_analyzer_ocr
${
OCR_INSTALL_DIR
}
analyzer_vis_tester.cc
)
inference_analysis_api_test
(
test_analyzer_ocr
${
OCR_INSTALL_DIR
}
analyzer_vis_tester.cc
)
...
@@ -235,7 +235,7 @@ set_property(TEST test_analyzer_detect PROPERTY ENVIRONMENT GLOG_vmodule=analysi
...
@@ -235,7 +235,7 @@ set_property(TEST test_analyzer_detect PROPERTY ENVIRONMENT GLOG_vmodule=analysi
# mobilenet with transpose op
# mobilenet with transpose op
set
(
MOBILENET_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/mobilenet"
)
set
(
MOBILENET_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/mobilenet"
)
if
(
NOT EXISTS
${
MOBILENET_INSTALL_DIR
}
/mobilenet.tar.gz
)
if
(
NOT EXISTS
${
MOBILENET_INSTALL_DIR
}
/mobilenet.tar.gz
)
inference_download_and_uncompress
(
${
MOBILENET_INSTALL_DIR
}
"http://paddlemodels.bj.bcebos.com/"
"inference-vis-demos
%2F
mobilenet.tar.gz"
)
inference_download_and_uncompress
(
${
MOBILENET_INSTALL_DIR
}
"http://paddlemodels.bj.bcebos.com/"
"inference-vis-demos
/
mobilenet.tar.gz"
)
endif
()
endif
()
inference_analysis_api_test
(
test_analyzer_mobilenet_transpose
${
MOBILENET_INSTALL_DIR
}
analyzer_vis_tester.cc
)
inference_analysis_api_test
(
test_analyzer_mobilenet_transpose
${
MOBILENET_INSTALL_DIR
}
analyzer_vis_tester.cc
)
...
@@ -363,9 +363,9 @@ if(WITH_MKLDNN)
...
@@ -363,9 +363,9 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build
(
${
QUANT_IMG_CLASS_TEST_APP
}
${
QUANT_IMG_CLASS_TEST_APP_SRC
}
)
inference_analysis_api_test_build
(
${
QUANT_IMG_CLASS_TEST_APP
}
${
QUANT_IMG_CLASS_TEST_APP_SRC
}
)
# MobileNetV1 FP32 vs. Quant INT8
# MobileNetV1 FP32 vs. Quant INT8
# The FP32 model should already be downloaded for slim Quant unit tests
set
(
QUANT2_MobileNetV1_MODEL_DIR
"
${
QUANT_DATA_DIR
}
/MobileNetV1_quant2"
)
set
(
QUANT2_MobileNetV1_MODEL_DIR
"
${
QUANT_DATA_DIR
}
/MobileNetV1_quant2"
)
set
(
QUANT2_INT8_MobileNetV1_MODEL_DIR
"
${
QUANT_DATA_DIR
}
/MobileNetV1_quant2_int8"
)
set
(
QUANT2_INT8_MobileNetV1_MODEL_DIR
"
${
QUANT_DATA_DIR
}
/MobileNetV1_quant2_int8"
)
download_quant_data
(
${
QUANT2_MobileNetV1_MODEL_DIR
}
"MobileNet_qat_perf.tar.gz"
)
download_quant_data
(
${
QUANT2_INT8_MobileNetV1_MODEL_DIR
}
"MobileNet_qat_perf_int8.tar.gz"
)
download_quant_data
(
${
QUANT2_INT8_MobileNetV1_MODEL_DIR
}
"MobileNet_qat_perf_int8.tar.gz"
)
inference_analysis_api_quant_test_run
(
test_analyzer_quant_performance_benchmark
${
QUANT_IMG_CLASS_TEST_APP
}
${
QUANT2_MobileNetV1_MODEL_DIR
}
/MobileNet_qat_perf/float
${
QUANT2_INT8_MobileNetV1_MODEL_DIR
}
/MobileNet_qat_perf_int8
${
IMAGENET_DATA_PATH
}
)
inference_analysis_api_quant_test_run
(
test_analyzer_quant_performance_benchmark
${
QUANT_IMG_CLASS_TEST_APP
}
${
QUANT2_MobileNetV1_MODEL_DIR
}
/MobileNet_qat_perf/float
${
QUANT2_INT8_MobileNetV1_MODEL_DIR
}
/MobileNet_qat_perf_int8
${
IMAGENET_DATA_PATH
}
)
...
@@ -477,9 +477,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
...
@@ -477,9 +477,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress
(
${
TEST_TRT_ERNIE_MODEL
}
${
INFERENCE_URL
}
/tensorrt_test
"ernie_model_4_unserialized.tgz"
)
inference_download_and_uncompress
(
${
TEST_TRT_ERNIE_MODEL
}
${
INFERENCE_URL
}
/tensorrt_test
"ernie_model_4_unserialized.tgz"
)
endif
()
endif
()
inference_analysis_test
(
test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
# disable test_trt_dynamic_shape_ernie_ser_deser temporary
EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
#inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
ARGS --infer_model=
${
TEST_TRT_ERNIE_MODEL
}
/ernie_model_4_unserialized
)
# EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
# ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
endif
()
endif
()
...
...
paddle/fluid/inference/tests/api/analyzer_capi_tester.cc
浏览文件 @
6abd05f2
...
@@ -44,7 +44,7 @@ void zero_copy_run() {
...
@@ -44,7 +44,7 @@ void zero_copy_run() {
const
int
channels
=
3
;
const
int
channels
=
3
;
const
int
height
=
318
;
const
int
height
=
318
;
const
int
width
=
318
;
const
int
width
=
318
;
float
input
[
batch_size
*
channels
*
height
*
width
]
=
{
0
}
;
float
*
input
=
new
float
[
batch_size
*
channels
*
height
*
width
]()
;
int
shape
[
4
]
=
{
batch_size
,
channels
,
height
,
width
};
int
shape
[
4
]
=
{
batch_size
,
channels
,
height
,
width
};
int
shape_size
=
4
;
int
shape_size
=
4
;
...
@@ -65,6 +65,7 @@ void zero_copy_run() {
...
@@ -65,6 +65,7 @@ void zero_copy_run() {
PD_PredictorZeroCopyRun
(
config
,
inputs
,
in_size
,
&
outputs
,
&
out_size
);
PD_PredictorZeroCopyRun
(
config
,
inputs
,
in_size
,
&
outputs
,
&
out_size
);
delete
[]
input
;
delete
[]
inputs
;
delete
[]
inputs
;
delete
[]
outputs
;
delete
[]
outputs
;
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc
浏览文件 @
6abd05f2
...
@@ -112,7 +112,11 @@ TEST(Analyzer_resnet50, compare_determine) {
...
@@ -112,7 +112,11 @@ TEST(Analyzer_resnet50, compare_determine) {
TEST
(
Analyzer_resnet50
,
save_optim_model
)
{
TEST
(
Analyzer_resnet50
,
save_optim_model
)
{
AnalysisConfig
cfg
;
AnalysisConfig
cfg
;
std
::
string
optimModelPath
=
FLAGS_infer_model
+
"/saved_optim_model"
;
std
::
string
optimModelPath
=
FLAGS_infer_model
+
"/saved_optim_model"
;
#ifdef _WIN32
_mkdir
(
optimModelPath
.
c_str
());
#else
mkdir
(
optimModelPath
.
c_str
(),
0777
);
mkdir
(
optimModelPath
.
c_str
(),
0777
);
#endif
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
SaveOptimModel
(
&
cfg
,
optimModelPath
);
SaveOptimModel
(
&
cfg
,
optimModelPath
);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
浏览文件 @
6abd05f2
...
@@ -123,7 +123,7 @@ void profile(bool memory_load = false) {
...
@@ -123,7 +123,7 @@ void profile(bool memory_load = false) {
size_t
size
=
GetSize
(
output
[
0
]);
size_t
size
=
GetSize
(
output
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
int64_t
*
result
=
static_cast
<
int64_t
*>
(
output
[
0
].
data
.
data
());
int64_t
*
result
=
static_cast
<
int64_t
*>
(
output
[
0
].
data
.
data
());
for
(
size_t
i
=
0
;
i
<
std
::
min
(
11UL
,
size
);
i
++
)
{
for
(
size_t
i
=
0
;
i
<
std
::
min
<
size_t
>
(
11
,
size
);
i
++
)
{
EXPECT_EQ
(
result
[
i
],
chinese_ner_result_data
[
i
]);
EXPECT_EQ
(
result
[
i
],
chinese_ner_result_data
[
i
]);
}
}
}
}
...
...
paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py
浏览文件 @
6abd05f2
...
@@ -23,7 +23,7 @@ from PIL import Image
...
@@ -23,7 +23,7 @@ from PIL import Image
import
math
import
math
from
paddle.dataset.common
import
download
from
paddle.dataset.common
import
download
import
tarfile
import
tarfile
import
StringIO
from
six.moves
import
StringIO
import
argparse
import
argparse
random
.
seed
(
0
)
random
.
seed
(
0
)
...
@@ -152,7 +152,7 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
...
@@ -152,7 +152,7 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
idx
=
0
idx
=
0
for
imagedata
in
dataset
.
values
():
for
imagedata
in
dataset
.
values
():
img
=
Image
.
open
(
StringIO
.
StringIO
(
imagedata
))
img
=
Image
.
open
(
StringIO
(
imagedata
))
img
=
process_image
(
img
)
img
=
process_image
(
img
)
np_img
=
np
.
array
(
img
)
np_img
=
np
.
array
(
img
)
ofs
.
write
(
np_img
.
astype
(
'float32'
).
tobytes
())
ofs
.
write
(
np_img
.
astype
(
'float32'
).
tobytes
())
...
...
paddle/fluid/inference/tests/api/full_pascalvoc_test_preprocess.py
浏览文件 @
6abd05f2
...
@@ -19,7 +19,7 @@ import os
...
@@ -19,7 +19,7 @@ import os
import
sys
import
sys
from
paddle.dataset.common
import
download
from
paddle.dataset.common
import
download
import
tarfile
import
tarfile
import
StringIO
from
six.moves
import
StringIO
import
hashlib
import
hashlib
import
tarfile
import
tarfile
import
argparse
import
argparse
...
@@ -191,7 +191,7 @@ def convert_pascalvoc_tar2bin(tar_path, data_out_path):
...
@@ -191,7 +191,7 @@ def convert_pascalvoc_tar2bin(tar_path, data_out_path):
gt_labels
[
name_prefix
]
=
tar
.
extractfile
(
tarInfo
).
read
()
gt_labels
[
name_prefix
]
=
tar
.
extractfile
(
tarInfo
).
read
()
for
line_idx
,
name_prefix
in
enumerate
(
lines
):
for
line_idx
,
name_prefix
in
enumerate
(
lines
):
im
=
Image
.
open
(
StringIO
.
StringIO
(
images
[
name_prefix
]))
im
=
Image
.
open
(
StringIO
(
images
[
name_prefix
]))
if
im
.
mode
==
'L'
:
if
im
.
mode
==
'L'
:
im
=
im
.
convert
(
'RGB'
)
im
=
im
.
convert
(
'RGB'
)
im_width
,
im_height
=
im
.
size
im_width
,
im_height
=
im
.
size
...
...
paddle/fluid/inference/tests/test.cmake
浏览文件 @
6abd05f2
...
@@ -25,7 +25,8 @@ endfunction()
...
@@ -25,7 +25,8 @@ endfunction()
function
(
inference_download_and_uncompress INSTALL_DIR URL FILENAME
)
function
(
inference_download_and_uncompress INSTALL_DIR URL FILENAME
)
message
(
STATUS
"Download inference test stuff from
${
URL
}
/
${
FILENAME
}
"
)
message
(
STATUS
"Download inference test stuff from
${
URL
}
/
${
FILENAME
}
"
)
string
(
REGEX REPLACE
"[-%.]"
"_"
FILENAME_EX
${
FILENAME
}
)
string
(
REGEX REPLACE
"[-%./
\\
]"
"_"
FILENAME_EX
${
FILENAME
}
)
string
(
REGEX MATCH
"[^/
\\
]+$"
DOWNLOAD_NAME
${
FILENAME
}
)
set
(
EXTERNAL_PROJECT_NAME
"extern_inference_download_
${
FILENAME_EX
}
"
)
set
(
EXTERNAL_PROJECT_NAME
"extern_inference_download_
${
FILENAME_EX
}
"
)
set
(
UNPACK_DIR
"
${
INSTALL_DIR
}
/src/
${
EXTERNAL_PROJECT_NAME
}
"
)
set
(
UNPACK_DIR
"
${
INSTALL_DIR
}
/src/
${
EXTERNAL_PROJECT_NAME
}
"
)
ExternalProject_Add
(
ExternalProject_Add
(
...
@@ -38,7 +39,7 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
...
@@ -38,7 +39,7 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
DOWNLOAD_NO_PROGRESS 1
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
${
CMAKE_COMMAND
}
-E chdir
${
INSTALL_DIR
}
BUILD_COMMAND
${
CMAKE_COMMAND
}
-E chdir
${
INSTALL_DIR
}
${
CMAKE_COMMAND
}
-E tar xzf
${
FILE
NAME
}
${
CMAKE_COMMAND
}
-E tar xzf
${
DOWNLOAD_
NAME
}
UPDATE_COMMAND
""
UPDATE_COMMAND
""
INSTALL_COMMAND
""
INSTALL_COMMAND
""
)
)
...
...
paddle/fluid/operators/cudnn_lstm_op.cc
浏览文件 @
6abd05f2
...
@@ -37,41 +37,42 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
...
@@ -37,41 +37,42 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LastC"
),
"Output"
,
"LastC"
,
"CudnnLSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LastC"
),
"Output"
,
"LastC"
,
"CudnnLSTM"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
init_dims
=
ctx
->
GetInputDim
(
"InitH"
);
auto
init_h_dims
=
ctx
->
GetInputDim
(
"InitH"
);
auto
init_c_dims
=
ctx
->
GetInputDim
(
"InitC"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
3
,
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
3
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The rank of Input in CudnnLSTM must be 3. But "
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d."
,
"received Input's rank is %d."
,
in_dims
.
size
()));
in_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
init_dims
.
size
(),
3
,
PADDLE_ENFORCE_EQ
(
init_
h_
dims
.
size
(),
3
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The rank of InitH in CudnnLSTM must be 3. But "
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d."
,
"received InitH's rank is %d."
,
init_dims
.
size
()));
init_
h_
dims
.
size
()));
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
init_dims
[
1
],
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
in_dims
[
1
],
init_h_dims
[
1
],
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
platform
::
errors
::
InvalidArgument
(
"dims) should be equal. But "
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"received in_dims[1] is %d and init_dims[1] is %d."
,
"dims) should be equal. But "
in_dims
[
1
],
init_dims
[
1
]));
"received in_dims[1] is %d and init_h_dims[1] is %d."
,
PADDLE_ENFORCE_EQ
(
in_dims
[
2
],
init_dims
[
2
],
in_dims
[
1
],
init_h_dims
[
1
]));
PADDLE_ENFORCE_EQ
(
init_c_dims
,
init_h_dims
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The
in_dims[2] (Input dims) and init_dims[2] (
InitH "
"The
InitC dims and
InitH "
"dims
)
should be equal. But "
"dims should be equal. But "
"received in
_dims[2] is %d and init_dims[2]
is %d."
,
"received in
it_c_dims is %d and init_h_dims
is %d."
,
in
_dims
[
2
],
init_dims
[
2
]
));
in
it_c_dims
,
init_h_dims
));
auto
out_dims
=
in_dims
;
auto
out_dims
=
in_dims
;
auto
hidden_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"hidden_size"
);
auto
hidden_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"hidden_size"
);
bool
is_bidirec
=
ctx
->
Attrs
().
Get
<
bool
>
(
"is_bidirec"
);
bool
is_bidirec
=
ctx
->
Attrs
().
Get
<
bool
>
(
"is_bidirec"
);
out_dims
[
2
]
=
is_bidirec
?
hidden_size
*
2
:
hidden_size
;
out_dims
[
2
]
=
is_bidirec
?
hidden_size
*
2
:
hidden_size
;
auto
last_dims
=
init_dims
;
last_dims
[
0
]
=
is_bidirec
?
last_dims
[
0
]
*
2
:
last_dims
[
0
];
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
ctx
->
SetOutputDim
(
"LastH"
,
last
_dims
);
ctx
->
SetOutputDim
(
"LastH"
,
init_c
_dims
);
ctx
->
SetOutputDim
(
"LastC"
,
last
_dims
);
ctx
->
SetOutputDim
(
"LastC"
,
init_h
_dims
);
}
}
protected:
protected:
...
@@ -95,7 +96,7 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -95,7 +96,7 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"different batch)"
"different batch)"
"batch_size is the instance number of this batch"
"batch_size is the instance number of this batch"
"input_size is the hidden size of the input."
"input_size is the hidden size of the input."
"input_
hidden_
size and the hidden_size in the next may not be same"
);
"input_size and the hidden_size in the next may not be same"
);
AddInput
(
"InitH"
,
AddInput
(
"InitH"
,
"(Tensor) the initial hidden state of the LSTM"
"(Tensor) the initial hidden state of the LSTM"
"input. This is a tensor with shape (num_layers x batch_size x "
"input. This is a tensor with shape (num_layers x batch_size x "
...
@@ -154,6 +155,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -154,6 +155,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddAttr
<
bool
>
(
"is_test"
,
"True if in test phase."
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"is_test"
,
"True if in test phase."
).
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"seed to used if fix_seed is True"
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"seed"
,
"seed to used if fix_seed is True"
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
int
>>
(
"sequence_length"
,
"(vector<int>) When the input data is padding, "
"set this parameter. This parameter represents "
"the variable sequence"
"lengths in a batch. The size of the vector has "
"to equal the batch_size."
)
.
SetDefault
({});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
CUDNN LSTM implementation
CUDNN LSTM implementation
...
...
paddle/fluid/operators/cudnn_lstm_op.cu.cc
浏览文件 @
6abd05f2
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -55,50 +56,96 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
...
@@ -55,50 +56,96 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
int
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
auto
sequence_length
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"sequence_length"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
CudnnRNNCache
*
cudnn_rnn_cache
=
new
CudnnRNNCache
();
int
seq_length
=
x
->
dims
()[
0
];
int
batch_size
=
x
->
dims
()[
1
];
int
input_size
=
x
->
dims
()[
2
];
int
weight_numel
=
w
->
numel
();
bool
state_initialized
=
state_out
->
IsInitialized
()
?
true
:
false
;
auto
input_w_numel
=
w
->
numel
();
size_t
workspace_size
;
auto
seq_len
=
x
->
dims
()[
0
];
auto
batch_size
=
x
->
dims
()[
1
];
auto
input_dim
=
x
->
dims
()[
2
];
size_t
reserve_size
;
size_t
reserve_size
;
bool
state_initialized
=
state_out
->
IsInitialized
()
?
true
:
false
;
cudnnDataType_t
cudnn_type
=
platform
::
ToCudnnDataType
(
platform
::
ScopedRNNBase
rnn
(
seq_length
,
batch_size
,
input_size
,
hidden_size
,
framework
::
ToDataType
(
std
::
type_index
(
typeid
(
T
))));
num_layers
,
dropout_prob
,
seed
,
weight_numel
,
cudnn_rnn_cache
->
init
(
handle
,
ctx
.
GetPlace
(),
seq_len
,
batch_size
,
state_initialized
,
is_bidirec
);
input_dim
,
hidden_size
,
num_layers
,
dropout_prob
,
rnn
.
Create
<
T
>
(
handle
,
ctx
.
GetPlace
(),
sequence_length
,
&
workspace_size
,
is_bidirec
,
seed
,
input_w_numel
,
&
reserve_size
,
&
reserve_size
,
state_out
);
state_out
,
state_initialized
,
cudnn_type
);
framework
::
Tensor
workspace_data_
;
workspace_data_
.
Resize
({
static_cast
<
int64_t
>
(
workspace_size
)});
workspace_data_
.
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
auto
*
reserve_data
=
reserve
->
mutable_data
<
uint8_t
>
(
auto
*
reserve_data
=
reserve
->
mutable_data
<
uint8_t
>
(
{
static_cast
<
int64_t
>
(
reserve_size
)},
ctx
.
GetPlace
());
{
static_cast
<
int64_t
>
(
reserve_size
)},
ctx
.
GetPlace
());
if
(
is_test
)
{
if
(
is_test
)
{
// for inference
if
(
sequence_length
.
empty
())
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNForwardInference
(
// for inference
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
seq_len
,
cudnn_rnn_cache
->
x_desc_
,
// This interface is used when the input/output is unpadded.
x_data
,
cudnn_rnn_cache
->
hx_desc_
,
init_h_data
,
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNForwardInference
(
cudnn_rnn_cache
->
cx_desc_
,
init_c_data
,
cudnn_rnn_cache
->
w_desc_
,
handle
,
rnn
.
rnn_desc
(),
seq_length
,
rnn
.
x_desc
(),
x_data
,
w_data
,
cudnn_rnn_cache
->
y_desc_
,
out_data
,
cudnn_rnn_cache
->
hy_desc_
,
rnn
.
hx_desc
(),
init_h_data
,
rnn
.
cx_desc
(),
init_c_data
,
last_h_data
,
cudnn_rnn_cache
->
cy_desc_
,
last_c_data
,
rnn
.
w_desc
(),
w_data
,
rnn
.
y_desc
(),
out_data
,
rnn
.
hy_desc
(),
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
(),
last_h_data
,
rnn
.
cy_desc
(),
last_c_data
,
cudnn_rnn_cache
->
workspace_size_
));
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
));
}
else
{
#if CUDNN_VERSION >= 7201
// for inference
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNForwardInferenceEx
(
handle
,
rnn
.
rnn_desc
(),
rnn
.
x_seq_desc
(),
x_data
,
rnn
.
hx_desc
(),
init_h_data
,
rnn
.
cx_desc
(),
init_c_data
,
rnn
.
w_desc
(),
w_data
,
rnn
.
y_seq_desc
(),
out_data
,
rnn
.
hy_desc
(),
last_h_data
,
rnn
.
cy_desc
(),
last_c_data
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
));
#else
PADDLE_ENFORCE_NOT_NULL
(
nullptr
,
platform
::
errors
::
Unavailable
(
"The padded input is supported by "
"cudnnRNNForwardInferenceEx, but it only works when "
"the version of cudnn is larger than 7.2.1"
));
#endif
}
}
else
{
}
else
{
// for train
if
(
sequence_length
.
empty
())
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNForwardTraining
(
// for train
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
seq_len
,
cudnn_rnn_cache
->
x_desc_
,
// This interface is used when the input/output is unpadded.
x_data
,
cudnn_rnn_cache
->
hx_desc_
,
init_h_data
,
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNForwardTraining
(
cudnn_rnn_cache
->
cx_desc_
,
init_c_data
,
cudnn_rnn_cache
->
w_desc_
,
handle
,
rnn
.
rnn_desc
(),
seq_length
,
rnn
.
x_desc
(),
x_data
,
w_data
,
cudnn_rnn_cache
->
y_desc_
,
out_data
,
cudnn_rnn_cache
->
hy_desc_
,
rnn
.
hx_desc
(),
init_h_data
,
rnn
.
cx_desc
(),
init_c_data
,
last_h_data
,
cudnn_rnn_cache
->
cy_desc_
,
last_c_data
,
rnn
.
w_desc
(),
w_data
,
rnn
.
y_desc
(),
out_data
,
rnn
.
hy_desc
(),
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
(),
last_h_data
,
rnn
.
cy_desc
(),
last_c_data
,
cudnn_rnn_cache
->
workspace_size_
,
reserve_data
,
reserve_size
));
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
,
reserve_data
,
reserve_size
));
}
else
{
#if CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNForwardTrainingEx
(
handle
,
rnn
.
rnn_desc
(),
rnn
.
x_seq_desc
(),
x_data
,
rnn
.
hx_desc
(),
init_h_data
,
rnn
.
cx_desc
(),
init_c_data
,
rnn
.
w_desc
(),
w_data
,
rnn
.
y_seq_desc
(),
out_data
,
rnn
.
hy_desc
(),
last_h_data
,
rnn
.
cy_desc
(),
last_c_data
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
,
reserve_data
,
reserve_size
));
#else
PADDLE_ENFORCE_NOT_NULL
(
nullptr
,
platform
::
errors
::
Unavailable
(
"The padded input is supported by "
"cudnnRNNForwardTrainingEx, but it only works when "
"the version of cudnn is larger than 7.2.1"
));
#endif
}
}
}
delete
cudnn_rnn_cache
;
}
}
};
};
...
@@ -156,44 +203,74 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
...
@@ -156,44 +203,74 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int
hidden_size
=
ctx
.
Attr
<
int
>
(
"hidden_size"
);
int
hidden_size
=
ctx
.
Attr
<
int
>
(
"hidden_size"
);
int
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
int
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
auto
sequence_length
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"sequence_length"
);
CudnnRNNCache
*
cudnn_rnn_cache
=
new
CudnnRNNCache
();
int
seq_length
=
input_dims
[
0
];
int
batch_size
=
input
->
dims
()[
1
];
int
input_size
=
input
->
dims
()[
2
];
int
weight_numel
=
weight
->
numel
();
auto
input_w_numel
=
weight
->
numel
();
size_t
workspace_size
;
auto
seq_len
=
input_dims
[
0
];
auto
batch_size
=
input
->
dims
()[
1
];
auto
input_dim
=
input
->
dims
()[
2
];
size_t
reserve_size
;
size_t
reserve_size
;
cudnnDataType_t
cudnn_type
=
platform
::
ToCudnnDataType
(
framework
::
ToDataType
(
std
::
type_index
(
typeid
(
T
))));
platform
::
ScopedRNNBase
rnn
(
seq_length
,
batch_size
,
input_size
,
hidden_size
,
cudnn_rnn_cache
->
init
(
handle
,
ctx
.
GetPlace
(),
seq_len
,
batch_size
,
num_layers
,
dropout_prob
,
seed
,
weight_numel
,
input_dim
,
hidden_size
,
num_layers
,
dropout_prob
,
true
,
is_bidirec
);
is_bidirec
,
seed
,
input_w_numel
,
&
reserve_size
,
const_cast
<
Tensor
*>
(
state_out
),
true
,
cudnn_type
);
rnn
.
Create
<
T
>
(
handle
,
ctx
.
GetPlace
(),
sequence_length
,
&
workspace_size
,
&
reserve_size
,
const_cast
<
Tensor
*>
(
state_out
));
auto
work_data
=
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
();
framework
::
Tensor
workspace_data_
;
workspace_data_
.
Resize
({
static_cast
<
int64_t
>
(
workspace_size
)});
workspace_data_
.
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
const
uint8_t
*
reserve_data
=
reserve
->
data
<
uint8_t
>
();
const
uint8_t
*
reserve_data
=
reserve
->
data
<
uint8_t
>
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNBackwardData
(
if
(
sequence_length
.
empty
())
{
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
seq_len
,
cudnn_rnn_cache
->
y_desc_
,
// This interface is used when the input/output is unpadded.
out_data
,
cudnn_rnn_cache
->
y_desc_
,
out_grad_data
,
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNBackwardData
(
cudnn_rnn_cache
->
hy_desc_
,
last_h_grad_data
,
cudnn_rnn_cache
->
cy_desc_
,
handle
,
rnn
.
rnn_desc
(),
seq_length
,
rnn
.
y_desc
(),
out_data
,
last_c_grad_data
,
cudnn_rnn_cache
->
w_desc_
,
weight_data
,
rnn
.
y_desc
(),
out_grad_data
,
rnn
.
hy_desc
(),
last_h_grad_data
,
cudnn_rnn_cache
->
hx_desc_
,
init_h_data
,
cudnn_rnn_cache
->
cx_desc_
,
rnn
.
cy_desc
(),
last_c_grad_data
,
rnn
.
w_desc
(),
weight_data
,
init_c_data
,
cudnn_rnn_cache
->
x_desc_
,
in_grad_data
,
rnn
.
hx_desc
(),
init_h_data
,
rnn
.
cx_desc
(),
init_c_data
,
rnn
.
x_desc
(),
cudnn_rnn_cache
->
hx_desc_
,
init_h_grad_data
,
cudnn_rnn_cache
->
cx_desc_
,
in_grad_data
,
rnn
.
hx_desc
(),
init_h_grad_data
,
rnn
.
cx_desc
(),
init_c_grad_data
,
work_data
,
cudnn_rnn_cache
->
workspace_size_
,
init_c_grad_data
,
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
,
const_cast
<
uint8_t
*>
(
reserve_data
),
reserve_size
));
const_cast
<
uint8_t
*>
(
reserve_data
),
reserve_size
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNBackwardWeights
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNBackwardWeights
(
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
seq_len
,
cudnn_rnn_cache
->
x_desc_
,
handle
,
rnn
.
rnn_desc
(),
seq_length
,
rnn
.
x_desc
(),
input
->
data
<
T
>
(),
input
->
data
<
T
>
(),
cudnn_rnn_cache
->
hx_desc_
,
init_h
->
data
<
T
>
(),
rnn
.
hx_desc
(),
init_h
->
data
<
T
>
(),
rnn
.
y_desc
(),
out
->
data
<
T
>
(),
cudnn_rnn_cache
->
y_desc_
,
out
->
data
<
T
>
(),
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
,
rnn
.
w_desc
(),
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
(),
weight_grad
->
data
<
T
>
(),
const_cast
<
uint8_t
*>
(
reserve_data
),
cudnn_rnn_cache
->
workspace_size_
,
cudnn_rnn_cache
->
w_desc_
,
reserve_size
));
weight_grad
->
data
<
T
>
(),
const_cast
<
uint8_t
*>
(
reserve_data
),
}
else
{
reserve_size
));
#if CUDNN_VERSION >= 7201
delete
cudnn_rnn_cache
;
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNBackwardDataEx
(
handle
,
rnn
.
rnn_desc
(),
rnn
.
y_seq_desc
(),
out_data
,
rnn
.
y_seq_desc
(),
out_grad_data
,
nullptr
,
nullptr
,
rnn
.
hy_desc
(),
last_h_grad_data
,
rnn
.
cy_desc
(),
last_c_grad_data
,
rnn
.
w_desc
(),
weight_data
,
rnn
.
hx_desc
(),
init_h_data
,
rnn
.
cx_desc
(),
init_c_data
,
rnn
.
x_seq_desc
(),
in_grad_data
,
rnn
.
hx_desc
(),
init_h_grad_data
,
rnn
.
cx_desc
(),
init_c_grad_data
,
nullptr
,
nullptr
,
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
,
const_cast
<
uint8_t
*>
(
reserve_data
),
reserve_size
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnRNNBackwardWeightsEx
(
handle
,
rnn
.
rnn_desc
(),
rnn
.
x_seq_desc
(),
input
->
data
<
T
>
(),
rnn
.
hx_desc
(),
init_h
->
data
<
T
>
(),
rnn
.
y_seq_desc
(),
out
->
data
<
T
>
(),
workspace_data_
.
data
<
uint8_t
>
(),
workspace_size
,
rnn
.
w_desc
(),
weight_grad
->
data
<
T
>
(),
const_cast
<
uint8_t
*>
(
reserve_data
),
reserve_size
));
#else
PADDLE_ENFORCE_NOT_NULL
(
nullptr
,
platform
::
errors
::
Unavailable
(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
"cudnnRNNBackwardWeightsEx, but it only works when the version "
"of cudnn is larger than 7.2.1"
));
#endif
}
}
}
};
};
...
...
paddle/fluid/operators/reduce_ops/logsumexp_op.cu
浏览文件 @
6abd05f2
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
REGISTER_OP_CUDA_KERNEL
(
logsumexp
,
REGISTER_OP_CUDA_KERNEL
(
logsumexp
,
...
@@ -20,8 +19,3 @@ REGISTER_OP_CUDA_KERNEL(logsumexp,
...
@@ -20,8 +19,3 @@ REGISTER_OP_CUDA_KERNEL(logsumexp,
float
,
ops
::
LogsumexpFunctor
>
,
float
,
ops
::
LogsumexpFunctor
>
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
LogsumexpFunctor
>
);
double
,
ops
::
LogsumexpFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
logsumexp_grad
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
LogsumexpGradFunctor
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
LogsumexpGradFunctor
>
);
paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu
0 → 100644
浏览文件 @
6abd05f2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
REGISTER_OP_CUDA_KERNEL
(
logsumexp_grad
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
LogsumexpGradFunctor
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
LogsumexpGradFunctor
>
);
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
6abd05f2
...
@@ -273,11 +273,116 @@ class ScopedTensorDescriptor {
...
@@ -273,11 +273,116 @@ class ScopedTensorDescriptor {
groups
);
groups
);
}
}
inline
cudnnTensorDescriptor_t
descriptor
(
const
cudnnDataType_t
cudnn_type
,
const
std
::
vector
<
int
>&
dim
,
const
std
::
vector
<
int
>&
stride
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetTensorNdDescriptor
(
desc_
,
cudnn_type
,
dim
.
size
(),
dim
.
data
(),
stride
.
data
()));
return
desc_
;
}
template
<
typename
T
>
inline
cudnnTensorDescriptor_t
descriptor
(
const
std
::
vector
<
int
>&
dim
,
const
std
::
vector
<
int
>&
stride
)
{
return
descriptor
(
CudnnDataType
<
T
>::
type
,
dim
,
stride
);
}
private:
private:
cudnnTensorDescriptor_t
desc_
;
cudnnTensorDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedTensorDescriptor
);
DISABLE_COPY_AND_ASSIGN
(
ScopedTensorDescriptor
);
};
};
class
ScopedRNNTensorDescriptor
{
public:
ScopedRNNTensorDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreateRNNDataDescriptor
(
&
desc_
));
}
~
ScopedRNNTensorDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroyRNNDataDescriptor
(
desc_
));
}
inline
cudnnRNNDataDescriptor_t
descriptor
(
const
cudnnDataType_t
cudnn_type
,
int
max_seq_length
,
int
batch_size
,
int
input_size
,
bool
time_major
,
const
std
::
vector
<
int
>&
seq_length
)
{
static
float
padding_fill
=
0.0
f
;
cudnnRNNDataLayout_t
layout
;
if
(
time_major
)
{
layout
=
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED
;
}
else
{
layout
=
CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetRNNDataDescriptor
(
desc_
,
cudnn_type
,
layout
,
max_seq_length
,
batch_size
,
input_size
,
seq_length
.
data
(),
static_cast
<
void
*>
(
&
padding_fill
)));
return
desc_
;
}
template
<
typename
T
>
inline
cudnnRNNDataDescriptor_t
descriptor
(
int
max_length
,
int
batch_size
,
int
input_size
,
bool
time_major
,
const
std
::
vector
<
int
>&
seq_length
)
{
return
descriptor
(
CudnnDataType
<
T
>::
type
,
max_length
,
batch_size
,
input_size
,
time_major
,
seq_length
);
}
private:
cudnnRNNDataDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedRNNTensorDescriptor
);
};
class
ScopedDropoutDescriptor
{
public:
ScopedDropoutDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreateDropoutDescriptor
(
&
desc_
));
}
~
ScopedDropoutDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroyDropoutDescriptor
(
desc_
));
}
inline
cudnnDropoutDescriptor_t
descriptor
(
const
cudnnHandle_t
&
handle
,
const
platform
::
Place
&
place
,
bool
initialized
,
float
dropout_prob_
,
framework
::
Tensor
*
dropout_state_
,
int
seed
,
size_t
state_size
)
{
auto
*
dropout_state_data
=
dropout_state_
->
data
<
uint8_t
>
();
if
(
!
initialized
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnSetDropoutDescriptor
(
desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
seed
));
}
else
{
auto
dropout_state_dims
=
dropout_state_
->
dims
();
state_size
=
dropout_state_dims
[
0
];
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnRestoreDropoutDescriptor
(
desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
0
));
}
return
desc_
;
}
private:
cudnnDropoutDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedDropoutDescriptor
);
};
class
ScopedRNNDescriptor
{
public:
ScopedRNNDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnCreateRNNDescriptor
(
&
desc_
));
}
~
ScopedRNNDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDestroyRNNDescriptor
(
desc_
));
}
inline
cudnnRNNDescriptor_t
descriptor
()
{
return
desc_
;
}
private:
cudnnRNNDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedRNNDescriptor
);
};
class
ScopedFilterDescriptor
{
class
ScopedFilterDescriptor
{
public:
public:
ScopedFilterDescriptor
()
{
ScopedFilterDescriptor
()
{
...
@@ -319,6 +424,167 @@ class ScopedFilterDescriptor {
...
@@ -319,6 +424,167 @@ class ScopedFilterDescriptor {
DISABLE_COPY_AND_ASSIGN
(
ScopedFilterDescriptor
);
DISABLE_COPY_AND_ASSIGN
(
ScopedFilterDescriptor
);
};
};
class
ScopedRNNBase
{
public:
ScopedRNNBase
(
int
seq_length
,
int
batch_size
,
int
input_size
,
int
hidden_size
,
int
num_layers
,
float
dropout_prob
,
int
seed
,
int
weight_numel
,
bool
initialized
,
bool
is_bidirec
)
:
seq_length_
(
seq_length
),
batch_size_
(
batch_size
),
input_size_
(
input_size
),
hidden_size_
(
hidden_size
),
num_layers_
(
num_layers
),
dropout_prob_
(
dropout_prob
),
seed_
(
seed
),
weight_numel_
(
weight_numel
),
initialized_
(
initialized
),
is_bidirec_
(
is_bidirec
)
{}
template
<
typename
T
>
void
Create
(
const
cudnnHandle_t
&
handle
,
const
platform
::
Place
&
place
,
std
::
vector
<
int
>
sequence_length
,
size_t
*
workspace_size
,
size_t
*
reserve_size
,
framework
::
Tensor
*
dropout_state
)
{
int
numDirections
=
is_bidirec_
?
2
:
1
;
cudnnDataType_t
cudnn_type
=
platform
::
CudnnDataType
<
T
>::
type
;
// ------------------- cudnn x, y descriptors ---------------------
std
::
vector
<
int
>
dims_x
=
{
batch_size_
,
input_size_
,
1
};
std
::
vector
<
int
>
strides_x
=
{
input_size_
,
1
,
1
};
std
::
vector
<
int
>
dims_y
=
{
batch_size_
,
hidden_size_
*
numDirections
,
1
};
std
::
vector
<
int
>
strides_y
=
{
hidden_size_
*
numDirections
,
1
,
1
};
for
(
int
i
=
0
;
i
<
seq_length_
;
++
i
)
{
x_desc_
.
emplace_back
(
x_d
.
descriptor
<
T
>
(
dims_x
,
strides_x
));
y_desc_
.
emplace_back
(
y_d
.
descriptor
<
T
>
(
dims_y
,
strides_y
));
}
if
(
!
sequence_length
.
empty
())
{
x_seq_desc_
=
x_seq_d
.
descriptor
<
T
>
(
seq_length_
,
batch_size_
,
input_size_
,
true
,
sequence_length
);
y_seq_desc_
=
y_seq_d
.
descriptor
<
T
>
(
seq_length_
,
batch_size_
,
hidden_size_
*
numDirections
,
true
,
sequence_length
);
}
// ------------------- cudnn hx, hy, cx, cy descriptors----------
std
::
vector
<
int
>
dims_hx
=
{
num_layers_
*
numDirections
,
batch_size_
,
hidden_size_
};
std
::
vector
<
int
>
strides_hx
=
{
hidden_size_
*
batch_size_
,
hidden_size_
,
1
};
hx_desc_
=
hx_d
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
cx_desc_
=
cx_d
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
hy_desc_
=
hy_d
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
cy_desc_
=
cy_d
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
// ------------------- cudnn dropout descriptors ---------------------
size_t
state_size
;
if
(
!
initialized_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
cudnnDropoutGetStatesSize
(
handle
,
&
state_size
));
dropout_state
->
mutable_data
<
uint8_t
>
({
static_cast
<
int64_t
>
(
state_size
)},
place
);
}
dropout_desc_
=
dropout_d
.
descriptor
(
handle
,
place
,
initialized_
,
dropout_prob_
,
dropout_state
,
seed_
,
state_size
);
// ------------------- cudnn rnn descriptors ---------------------
rnn_desc_
=
rnn_d
.
descriptor
();
#if CUDNN_VERSION >= 6000
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetRNNDescriptor_v6
(
handle
,
rnn_desc_
,
hidden_size_
,
num_layers_
,
dropout_desc_
,
CUDNN_LINEAR_INPUT
,
is_bidirec_
?
CUDNN_BIDIRECTIONAL
:
CUDNN_UNIDIRECTIONAL
,
CUDNN_LSTM
,
CUDNN_RNN_ALGO_STANDARD
,
cudnn_type
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetRNNDescriptor
(
rnn_desc_
,
hidden_size_
,
num_layers_
,
dropout_desc_
,
CUDNN_LINEAR_INPUT
,
is_bidirec_
?
CUDNN_BIDIRECTIONAL
:
CUDNN_UNIDIRECTIONAL
,
CUDNN_LSTM
,
cudnn_type
));
#endif
if
(
!
sequence_length
.
empty
())
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetRNNPaddingMode
(
rnn_desc_
,
CUDNN_RNN_PADDED_IO_ENABLED
));
}
// ------------------- cudnn weights_size ---------------------
size_t
weights_size_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetRNNParamsSize
(
handle
,
rnn_desc_
,
x_desc_
[
0
],
&
weights_size_
,
cudnn_type
));
PADDLE_ENFORCE_EQ
(
weights_size_
,
sizeof
(
T
)
*
weight_numel_
,
platform
::
errors
::
InvalidArgument
(
"The cudnn lstm and setting weight size should be same."
));
// ------------------- cudnn weight descriptors ---------------------
platform
::
DataLayout
layout
=
platform
::
DataLayout
::
kNCHW
;
int
dim_tmp
=
weights_size_
/
sizeof
(
T
);
std
::
vector
<
int
>
dim_w
=
{
dim_tmp
,
1
,
1
};
w_desc_
=
w_d
.
descriptor
<
T
>
(
layout
,
dim_w
);
// ------------------- cudnn workspace, reserve size ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetRNNWorkspaceSize
(
handle
,
rnn_desc_
,
seq_length_
,
x_desc_
.
data
(),
workspace_size
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetRNNTrainingReserveSize
(
handle
,
rnn_desc_
,
seq_length_
,
x_desc_
.
data
(),
reserve_size
));
}
cudnnTensorDescriptor_t
*
x_desc
()
{
return
x_desc_
.
data
();
}
cudnnTensorDescriptor_t
*
y_desc
()
{
return
y_desc_
.
data
();
}
cudnnRNNDataDescriptor_t
x_seq_desc
()
{
return
x_seq_desc_
;
}
cudnnRNNDataDescriptor_t
y_seq_desc
()
{
return
y_seq_desc_
;
}
cudnnTensorDescriptor_t
hx_desc
()
{
return
hx_desc_
;
}
cudnnTensorDescriptor_t
cx_desc
()
{
return
cx_desc_
;
}
cudnnTensorDescriptor_t
hy_desc
()
{
return
hy_desc_
;
}
cudnnTensorDescriptor_t
cy_desc
()
{
return
cy_desc_
;
}
cudnnRNNDescriptor_t
rnn_desc
()
{
return
rnn_desc_
;
}
cudnnDropoutDescriptor_t
dropout_desc
()
{
return
dropout_desc_
;
}
cudnnFilterDescriptor_t
w_desc
()
{
return
w_desc_
;
}
private:
int
seq_length_
;
int
batch_size_
;
int
input_size_
;
int
hidden_size_
;
int
num_layers_
;
float
dropout_prob_
;
int
seed_
;
int
weight_numel_
;
bool
initialized_
;
bool
is_bidirec_
;
std
::
vector
<
cudnnTensorDescriptor_t
>
x_desc_
;
std
::
vector
<
cudnnTensorDescriptor_t
>
y_desc_
;
cudnnRNNDataDescriptor_t
x_seq_desc_
;
cudnnRNNDataDescriptor_t
y_seq_desc_
;
// A tensor descriptor describing the initial hidden state of the RNN.
cudnnTensorDescriptor_t
hx_desc_
;
// A tensor descriptor describing the initial cell state for LSTM networks.
cudnnTensorDescriptor_t
cx_desc_
;
// A tensor descriptor describing the final hidden state of the RNN.
cudnnTensorDescriptor_t
hy_desc_
;
// A tensor descriptor describing the final cell state for LSTM networks.
cudnnTensorDescriptor_t
cy_desc_
;
cudnnDropoutDescriptor_t
dropout_desc_
;
cudnnFilterDescriptor_t
w_desc_
;
cudnnRNNDescriptor_t
rnn_desc_
;
ScopedTensorDescriptor
x_d
;
ScopedTensorDescriptor
y_d
;
ScopedRNNTensorDescriptor
x_seq_d
;
ScopedRNNTensorDescriptor
y_seq_d
;
ScopedTensorDescriptor
hx_d
;
ScopedTensorDescriptor
cx_d
;
ScopedTensorDescriptor
hy_d
;
ScopedTensorDescriptor
cy_d
;
ScopedDropoutDescriptor
dropout_d
;
ScopedFilterDescriptor
w_d
;
ScopedRNNDescriptor
rnn_d
;
};
class
ScopedConvolutionDescriptor
{
class
ScopedConvolutionDescriptor
{
public:
public:
ScopedConvolutionDescriptor
()
{
ScopedConvolutionDescriptor
()
{
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
6abd05f2
...
@@ -101,6 +101,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
...
@@ -101,6 +101,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNWorkspaceSize); \
...
@@ -109,6 +112,11 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
...
@@ -109,6 +112,11 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx);
__macro(cudnnSetTensorNdDescriptorEx);
...
...
paddle/scripts/paddle_build.bat
浏览文件 @
6abd05f2
...
@@ -58,7 +58,7 @@ if not defined WITH_AVX set WITH_AVX=ON
...
@@ -58,7 +58,7 @@ if not defined WITH_AVX set WITH_AVX=ON
if
not
defined
WITH_TESTING
set
WITH_TESTING
=
ON
if
not
defined
WITH_TESTING
set
WITH_TESTING
=
ON
if
not
defined
WITH_PYTHON
set
WITH_PYTHON
=
ON
if
not
defined
WITH_PYTHON
set
WITH_PYTHON
=
ON
if
not
defined
ON_INFER
set
ON_INFER
=
ON
if
not
defined
ON_INFER
set
ON_INFER
=
ON
if
not
defined
WITH_INFERENCE_API_TEST
set
WITH_INFERENCE_API_TEST
=
O
FF
if
not
defined
WITH_INFERENCE_API_TEST
set
WITH_INFERENCE_API_TEST
=
O
N
if
not
defined
WITH_TPCACHE
set
WITH_TPCACHE
=
ON
if
not
defined
WITH_TPCACHE
set
WITH_TPCACHE
=
ON
rem ------set cache third_party------
rem ------set cache third_party------
...
...
python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
浏览文件 @
6abd05f2
...
@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
...
@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from
paddle.fluid.dygraph.dygraph_to_static.logging_utils
import
TranslatorLogger
from
paddle.fluid.dygraph.dygraph_to_static.logging_utils
import
TranslatorLogger
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
StaticLayer
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
StaticLayer
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
convert_to_static
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
convert_to_static
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
unwrap_decorators
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.fluid.dygraph.layers
import
Layer
# TODO(liym27): A better way to do this.
# TODO(liym27): A better way to do this.
...
@@ -118,14 +119,9 @@ def convert_call(func):
...
@@ -118,14 +119,9 @@ def convert_call(func):
func_self
=
None
func_self
=
None
converted_call
=
None
converted_call
=
None
# Function in convert_call may be decorated by another `@
declarative
`,
# Function in convert_call may be decorated by another `@
to_static
`,
# in this case, unwraps it into a raw method or function.
# in this case, unwraps it into a raw method or function.
if
isinstance
(
func
,
StaticLayer
):
_
,
func
=
unwrap_decorators
(
func
)
instance
=
func
.
_class_instance
if
instance
is
not
None
:
func
=
func
.
dygraph_function
.
__get__
(
instance
)
else
:
func
=
func
.
dygraph_function
if
is_builtin_len
(
func
):
if
is_builtin_len
(
func
):
return
convert_len
return
convert_len
...
@@ -155,7 +151,8 @@ def convert_call(func):
...
@@ -155,7 +151,8 @@ def convert_call(func):
if
inspect
.
isfunction
(
fn
):
if
inspect
.
isfunction
(
fn
):
global_functions
.
add
(
fn
)
global_functions
.
add
(
fn
)
elif
isinstance
(
fn
,
StaticLayer
):
elif
isinstance
(
fn
,
StaticLayer
):
global_functions
.
add
(
fn
.
dygraph_function
)
_
,
fn
=
unwrap_decorators
(
fn
)
global_functions
.
add
(
fn
)
if
func
in
global_functions
:
if
func
in
global_functions
:
converted_call
=
convert_to_static
(
func
)
converted_call
=
convert_to_static
(
func
)
...
@@ -189,7 +186,8 @@ def convert_call(func):
...
@@ -189,7 +186,8 @@ def convert_call(func):
elif
hasattr
(
func
,
'__class__'
)
and
hasattr
(
func
.
__class__
,
'__call__'
):
elif
hasattr
(
func
,
'__class__'
)
and
hasattr
(
func
.
__class__
,
'__call__'
):
if
hasattr
(
func
,
'forward'
)
and
isinstance
(
func
,
Layer
):
if
hasattr
(
func
,
'forward'
)
and
isinstance
(
func
,
Layer
):
try
:
try
:
forward_func
=
convert_to_static
(
func
.
forward
)
_
,
forward_func
=
unwrap_decorators
(
func
.
forward
)
forward_func
=
convert_to_static
(
forward_func
)
setattr
(
func
,
'forward'
,
forward_func
)
setattr
(
func
,
'forward'
,
forward_func
)
func_self
=
func
func_self
=
func
except
Exception
:
except
Exception
:
...
...
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
6abd05f2
...
@@ -21,6 +21,7 @@ import six
...
@@ -21,6 +21,7 @@ import six
import
textwrap
import
textwrap
import
threading
import
threading
import
warnings
import
warnings
import
weakref
import
gast
import
gast
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
...
@@ -245,6 +246,7 @@ class StaticLayer(object):
...
@@ -245,6 +246,7 @@ class StaticLayer(object):
self
.
_input_spec
=
input_spec
self
.
_input_spec
=
input_spec
self
.
_function_spec
=
FunctionSpec
(
function
,
input_spec
)
self
.
_function_spec
=
FunctionSpec
(
function
,
input_spec
)
self
.
_program_cache
=
ProgramCache
()
self
.
_program_cache
=
ProgramCache
()
self
.
_descriptor_cache
=
weakref
.
WeakKeyDictionary
()
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
self
.
_program_trans
=
ProgramTranslator
()
self
.
_program_trans
=
ProgramTranslator
()
...
@@ -271,8 +273,19 @@ class StaticLayer(object):
...
@@ -271,8 +273,19 @@ class StaticLayer(object):
of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
to parse the class instance correctly instead of the `StaticLayer` instance.
to parse the class instance correctly instead of the `StaticLayer` instance.
"""
"""
self
.
_class_instance
=
instance
if
instance
not
in
self
.
_descriptor_cache
:
return
self
if
instance
is
None
:
return
self
# Note(Aurelius84): To construct new instance of StaticLayer when we
# first encouter the bound function of layer and cache it.
new_static_layer
=
self
.
_clone
()
new_static_layer
.
_class_instance
=
instance
self
.
_descriptor_cache
[
instance
]
=
new_static_layer
return
self
.
_descriptor_cache
[
instance
]
def
_clone
(
self
):
return
self
.
__class__
(
self
.
_dygraph_function
,
self
.
_input_spec
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
"""
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py
浏览文件 @
6abd05f2
...
@@ -19,7 +19,7 @@ import paddle
...
@@ -19,7 +19,7 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.static
import
InputSpec
from
paddle.static
import
InputSpec
from
paddle.fluid.dygraph
import
to_variable
,
declarative
,
ProgramTranslator
,
Layer
,
jit
from
paddle.fluid.dygraph
import
to_variable
,
declarative
,
ProgramTranslator
,
Layer
,
jit
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
ConcreteProgram
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
ConcreteProgram
,
StaticLayer
from
test_basic_api_transformation
import
dyfunc_to_variable
from
test_basic_api_transformation
import
dyfunc_to_variable
...
@@ -84,6 +84,23 @@ class SimpleNet(Layer):
...
@@ -84,6 +84,23 @@ class SimpleNet(Layer):
return
z
return
z
class
TestStaticLayerInstance
(
unittest
.
TestCase
):
def
test_instance_same_class
(
self
):
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
net_1
=
SimpleNet
()
net_2
=
SimpleNet
()
self
.
assertTrue
(
isinstance
(
net_1
.
forward
,
StaticLayer
))
self
.
assertTrue
(
isinstance
(
net_2
.
forward
,
StaticLayer
))
self
.
assertNotEqual
(
net_1
.
forward
,
net_2
.
forward
)
# convert layer into static progam of net_1
net_1
.
forward
.
concrete_program
self
.
assertTrue
(
len
(
net_1
.
forward
.
program_cache
)
==
1
)
# check no conversion applid with net_2
self
.
assertTrue
(
len
(
net_2
.
forward
.
program_cache
)
==
0
)
class
TestInputSpec
(
unittest
.
TestCase
):
class
TestInputSpec
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
pass
pass
...
@@ -224,7 +241,6 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
...
@@ -224,7 +241,6 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
# 1. specific InputSpec for `x`/`y`
# 1. specific InputSpec for `x`/`y`
concrete_program_1
=
foo
.
get_concrete_program
(
concrete_program_1
=
foo
.
get_concrete_program
(
InputSpec
([
None
,
10
]),
InputSpec
([
10
]))
InputSpec
([
None
,
10
]),
InputSpec
([
10
]))
print
(
concrete_program_1
)
self
.
assertTrue
(
len
(
foo
.
program_cache
)
==
1
)
self
.
assertTrue
(
len
(
foo
.
program_cache
)
==
1
)
# 2. specific `c`/`d` explicitly with same default value
# 2. specific `c`/`d` explicitly with same default value
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
浏览文件 @
6abd05f2
...
@@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase):
...
@@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase):
x
=
fluid
.
dygraph
.
to_variable
(
x_data
)
x
=
fluid
.
dygraph
.
to_variable
(
x_data
)
out
=
net
(
x
)
out
=
net
(
x
)
program_cache
=
SimpleFcLayer
.
forward
.
program_cache
program_cache
=
net
.
forward
.
program_cache
_
,
(
concrete_program
,
_
)
=
program_cache
.
last
()
_
,
(
concrete_program
,
_
)
=
program_cache
.
last
()
params
=
concrete_program
.
parameters
params
=
concrete_program
.
parameters
...
...
python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py
浏览文件 @
6abd05f2
...
@@ -16,6 +16,7 @@ from __future__ import print_function
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
math
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -27,120 +28,372 @@ SIGMOID_THRESHOLD_MAX = 13.0
...
@@ -27,120 +28,372 @@ SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT
=
40.0
EXP_MAX_INPUT
=
40.0
def
lstm_naive
(
input
,
w
):
class
LayerMixin
(
object
):
seq_len
,
batch_size
,
hidden_size
=
input
.
shape
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
offset
=
0
wi
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
(
hidden_size
,
hidden_size
)).
transpose
()
class
LayerListMixin
(
LayerMixin
):
offset
+=
hidden_size
*
hidden_size
def
__init__
(
self
,
layers
=
None
):
wf
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
self
.
_layers
=
list
(
layers
)
if
layers
else
[]
(
hidden_size
,
hidden_size
)).
transpose
()
offset
+=
hidden_size
*
hidden_size
def
append
(
self
,
layer
):
wc
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
self
.
_layers
.
append
(
layer
)
(
hidden_size
,
hidden_size
)).
transpose
()
offset
+=
hidden_size
*
hidden_size
def
__iter__
(
self
):
wo
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
return
iter
(
self
.
_layers
)
(
hidden_size
,
hidden_size
)).
transpose
()
offset
+=
hidden_size
*
hidden_size
ri
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
class
LSTMCell
(
LayerMixin
):
(
hidden_size
,
hidden_size
)).
transpose
()
def
__init__
(
self
,
input_size
,
hidden_size
,
bias
=
True
):
offset
+=
hidden_size
*
hidden_size
self
.
input_size
=
input_size
rf
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
self
.
hidden_size
=
hidden_size
(
hidden_size
,
hidden_size
)).
transpose
()
self
.
bias
=
bias
offset
+=
hidden_size
*
hidden_size
self
.
dtype
=
np
.
float64
rc
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
self
.
parameters
=
dict
()
(
hidden_size
,
hidden_size
)).
transpose
()
std
=
1.0
/
math
.
sqrt
(
hidden_size
)
offset
+=
hidden_size
*
hidden_size
self
.
weight_ih
=
np
.
ones
(
ro
=
w
[
offset
:
offset
+
hidden_size
*
hidden_size
].
reshape
(
(
4
*
hidden_size
,
input_size
),
dtype
=
self
.
dtype
)
(
hidden_size
,
hidden_size
)).
transpose
()
self
.
weight_hh
=
np
.
ones
((
4
*
hidden_size
,
offset
+=
hidden_size
*
hidden_size
hidden_size
)).
astype
(
self
.
dtype
)
self
.
parameters
[
'weight_ih'
]
=
self
.
weight_ih
bi_1
=
w
[
offset
:
offset
+
hidden_size
]
self
.
parameters
[
'weight_hh'
]
=
self
.
weight_hh
offset
+=
hidden_size
if
bias
:
bf_1
=
w
[
offset
:
offset
+
hidden_size
]
self
.
bias_ih
=
np
.
ones
((
4
*
hidden_size
)).
astype
(
self
.
dtype
)
offset
+=
hidden_size
self
.
bias_hh
=
np
.
ones
((
4
*
hidden_size
)).
astype
(
self
.
dtype
)
bc_1
=
w
[
offset
:
offset
+
hidden_size
]
self
.
parameters
[
'bias_ih'
]
=
self
.
bias_ih
offset
+=
hidden_size
self
.
parameters
[
'bias_hh'
]
=
self
.
bias_hh
bo_1
=
w
[
offset
:
offset
+
hidden_size
]
else
:
offset
+=
hidden_size
self
.
bias_ih
=
None
self
.
bias_hh
=
None
bi_2
=
w
[
offset
:
offset
+
hidden_size
]
offset
+=
hidden_size
def
init_state
(
self
,
inputs
):
bf_2
=
w
[
offset
:
offset
+
hidden_size
]
batch_size
=
inputs
.
shape
[
0
]
offset
+=
hidden_size
init_h
=
np
.
zeros
((
batch_size
,
self
.
hidden_size
),
dtype
=
inputs
.
dtype
)
bc_2
=
w
[
offset
:
offset
+
hidden_size
]
init_c
=
np
.
zeros
((
batch_size
,
self
.
hidden_size
),
dtype
=
inputs
.
dtype
)
offset
+=
hidden_size
return
init_h
,
init_c
bo_2
=
w
[
offset
:
offset
+
hidden_size
]
def
forward
(
self
,
inputs
,
hx
=
None
):
def
sigmoid
(
x
):
if
hx
is
None
:
y
=
np
.
copy
(
x
)
hx
=
self
.
init_state
(
inputs
)
y
[
x
<
SIGMOID_THRESHOLD_MIN
]
=
SIGMOID_THRESHOLD_MIN
pre_hidden
,
pre_cell
=
hx
y
[
x
>
SIGMOID_THRESHOLD_MAX
]
=
SIGMOID_THRESHOLD_MAX
gates
=
np
.
matmul
(
inputs
,
self
.
weight_ih
.
T
)
return
1.
/
(
1.
+
np
.
exp
(
-
y
))
if
self
.
bias_ih
is
not
None
:
gates
=
gates
+
self
.
bias_ih
def
tanh
(
x
):
gates
+=
np
.
matmul
(
pre_hidden
,
self
.
weight_hh
.
T
)
y
=
-
2.
*
x
if
self
.
bias_hh
is
not
None
:
y
[
y
>
EXP_MAX_INPUT
]
=
EXP_MAX_INPUT
gates
=
gates
+
self
.
bias_hh
return
(
2.
/
(
1.
+
np
.
exp
(
y
)))
-
1.
chunked_gates
=
np
.
split
(
gates
,
4
,
-
1
)
output
=
[]
pre_h
=
np
.
zeros
((
1
,
batch_size
,
hidden_size
),
dtype
=
input
.
dtype
)
i
=
1.0
/
(
1.0
+
np
.
exp
(
-
chunked_gates
[
0
]))
pre_c
=
np
.
zeros
((
1
,
batch_size
,
hidden_size
),
dtype
=
input
.
dtype
)
f
=
1.0
/
(
1.0
+
np
.
exp
(
-
chunked_gates
[
1
]))
o
=
1.0
/
(
1.0
+
np
.
exp
(
-
chunked_gates
[
3
]))
for
i
in
range
(
seq_len
):
c
=
f
*
pre_cell
+
i
*
np
.
tanh
(
chunked_gates
[
2
])
emb_1
=
input
[
i
]
h
=
o
*
np
.
tanh
(
c
)
input_gate
=
sigmoid
(
return
h
,
(
h
,
c
)
np
.
matmul
(
emb_1
,
wi
)
+
np
.
matmul
(
pre_h
,
ri
)
+
bi_1
+
bi_2
)
forget_gate
=
sigmoid
(
np
.
matmul
(
emb_1
,
wf
)
+
np
.
matmul
(
pre_h
,
rf
)
+
bf_1
+
bf_2
)
def
sequence_mask
(
lengths
,
max_len
=
None
):
output_gate
=
sigmoid
(
if
max_len
is
None
:
np
.
matmul
(
emb_1
,
wo
)
+
np
.
matmul
(
pre_h
,
ro
)
+
bo_1
+
bo_2
)
max_len
=
np
.
max
(
lengths
)
c_t_temp
=
tanh
(
else
:
np
.
matmul
(
emb_1
,
wc
)
+
np
.
matmul
(
pre_h
,
rc
)
+
bc_1
+
bc_2
)
assert
max_len
>=
np
.
max
(
lengths
)
new_c
=
input_gate
*
c_t_temp
+
forget_gate
*
pre_c
return
np
.
arange
(
max_len
)
<
np
.
expand_dims
(
lengths
,
-
1
)
new_h
=
output_gate
*
tanh
(
new_c
)
pre_h
=
new_h
def
update_state
(
mask
,
new
,
old
):
pre_c
=
new_c
if
not
isinstance
(
old
,
(
tuple
,
list
)):
return
np
.
where
(
mask
,
new
,
old
)
output
.
append
(
new_h
)
else
:
return
tuple
(
map
(
lambda
x
,
y
:
np
.
where
(
mask
,
x
,
y
),
new
,
old
))
output
=
np
.
concatenate
(
output
,
-
1
)
output
=
output
.
reshape
((
batch_size
,
-
1
,
hidden_size
))
output
=
output
.
transpose
((
1
,
0
,
2
))
def
rnn
(
cell
,
inputs
,
return
output
,
pre_h
,
pre_c
initial_states
,
sequence_length
=
None
,
time_major
=
False
,
is_reverse
=
False
):
if
not
time_major
:
inputs
=
np
.
transpose
(
inputs
,
[
1
,
0
,
2
])
if
is_reverse
:
inputs
=
np
.
flip
(
inputs
,
0
)
if
sequence_length
is
None
:
mask
=
None
else
:
mask
=
np
.
transpose
(
sequence_mask
(
sequence_length
),
[
1
,
0
])
mask
=
np
.
expand_dims
(
mask
,
-
1
)
if
is_reverse
:
mask
=
np
.
flip
(
mask
,
0
)
time_steps
=
inputs
.
shape
[
0
]
state
=
initial_states
outputs
=
[]
for
t
in
range
(
time_steps
):
x_t
=
inputs
[
t
]
if
mask
is
not
None
:
m_t
=
mask
[
t
]
y
,
new_state
=
cell
(
x_t
,
state
)
y
=
np
.
where
(
m_t
,
y
,
0.
)
outputs
.
append
(
y
)
state
=
update_state
(
m_t
,
new_state
,
state
)
else
:
y
,
new_state
=
cell
(
x_t
,
state
)
outputs
.
append
(
y
)
state
=
new_state
outputs
=
np
.
stack
(
outputs
)
final_state
=
state
if
is_reverse
:
outputs
=
np
.
flip
(
outputs
,
0
)
if
not
time_major
:
outputs
=
np
.
transpose
(
outputs
,
[
1
,
0
,
2
])
return
outputs
,
final_state
def
birnn
(
cell_fw
,
cell_bw
,
inputs
,
initial_states
,
sequence_length
=
None
,
time_major
=
False
):
states_fw
,
states_bw
=
initial_states
outputs_fw
,
states_fw
=
rnn
(
cell_fw
,
inputs
,
states_fw
,
sequence_length
,
time_major
=
time_major
)
outputs_bw
,
states_bw
=
rnn
(
cell_bw
,
inputs
,
states_bw
,
sequence_length
,
time_major
=
time_major
,
is_reverse
=
True
)
outputs
=
np
.
concatenate
((
outputs_fw
,
outputs_bw
),
-
1
)
final_states
=
(
states_fw
,
states_bw
)
return
outputs
,
final_states
def
flatten
(
nested
):
return
list
(
_flatten
(
nested
))
def
_flatten
(
nested
):
for
item
in
nested
:
if
isinstance
(
item
,
(
list
,
tuple
)):
for
subitem
in
_flatten
(
item
):
yield
subitem
else
:
yield
item
def
unstack
(
array
,
axis
=
0
):
num
=
array
.
shape
[
axis
]
sub_arrays
=
np
.
split
(
array
,
num
,
axis
)
return
[
np
.
squeeze
(
sub_array
,
axis
)
for
sub_array
in
sub_arrays
]
def
dropout
(
array
,
p
=
0.0
):
if
p
==
0.0
:
return
array
mask
=
(
np
.
random
.
uniform
(
size
=
array
.
shape
)
<
(
1
-
p
)).
astype
(
array
.
dtype
)
return
array
*
(
mask
/
(
1
-
p
))
def
split_states
(
states
,
bidirectional
=
False
,
state_components
=
1
):
if
state_components
==
1
:
states
=
unstack
(
states
)
if
not
bidirectional
:
return
states
else
:
return
list
(
zip
(
states
[::
2
],
states
[
1
::
2
]))
else
:
assert
len
(
states
)
==
state_components
states
=
tuple
([
unstack
(
item
)
for
item
in
states
])
if
not
bidirectional
:
return
list
(
zip
(
*
states
))
else
:
states
=
list
(
zip
(
*
states
))
return
list
(
zip
(
states
[::
2
],
states
[
1
::
2
]))
def
concat_states
(
states
,
bidirectional
=
False
,
state_components
=
1
):
if
state_components
==
1
:
return
np
.
stack
(
flatten
(
states
))
else
:
states
=
flatten
(
states
)
componnets
=
[]
for
i
in
range
(
state_components
):
componnets
.
append
(
states
[
i
::
state_components
])
return
[
np
.
stack
(
item
)
for
item
in
componnets
]
class
RNN
(
LayerMixin
):
def
__init__
(
self
,
cell
,
is_reverse
=
False
,
time_major
=
False
):
super
(
RNN
,
self
).
__init__
()
self
.
cell
=
cell
if
not
hasattr
(
self
.
cell
,
"call"
):
# for non-dygraph mode, `rnn` api uses cell.call
self
.
cell
.
call
=
self
.
cell
.
forward
self
.
is_reverse
=
is_reverse
self
.
time_major
=
time_major
def
forward
(
self
,
inputs
,
initial_states
=
None
,
sequence_length
=
None
):
final_outputs
,
final_states
=
rnn
(
self
.
cell
,
inputs
,
initial_states
=
initial_states
,
sequence_length
=
sequence_length
,
time_major
=
self
.
time_major
,
is_reverse
=
self
.
is_reverse
)
return
final_outputs
,
final_states
class
BiRNN
(
LayerMixin
):
def
__init__
(
self
,
cell_fw
,
cell_bw
,
time_major
=
False
):
super
(
BiRNN
,
self
).
__init__
()
self
.
cell_fw
=
cell_fw
self
.
cell_bw
=
cell_bw
self
.
time_major
=
time_major
def
forward
(
self
,
inputs
,
initial_states
=
None
,
sequence_length
=
None
,
**
kwargs
):
if
isinstance
(
initial_states
,
(
list
,
tuple
)):
assert
len
(
initial_states
)
==
2
,
\
"length of initial_states should be 2 when it is a list/tuple"
else
:
initial_states
=
[
initial_states
,
initial_states
]
outputs
,
final_states
=
birnn
(
self
.
cell_fw
,
self
.
cell_bw
,
inputs
,
initial_states
,
sequence_length
,
self
.
time_major
)
return
outputs
,
final_states
class
RNNMixin
(
LayerListMixin
):
def
forward
(
self
,
inputs
,
initial_states
=
None
,
sequence_length
=
None
):
batch_index
=
1
if
self
.
time_major
else
0
batch_size
=
inputs
.
shape
[
batch_index
]
dtype
=
inputs
.
dtype
if
initial_states
is
None
:
state_shape
=
(
self
.
num_layers
*
self
.
num_directions
,
batch_size
,
self
.
hidden_size
)
if
self
.
state_components
==
1
:
initial_states
=
np
.
zeros
(
state_shape
,
dtype
)
else
:
initial_states
=
tuple
([
np
.
zeros
(
state_shape
,
dtype
)
for
_
in
range
(
self
.
state_components
)
])
states
=
split_states
(
initial_states
,
self
.
num_directions
==
2
,
self
.
state_components
)
final_states
=
[]
for
i
,
rnn_layer
in
enumerate
(
self
):
if
i
>
0
:
inputs
=
dropout
(
inputs
,
self
.
dropout
)
outputs
,
final_state
=
rnn_layer
(
inputs
,
states
[
i
],
sequence_length
)
final_states
.
append
(
final_state
)
inputs
=
outputs
final_states
=
concat_states
(
final_states
,
self
.
num_directions
==
2
,
self
.
state_components
)
return
outputs
,
final_states
class
LSTM
(
RNNMixin
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
=
1
,
direction
=
"forward"
,
dropout
=
0.
,
time_major
=
False
):
super
(
LSTM
,
self
).
__init__
()
if
direction
in
[
"forward"
,
"backward"
]:
is_reverse
=
direction
==
"backward"
cell
=
LSTMCell
(
input_size
,
hidden_size
)
self
.
append
(
RNN
(
cell
,
is_reverse
,
time_major
))
for
i
in
range
(
1
,
num_layers
):
cell
=
LSTMCell
(
hidden_size
,
hidden_size
)
self
.
append
(
RNN
(
cell
,
is_reverse
,
time_major
))
elif
direction
==
"bidirectional"
:
cell_fw
=
LSTMCell
(
input_size
,
hidden_size
)
cell_bw
=
LSTMCell
(
input_size
,
hidden_size
)
self
.
append
(
BiRNN
(
cell_fw
,
cell_bw
,
time_major
))
for
i
in
range
(
1
,
num_layers
):
cell_fw
=
LSTMCell
(
2
*
hidden_size
,
hidden_size
)
cell_bw
=
LSTMCell
(
2
*
hidden_size
,
hidden_size
)
self
.
append
(
BiRNN
(
cell_fw
,
cell_bw
,
time_major
))
else
:
raise
ValueError
(
"direction should be forward, backward or bidirectional, "
"received direction = {}"
.
format
(
direction
))
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
self
.
dropout
=
dropout
self
.
num_directions
=
2
if
direction
==
"bidirectional"
else
1
self
.
time_major
=
time_major
self
.
num_layers
=
num_layers
self
.
state_components
=
2
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
"core is not compiled with CUDA"
)
class
TestCUDNNLstmOp
(
OpTest
):
class
TestCUDNNLstmOp
(
OpTest
):
#
TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed.
#
TODO(GaoWei8): Need to satisfy the result through the new interface
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cudnn_lstm"
self
.
op_type
=
"cudnn_lstm"
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float64
self
.
sequence_length
=
np
.
array
([
12
,
11
,
10
,
9
,
8
],
dtype
=
np
.
int32
)
self
.
num_layers
=
1
seq_length
=
20
seq_length
=
12
batch_size
=
5
batch_size
=
5
hidden_size
=
20
input_size
=
21
hidden_size
=
21
input_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
input_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
hidden_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
hidden_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
weight_size
=
input_weight_size
+
hidden_weight_size
weight_size
=
input_weight_size
+
hidden_weight_size
weight_size
+=
hidden_size
*
8
weight_size
+=
hidden_size
*
8
weight_size
*=
self
.
num_layers
input
=
np
.
random
.
uniform
(
input
=
np
.
random
.
uniform
(
low
=-
0.1
,
high
=
0.1
,
size
=
(
seq_length
,
batch_size
,
low
=-
0.1
,
high
=
0.1
,
hidden_size
)).
astype
(
self
.
dtype
)
size
=
(
seq_length
,
batch_size
,
input_size
)).
astype
(
self
.
dtype
)
flat_w
=
np
.
random
.
uniform
(
input
[
11
][
1
:][:]
=
0
low
=-
0.1
,
high
=
0.1
,
size
=
(
weight_size
)).
astype
(
self
.
dtype
)
input
[
10
][
2
:][:]
=
0
input
[
9
][
3
:][:]
=
0
output
,
last_hidden
,
last_cell
=
lstm_naive
(
input
,
flat_w
)
input
[
8
][
4
:][:]
=
0
init_h
=
np
.
zeros
((
1
,
batch_size
,
hidden_size
),
dtype
=
np
.
float64
)
rnn1
=
LSTM
(
init_c
=
np
.
zeros
((
1
,
batch_size
,
hidden_size
),
dtype
=
np
.
float64
)
input_size
,
hidden_size
,
self
.
num_layers
,
time_major
=
True
,
direction
=
"forward"
)
output
,
(
last_hidden
,
last_cell
)
=
rnn1
(
input
,
sequence_length
=
self
.
sequence_length
)
flat_w
=
np
.
ones
((
weight_size
)).
astype
(
self
.
dtype
)
init_h
=
np
.
zeros
((
self
.
num_layers
,
batch_size
,
hidden_size
)).
astype
(
self
.
dtype
)
init_c
=
np
.
zeros
((
self
.
num_layers
,
batch_size
,
hidden_size
)).
astype
(
self
.
dtype
)
state_out
=
np
.
ndarray
((
300
)).
astype
(
"uint8"
)
state_out
=
np
.
ndarray
((
300
)).
astype
(
"uint8"
)
self
.
inputs
=
{
self
.
inputs
=
{
...
@@ -152,9 +405,10 @@ class TestCUDNNLstmOp(OpTest):
...
@@ -152,9 +405,10 @@ class TestCUDNNLstmOp(OpTest):
self
.
attrs
=
{
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'dropout_prob'
:
0.0
,
'is_bidirec'
:
False
,
'is_bidirec'
:
False
,
'input_size'
:
hidden
_size
,
'input_size'
:
input
_size
,
'hidden_size'
:
hidden_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
1
,
'num_layers'
:
1
,
'sequence_length'
:
self
.
sequence_length
.
tolist
()
}
}
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
output
,
'Out'
:
output
,
...
@@ -164,19 +418,33 @@ class TestCUDNNLstmOp(OpTest):
...
@@ -164,19 +418,33 @@ class TestCUDNNLstmOp(OpTest):
'StateOut'
:
state_out
'StateOut'
:
state_out
}
}
def
set_attrs
(
self
):
pass
def
test_output_with_place
(
self
):
def
test_output_with_place
(
self
):
# depend on the scope structure
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
self
.
check_output_with_place
(
place
,
no_check_set
=
[
'Reserve'
,
'StateOut'
])
place
,
no_check_set
=
[
'Reserve'
,
'StateOut'
])
def
test_grad_with_place
(
self
):
def
test_grad_with_place
(
self
):
# depend on the scope structure
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
self
.
check_grad_with_place
(
place
,
place
,
set
([
'Input'
,
'W'
,
'InitH'
,
'InitC'
]),
set
([
'Input'
,
'W'
,
'InitH'
,
'InitC'
]),
[
'Out'
,
'LastH'
,
'LastC'
],
[
'Out'
,
'LastH'
,
'LastC'
])
max_relative_error
=
1e-4
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNLstmOp2
(
TestCUDNNLstmOp
):
def
set_attrs
(
self
):
self
.
sequence_length
=
np
.
array
([],
dtype
=
np
.
int32
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNLstmOp3
(
TestCUDNNLstmOp
):
def
set_attrs
(
self
):
self
.
num_layers
=
2
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
@@ -198,7 +466,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
...
@@ -198,7 +466,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
'float64'
,
0.0
)
'float64'
,
0.0
)
rnn_out
,
last_h
,
last_c
=
layers
.
lstm
(
input
,
init_h
,
init_c
,
seq_len
,
rnn_out
,
last_h
,
last_c
=
layers
.
lstm
(
input
,
init_h
,
init_c
,
seq_len
,
hidden_size
,
num_layers
,
hidden_size
,
num_layers
,
dropout_prob
)
dropout_prob
,
False
,
True
)
exe
=
fluid
.
Executor
(
fluid
.
CUDAPlace
(
0
))
exe
=
fluid
.
Executor
(
fluid
.
CUDAPlace
(
0
))
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
input_i
=
np
.
random
.
uniform
(
input_i
=
np
.
random
.
uniform
(
...
@@ -208,12 +476,6 @@ class TestCUDNNlstmAPI(unittest.TestCase):
...
@@ -208,12 +476,6 @@ class TestCUDNNlstmAPI(unittest.TestCase):
feed
=
{
'input'
:
input_i
},
feed
=
{
'input'
:
input_i
},
fetch_list
=
[
rnn_out
,
last_h
,
last_c
,
'cudnn_lstm_0.w_0'
])
fetch_list
=
[
rnn_out
,
last_h
,
last_c
,
'cudnn_lstm_0.w_0'
])
output
,
last_hidden
,
last_cell
=
lstm_naive
(
input_i
,
out
[
3
])
self
.
assertTrue
(
np
.
allclose
(
output
,
out
[
0
],
atol
=
1e-5
))
self
.
assertTrue
(
np
.
allclose
(
last_hidden
,
out
[
1
],
atol
=
1e-5
))
self
.
assertTrue
(
np
.
allclose
(
last_cell
,
out
[
2
],
atol
=
1e-5
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录