Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
aca05d59
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aca05d59
编写于
10月 23, 2018
作者:
S
shippingwang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'release/1.0.0' of
https://github.com/PaddlePaddle/Paddle
into release/1.0.0
上级
e8fbf82d
587f3dd3
变更
35
展开全部
隐藏空白更改
内联
并排
Showing
35 changed file
with
1571 addition
and
426 deletion
+1571
-426
paddle/fluid/API.spec
paddle/fluid/API.spec
+5
-5
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+12
-4
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
+12
-0
paddle/fluid/inference/api/demo_ci/run.sh
paddle/fluid/inference/api/demo_ci/run.sh
+37
-6
paddle/fluid/inference/api/demo_ci/vis_demo.cc
paddle/fluid/inference/api/demo_ci/vis_demo.cc
+31
-15
paddle/fluid/operators/adadelta_op.cc
paddle/fluid/operators/adadelta_op.cc
+12
-0
paddle/fluid/operators/adadelta_op.h
paddle/fluid/operators/adadelta_op.h
+11
-0
paddle/fluid/operators/adagrad_op.h
paddle/fluid/operators/adagrad_op.h
+20
-13
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+9
-16
paddle/fluid/operators/adamax_op.cc
paddle/fluid/operators/adamax_op.cc
+10
-0
paddle/fluid/operators/adamax_op.h
paddle/fluid/operators/adamax_op.h
+11
-0
paddle/fluid/operators/decayed_adagrad_op.cc
paddle/fluid/operators/decayed_adagrad_op.cc
+10
-0
paddle/fluid/operators/decayed_adagrad_op.h
paddle/fluid/operators/decayed_adagrad_op.h
+11
-0
paddle/fluid/operators/ftrl_op.cc
paddle/fluid/operators/ftrl_op.cc
+10
-0
paddle/fluid/operators/ftrl_op.h
paddle/fluid/operators/ftrl_op.h
+11
-0
paddle/fluid/operators/math/algorithm.h
paddle/fluid/operators/math/algorithm.h
+44
-0
paddle/fluid/operators/momentum_op.cc
paddle/fluid/operators/momentum_op.cc
+45
-13
paddle/fluid/operators/momentum_op.cu
paddle/fluid/operators/momentum_op.cu
+3
-61
paddle/fluid/operators/momentum_op.h
paddle/fluid/operators/momentum_op.h
+311
-14
paddle/fluid/operators/rmsprop_op.cc
paddle/fluid/operators/rmsprop_op.cc
+5
-0
paddle/fluid/operators/rmsprop_op.h
paddle/fluid/operators/rmsprop_op.h
+234
-41
paddle/fluid/operators/sgd_op.cc
paddle/fluid/operators/sgd_op.cc
+16
-13
paddle/fluid/operators/sgd_op.cu
paddle/fluid/operators/sgd_op.cu
+6
-0
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+16
-16
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+140
-18
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+13
-2
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+5
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+226
-80
python/paddle/fluid/layers/ops.py
python/paddle/fluid/layers/ops.py
+12
-5
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+1
-1
python/paddle/fluid/nets.py
python/paddle/fluid/nets.py
+18
-8
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+12
-0
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+19
-2
python/paddle/fluid/tests/unittests/test_momentum_op.py
python/paddle/fluid/tests/unittests/test_momentum_op.py
+94
-0
python/paddle/fluid/tests/unittests/test_rmsprop_op.py
python/paddle/fluid/tests/unittests/test_rmsprop_op.py
+139
-92
未找到文件。
paddle/fluid/API.spec
浏览文件 @
aca05d59
...
...
@@ -61,12 +61,12 @@ paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'
], varargs=None, keywords=None, defaults=(3, 1
, None, None, None, None))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'
, 'name'], varargs=None, keywords=None, defaults=(3, 1, None
, None, None, None, None))
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', '
param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, Fals
e))
paddle.fluid.layers.softmax ArgSpec(args=['input', '
param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None,
True, None))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', '
use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, Non
e))
paddle.fluid.layers.softmax ArgSpec(args=['input', '
use_cudnn', 'name'], varargs=None, keywords=None, defaults=(
True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
...
...
@@ -95,8 +95,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples'
], varargs=None, keywords=None, defaults=(
None, None, None, None))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr'
], varargs=None, keywords=None, defaults=(
None, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples'
, 'name'], varargs=None, keywords=None, defaults=(None,
None, None, None, None))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr'
, 'name'], varargs=None, keywords=None, defaults=(None,
None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
aca05d59
...
...
@@ -19,8 +19,18 @@ cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api)
add_subdirectory
(
api
)
set
(
STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor
)
set
(
SHARED_INFERENCE_SRCS
io.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api_impl.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/analysis_predictor.cc
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
set
(
STATIC_INFERENCE_APIS
${
STATIC_INFERENCE_APIS
}
paddle_inference_tensorrt_subgraph_engine
)
set
(
SHARED_INFERENCE_SRCS
${
SHARED_INFERENCE_SRCS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api_tensorrt_subgraph_engine.cc
)
endif
()
# Create static library
cc_library
(
paddle_fluid DEPS
${
fluid_modules
}
paddle_fluid_api paddle_inference_api analysis_predictor
)
cc_library
(
paddle_fluid DEPS
${
fluid_modules
}
${
STATIC_INFERENCE_APIS
}
)
if
(
NOT APPLE
)
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
set
(
LINK_FLAGS
"-Wl,--retain-symbols-file
${
CMAKE_CURRENT_SOURCE_DIR
}
/paddle_fluid.sym"
)
...
...
@@ -28,9 +38,7 @@ if(NOT APPLE)
endif
()
# Create shared library
cc_library
(
paddle_fluid_shared SHARED
SRCS io.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api_impl.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/analysis_predictor.cc
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
paddle_fluid_api
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
...
...
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
浏览文件 @
aca05d59
...
...
@@ -3,6 +3,7 @@ project(cpp_inference_demo CXX C)
option
(
WITH_MKL
"Compile demo with MKL/OpenBlas support, default use MKL."
ON
)
option
(
WITH_GPU
"Compile demo with GPU/CPU, default use CPU."
OFF
)
option
(
WITH_STATIC_LIB
"Compile demo with static/shared library, default use static."
ON
)
option
(
USE_TENSORRT
"Compile demo with TensorRT."
OFF
)
macro
(
safe_set_static_flag
)
foreach
(
flag_var
...
...
@@ -60,6 +61,13 @@ endif(NOT WIN32)
include_directories
(
"
${
PADDLE_LIB
}
/third_party/boost"
)
include_directories
(
"
${
PADDLE_LIB
}
/third_party/eigen3"
)
if
(
NOT WIN32
)
if
(
USE_TENSORRT AND WITH_GPU
)
include_directories
(
"
${
TENSORRT_INCLUDE_DIR
}
"
)
link_directories
(
"
${
TENSORRT_LIB_DIR
}
"
)
endif
()
endif
(
NOT WIN32
)
if
(
NOT WIN32
)
link_directories
(
"
${
PADDLE_LIB
}
/third_party/install/snappy/lib"
)
link_directories
(
"
${
PADDLE_LIB
}
/third_party/install/snappystream/lib"
)
...
...
@@ -112,6 +120,10 @@ endif(NOT WIN32)
if
(
WITH_GPU
)
if
(
NOT WIN32
)
if
(
USE_TENSORRT
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/libnvinfer
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/libnvinfer_plugin
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/libcudart
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
else
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudart
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
...
...
paddle/fluid/inference/api/demo_ci/run.sh
浏览文件 @
aca05d59
...
...
@@ -2,6 +2,12 @@ set -x
PADDLE_ROOT
=
$1
TURN_ON_MKL
=
$2
# use MKL or Openblas
TEST_GPU_CPU
=
$3
# test both GPU/CPU mode or only CPU mode
DATA_DIR
=
$4
# dataset
TENSORRT_INCLUDE_DIR
=
$5
# TensorRT header file dir, defalut to /usr/local/TensorRT/include
TENSORRT_LIB_DIR
=
$6
# TensorRT lib file dir, default to /usr/local/TensorRT/lib
cd
`
dirname
$0
`
current_dir
=
`
pwd
`
if
[
$2
==
ON
]
;
then
# You can export yourself if move the install path
MKL_LIB
=
${
PADDLE_ROOT
}
/build/fluid_install_dir/third_party/install/mklml/lib
...
...
@@ -13,6 +19,11 @@ else
use_gpu_list
=
'false'
fi
USE_TENSORRT
=
OFF
if
[
[
-d
"
$TENSORRT_INCLUDE_DIR
"
]
-a
[
-d
"
$TENSORRT_LIB_DIR
"
]
]
;
then
USE_TENSORRT
=
ON
fi
PREFIX
=
inference-vis-demos%2F
URL_ROOT
=
http://paddlemodels.cdn.bcebos.com/
${
PREFIX
}
...
...
@@ -29,15 +40,15 @@ function download() {
fi
cd
..
}
mkdir
-p
data
cd
data
mkdir
-p
$DATA_DIR
cd
$DATA_DIR
vis_demo_list
=
'se_resnext50 ocr mobilenet'
for
vis_demo_name
in
$vis_demo_list
;
do
download
$vis_demo_name
done
cd
..
# compile and test the demo
cd
$current_dir
mkdir
-p
build
cd
build
...
...
@@ -73,9 +84,9 @@ for WITH_STATIC_LIB in ON OFF; do
for
use_gpu
in
$use_gpu_list
;
do
for
vis_demo_name
in
$vis_demo_list
;
do
./vis_demo
\
--modeldir
=
../data
/
$vis_demo_name
/model
\
--data
=
../data
/
$vis_demo_name
/data.txt
\
--refer
=
../data
/
$vis_demo_name
/result.txt
\
--modeldir
=
$DATA_DIR
/
$vis_demo_name
/model
\
--data
=
$DATA_DIR
/
$vis_demo_name
/data.txt
\
--refer
=
$DATA_DIR
/
$vis_demo_name
/result.txt
\
--use_gpu
=
$use_gpu
if
[
$?
-ne
0
]
;
then
echo
"vis demo
$vis_demo_name
runs fail."
...
...
@@ -83,5 +94,25 @@ for WITH_STATIC_LIB in ON OFF; do
fi
done
done
# --------tensorrt mobilenet------
if
[
$USE_TENSORRT
==
ON
-a
$TEST_GPU_CPU
==
ON
]
;
then
rm
-rf
*
cmake ..
-DPADDLE_LIB
=
${
PADDLE_ROOT
}
/build/fluid_install_dir/
\
-DWITH_MKL
=
$TURN_ON_MKL
\
-DDEMO_NAME
=
vis_demo
\
-DWITH_GPU
=
$TEST_GPU_CPU
\
-DWITH_STATIC_LIB
=
$WITH_STATIC_LIB
\
-DUSE_TENSORRT
=
$USE_TENSORRT
\
-DTENSORRT_INCLUDE_DIR
=
$TENSORRT_INCLUDE_DIR
\
-DTENSORRT_LIB_DIR
=
$TENSORRT_LIB_DIR
make
-j
./vis_demo
\
--modeldir
=
$DATA_DIR
/mobilenet/model
\
--data
=
$DATA_DIR
/mobilenet/data.txt
\
--refer
=
$DATA_DIR
/mobilenet/result.txt
\
--use_gpu
=
true
\
--use_trt
=
true
fi
done
set
+x
paddle/fluid/inference/api/demo_ci/vis_demo.cc
浏览文件 @
aca05d59
...
...
@@ -33,6 +33,7 @@ DEFINE_string(
"path of data; each line is a record, format is "
"'<space splitted floats as data>
\t
<space splitted ints as shape'"
);
DEFINE_bool
(
use_gpu
,
false
,
"Whether use gpu."
);
DEFINE_bool
(
use_trt
,
false
,
"Whether use trt."
);
namespace
paddle
{
namespace
demo
{
...
...
@@ -100,20 +101,32 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) {
/*
* Use the native fluid engine to inference the demo.
*/
void
Main
(
bool
use_gpu
)
{
NativeConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
use_gpu
;
config
.
device
=
0
;
if
(
FLAGS_use_gpu
)
{
void
Main
(
bool
use_gpu
,
bool
use_trt
)
{
std
::
unique_ptr
<
PaddlePredictor
>
predictor
;
if
(
!
use_trt
)
{
NativeConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
use_gpu
;
config
.
device
=
0
;
if
(
FLAGS_use_gpu
)
{
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
}
VLOG
(
3
)
<<
"init predictor"
;
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
}
else
{
paddle
::
contrib
::
MixedRTConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
use_gpu
=
true
;
config
.
device
=
0
;
config
.
max_batch_size
=
1
;
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
predictor
=
CreatePaddlePredictor
<
paddle
::
contrib
::
MixedRTConfig
>
(
config
);
}
VLOG
(
3
)
<<
"init predictor"
;
auto
predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
VLOG
(
3
)
<<
"begin to process data"
;
// Just a single batch of data.
std
::
string
line
;
...
...
@@ -131,7 +144,7 @@ void Main(bool use_gpu) {
VLOG
(
3
)
<<
"run executor"
;
std
::
vector
<
PaddleTensor
>
output
;
predictor
->
Run
({
input
},
&
output
);
predictor
->
Run
({
input
},
&
output
,
1
);
VLOG
(
3
)
<<
"output.size "
<<
output
.
size
();
auto
&
tensor
=
output
.
front
();
...
...
@@ -146,9 +159,12 @@ void Main(bool use_gpu) {
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
paddle
::
demo
::
Main
(
false
/* use_gpu*/
);
if
(
FLAGS_use_gpu
)
{
paddle
::
demo
::
Main
(
true
/*use_gpu*/
);
if
(
FLAGS_use_gpu
&&
FLAGS_use_trt
)
{
paddle
::
demo
::
Main
(
true
/*use_gpu*/
,
true
);
}
else
if
(
FLAGS_use_gpu
)
{
paddle
::
demo
::
Main
(
true
/*use_gpu*/
,
false
);
}
else
{
paddle
::
demo
::
Main
(
false
/*use_gpu*/
,
false
/*use_tensorrt*/
);
}
return
0
;
}
paddle/fluid/operators/adadelta_op.cc
浏览文件 @
aca05d59
...
...
@@ -18,6 +18,7 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
AdadeltaOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"AvgSquaredUpdate"
),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of AdadeltaOp should not be null."
);
...
...
@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"AvgSquaredGradOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"AvgSquaredUpdateOut"
,
param_dim
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
...
...
paddle/fluid/operators/adadelta_op.h
浏览文件 @
aca05d59
...
...
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class
AdadeltaOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
avg_squared_grad_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"AvgSquaredGradOut"
);
...
...
paddle/fluid/operators/adagrad_op.h
浏览文件 @
aca05d59
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -21,25 +22,31 @@ namespace operators {
template
<
typename
DeviceContext
,
typename
T
>
struct
SparseAdagradFunctor
{
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
);
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
AdagradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
auto
*
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
param_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
param
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
));
...
...
@@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
));
auto
moment
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
));
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out_tensor
);
auto
moment_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
moment_out_tensor
);
auto
*
place
=
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
*
place
=
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
moment_out
.
device
(
*
place
)
=
moment
+
grad
*
grad
;
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment_out_tensor
->
numel
());
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
auto
*
lr
=
learning_rate
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
param_out
.
device
(
*
place
)
=
param
-
lr
[
0
]
*
grad
/
(
moment_out
.
sqrt
()
+
epsilon
);
}
else
{
...
...
@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
lr
.
broadcast
(
m_dsize
)
*
grad
/
(
moment_out
.
sqrt
()
+
epsilon
);
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
param_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_tensor
,
param_out_tensor
);
auto
*
moment_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
);
auto
*
moment_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
);
PADDLE_ENFORCE_EQ
(
moment_tensor
,
moment_out_tensor
);
SparseAdagradFunctor
<
DeviceContext
,
T
>
functor
;
...
...
paddle/fluid/operators/adam_op.h
浏览文件 @
aca05d59
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
...
...
@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
row_numel_
(
row_numel
),
row_count_
(
row_count
)
{}
inline
HOSTDEVICE
int64_t
BinarySearchInRows
(
int64_t
row
)
const
{
int64_t
beg
=
0
,
end
=
row_count_
-
1
;
while
(
beg
<=
end
)
{
auto
mid
=
((
beg
+
end
)
>>
1
);
if
(
rows_
[
mid
]
==
row
)
return
mid
;
else
if
(
rows_
[
mid
]
<
row
)
beg
=
mid
+
1
;
else
end
=
mid
-
1
;
}
return
-
1
;
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
int64_t
row
=
i
/
row_numel_
;
auto
row_idx
=
BinarySearchInRows
(
row
);
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_count_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// The following code is the same as dense
...
...
@@ -244,6 +231,12 @@ template <typename DeviceContext, typename T>
class
AdamOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
using
paddle
::
framework
::
LoDTensor
;
using
paddle
::
operators
::
detail
::
Ref
;
...
...
paddle/fluid/operators/adamax_op.cc
浏览文件 @
aca05d59
...
...
@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
"Input(LearningRate) of AdamaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Beta1Pow"
),
"Input(Beta1Pow) of AdamaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of AdamaxOp should not be null."
);
...
...
paddle/fluid/operators/adamax_op.h
浏览文件 @
aca05d59
...
...
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class
AdamaxOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
auto
inf_norm_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"InfNormOut"
);
...
...
paddle/fluid/operators/decayed_adagrad_op.cc
浏览文件 @
aca05d59
...
...
@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of DecayedAdagradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of DecayedAdagradOp should not be null."
);
...
...
paddle/fluid/operators/decayed_adagrad_op.h
浏览文件 @
aca05d59
...
...
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class
DecayedAdagradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
...
...
paddle/fluid/operators/ftrl_op.cc
浏览文件 @
aca05d59
...
...
@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel {
"Input(Grad) of FTRL should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of FTRL should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of FTRL should not be null."
);
...
...
paddle/fluid/operators/ftrl_op.h
浏览文件 @
aca05d59
...
...
@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T>
class
FTRLOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
*
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
auto
*
sq_accum_out
=
ctx
.
Output
<
Tensor
>
(
"SquaredAccumOut"
);
auto
*
lin_accum_out
=
ctx
.
Output
<
Tensor
>
(
"LinearAccumOut"
);
...
...
paddle/fluid/operators/math/algorithm.h
0 → 100644
浏览文件 @
aca05d59
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstdint> // for int64_t
#include <numeric>
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
HOSTDEVICE
inline
int64_t
BinarySearch
(
const
T
*
x
,
int64_t
num
,
const
T
&
val
)
{
int64_t
beg
=
0
,
end
=
num
-
1
;
while
(
beg
<=
end
)
{
auto
mid
=
((
beg
+
end
)
>>
1
);
if
(
x
[
mid
]
==
val
)
return
mid
;
else
if
(
x
[
mid
]
<
val
)
beg
=
mid
+
1
;
else
end
=
mid
-
1
;
}
return
-
1
;
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/momentum_op.cc
浏览文件 @
aca05d59
...
...
@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(param) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
@@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel {
"Input(velocity) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of Momentum should not be null."
);
...
...
@@ -40,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel {
"Output(VelocityOut) of Momentum should not be null."
);
auto
param_dim
=
ctx
->
GetInputDim
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
"Param and Grad input of MomentumOp should have the same dimension."
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Velocity"
),
"Param and Velocity of MomentumOp should have the same dimension."
);
if
(
ctx
->
GetInputsVarType
(
"Grad"
)[
0
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
"Param and Grad input of MomentumOp should have the same dimension."
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Velocity"
),
"Param and Velocity of MomentumOp should have the same dimension."
);
}
PADDLE_ENFORCE_EQ
(
framework
::
product
(
ctx
->
GetInputDim
(
"LearningRate"
)),
1
,
"Learning_rate should be a scalar"
);
...
...
@@ -53,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"VelocityOut"
,
param_dim
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Param"
));
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
else
{
PADDLE_THROW
(
"Only support LodTensor and SelectedRows, Unexpected Input Type."
);
}
}
}
};
class
MomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -110,6 +139,9 @@ $$
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
momentum
,
ops
::
MomentumOp
,
ops
::
MomentumOpMaker
);
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
float
>
,
ops
::
MomentumOpKernel
<
double
>
);
REGISTER_OPERATOR
(
momentum
,
ops
::
MomentumOp
,
ops
::
MomentumOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
MomentumOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/momentum_op.cu
浏览文件 @
aca05d59
...
...
@@ -15,65 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
MomentumKernel
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
bool
use_nesterov
,
T
*
p_out
,
T
*
v_out
)
{
T
lr
=
learning_rate
[
0
];
if
(
use_nesterov
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
g_val
=
g
[
i
];
T
v_new
=
v
[
i
]
*
mu
+
g_val
;
v_out
[
i
]
=
v_new
;
p_out
[
i
]
=
p
[
i
]
-
(
g_val
+
v_new
*
mu
)
*
lr
;
}
}
else
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
v_new
=
v
[
i
]
*
mu
+
g
[
i
];
v_out
[
i
]
=
v_new
;
p_out
[
i
]
=
p
[
i
]
-
lr
*
v_new
;
}
}
}
template
<
typename
T
>
class
MomentumOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
T
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
v_out
=
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
*
p
=
param
->
data
<
T
>
();
auto
*
v
=
velocity
->
data
<
T
>
();
auto
*
g
=
grad
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
int
block
=
512
;
int
grid
=
(
param
->
numel
()
+
block
-
1
)
/
block
;
MomentumKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
p
,
g
,
v
,
lr
,
mu
,
param
->
numel
(),
use_nesterov
,
p_out
,
v_out
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
momentum
,
ops
::
MomentumOpCUDAKernel
<
float
>
,
ops
::
MomentumOpCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/momentum_op.h
浏览文件 @
aca05d59
...
...
@@ -13,29 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
using
framework
::
Tensor
;
using
framework
::
SelectedRows
;
struct
NoNesterov
;
struct
UseNesterov
;
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
template
<
typename
T
>
class
CPUDenseMomentumFunctor
{
private:
const
Tensor
*
param
;
const
Tensor
*
grad
;
const
Tensor
*
velocity
;
const
Tensor
*
learning_rate
;
const
T
mu
;
const
T
use_nesterov
;
Tensor
*
param_out
;
Tensor
*
velocity_out
;
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
public:
CPUDenseMomentumFunctor
(
const
Tensor
*
param
,
const
Tensor
*
grad
,
const
Tensor
*
velocity
,
const
Tensor
*
learning_rate
,
const
T
mu
,
const
bool
use_nesterov
,
Tensor
*
param_out
,
Tensor
*
velocity_out
)
:
param
(
param
),
grad
(
grad
),
velocity
(
velocity
),
learning_rate
(
learning_rate
),
mu
(
mu
),
use_nesterov
(
use_nesterov
),
param_out
(
param_out
),
velocity_out
(
velocity_out
)
{}
inline
void
operator
()()
{
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
...
...
@@ -53,5 +72,283 @@ class MomentumOpKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
,
typename
UpdateMethod
>
class
DenseMomentumFunctor
;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template
<
typename
T
>
class
DenseMomentumFunctor
<
T
,
UseNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
num_
;
T
*
p_out_
;
T
*
v_out_
;
public:
DenseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
learning_rate
),
mu_
(
mu
),
num_
(
num
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
g
=
g_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
(
g
+
v_out
*
mu_
)
*
lr
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
T
>
class
DenseMomentumFunctor
<
T
,
NoNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
num_
;
T
*
p_out_
;
T
*
v_out_
;
public:
DenseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
learning_rate
),
mu_
(
mu
),
num_
(
num
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
g
=
g_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
lr
*
v_out
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
T
,
typename
UpdateMethod
>
class
SparseMomentumFunctor
;
template
<
typename
T
>
class
SparseMomentumFunctor
<
T
,
UseNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
*
rows_
;
const
int64_t
row_numel_
;
const
int64_t
row_height_
;
T
*
p_out_
;
T
*
v_out_
;
public:
SparseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
lr
),
mu_
(
mu
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_height_
(
row_height
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_height_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
g_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
(
g
+
v_out
*
mu_
)
*
lr
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
T
>
class
SparseMomentumFunctor
<
T
,
NoNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
*
rows_
;
const
int64_t
row_numel_
;
const
int64_t
row_height_
;
T
*
p_out_
;
T
*
v_out_
;
public:
SparseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
lr
),
mu_
(
mu
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_height_
(
row_height
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_height_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
g_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
v_out
*
lr
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
CPUDenseMomentumFunctor
<
T
>
functor
(
param
,
grad
,
velocity
,
learning_rate
,
mu
,
use_nesterov
,
param_out
,
velocity_out
);
functor
();
}
else
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
if
(
use_nesterov
)
{
DenseMomentumFunctor
<
T
,
UseNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
else
{
DenseMomentumFunctor
<
T
,
NoNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// sparse update embedding with selectedrows
auto
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
// sparse update maybe empty.
if
(
grad
->
rows
().
size
()
==
0
)
{
VLOG
(
3
)
<<
"Grad SelectedRows contains no data!"
;
return
;
}
auto
*
merged_grad
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
*
grad
,
merged_grad
);
const
int64_t
*
rows
=
nullptr
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
merged_grad
->
rows
().
CUDAData
(
ctx
.
GetPlace
());
}
else
{
#endif
rows
=
merged_grad
->
rows
().
data
();
#ifdef PADDLE_WITH_CUDA
}
#endif
int64_t
row_numel
=
merged_grad
->
value
().
numel
()
/
merged_grad
->
rows
().
size
();
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
if
(
use_nesterov
)
{
SparseMomentumFunctor
<
T
,
UseNesterov
>
functor
(
param
->
data
<
T
>
(),
merged_grad
->
value
().
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
rows
,
row_numel
,
static_cast
<
int64_t
>
(
merged_grad
->
rows
().
size
()),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
else
{
SparseMomentumFunctor
<
T
,
NoNesterov
>
functor
(
param
->
data
<
T
>
(),
merged_grad
->
value
().
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
rows
,
row_numel
,
static_cast
<
int64_t
>
(
merged_grad
->
rows
().
size
()),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
}
else
{
PADDLE_THROW
(
string
::
Sprintf
(
"MomentumOp only supports LoDTensor or SelectedRows "
"gradient, but the received Variable Type is %s"
,
grad_var
->
Type
().
name
()));
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/rmsprop_op.cc
浏览文件 @
aca05d59
...
...
@@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel {
"Input(Grad) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Moment"
),
"Input(Moment) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(param_out) of RmspropOp should not be null."
);
...
...
paddle/fluid/operators/rmsprop_op.h
浏览文件 @
aca05d59
...
...
@@ -13,66 +13,259 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
struct
DenseRmspropGradFunctor
{
inline
explicit
DenseRmspropGradFunctor
(
const
T
*
grad
)
:
grad_
(
grad
)
{}
HOSTDEVICE
inline
T
operator
()(
int64_t
idx
)
const
{
return
grad_
[
idx
];
}
const
T
*
grad_
;
};
template
<
typename
T
>
struct
SparseRmspropGradFunctor
{
inline
SparseRmspropGradFunctor
(
const
T
*
grad
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_count
)
:
grad_
(
grad
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_count_
(
row_count
)
{}
HOSTDEVICE
inline
T
operator
()(
int64_t
idx
)
const
{
auto
row_idx
=
math
::
BinarySearch
(
rows_
,
row_count_
,
idx
/
row_numel_
);
return
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
idx
%
row_numel_
]
:
0
;
}
const
T
*
grad_
;
const
int64_t
*
rows_
;
int64_t
row_numel_
;
int64_t
row_count_
;
};
template
<
typename
T
,
typename
GradFunctor
>
struct
UncenteredRmspropFunctor
{
UncenteredRmspropFunctor
(
T
*
param
,
T
*
ms
,
T
*
mom
,
const
T
*
lr
,
T
rho
,
T
epsilon
,
T
momentum
,
const
GradFunctor
&
grad_functor
)
:
param_
(
param
),
ms_
(
ms
),
mom_
(
mom
),
lr_
(
lr
),
rho_
(
rho
),
epsilon_
(
epsilon
),
momentum_
(
momentum
),
grad_functor_
(
grad_functor
)
{}
HOSTDEVICE
inline
void
operator
()(
int64_t
idx
)
const
{
T
g
=
grad_functor_
(
idx
);
T
ms_out
=
rho_
*
ms_
[
idx
]
+
(
1
-
rho_
)
*
g
*
g
;
T
mom_out
=
momentum_
*
mom_
[
idx
]
+
lr_
[
0
]
*
g
/
sqrt
(
ms_out
+
epsilon_
);
param_
[
idx
]
-=
mom_out
;
ms_
[
idx
]
=
ms_out
;
mom_
[
idx
]
=
mom_out
;
}
T
*
param_
;
T
*
ms_
;
T
*
mom_
;
const
T
*
lr_
;
T
rho_
;
T
epsilon_
;
T
momentum_
;
GradFunctor
grad_functor_
;
};
template
<
typename
T
,
typename
GradFunctor
>
struct
CenteredRmspropFunctor
{
CenteredRmspropFunctor
(
T
*
param
,
T
*
ms
,
T
*
mom
,
T
*
mean_grad
,
const
T
*
lr
,
T
rho
,
T
epsilon
,
T
momentum
,
const
GradFunctor
&
grad_functor
)
:
param_
(
param
),
ms_
(
ms
),
mom_
(
mom
),
mean_grad_
(
mean_grad
),
lr_
(
lr
),
rho_
(
rho
),
epsilon_
(
epsilon
),
momentum_
(
momentum
),
grad_functor_
(
grad_functor
)
{}
HOSTDEVICE
inline
void
operator
()(
int64_t
idx
)
const
{
T
g
=
grad_functor_
(
idx
);
T
ms_out
=
rho_
*
ms_
[
idx
]
+
(
1
-
rho_
)
*
g
*
g
;
T
mg_out
=
rho_
*
mean_grad_
[
idx
]
+
(
1
-
rho_
)
*
g
;
T
mom_out
=
momentum_
*
mom_
[
idx
]
+
lr_
[
0
]
*
g
/
sqrt
(
ms_out
-
mg_out
*
mg_out
+
epsilon_
);
param_
[
idx
]
-=
mom_out
;
ms_
[
idx
]
=
ms_out
;
mom_
[
idx
]
=
mom_out
;
mean_grad_
[
idx
]
=
mg_out
;
}
T
*
param_
;
T
*
ms_
;
T
*
mom_
;
T
*
mean_grad_
;
const
T
*
lr_
;
T
rho_
;
T
epsilon_
;
T
momentum_
;
GradFunctor
grad_functor_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
RmspropOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
auto
*
moment_out
=
ctx
.
Output
<
Tensor
>
(
"MomentOut"
);
auto
*
mean_square_out
=
ctx
.
Output
<
Tensor
>
(
"MeanSquareOut"
);
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
LoDTensor
=
framework
::
LoDTensor
;
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
auto
*
param_out
=
ctx
.
Output
<
LoDTensor
>
(
"ParamOut"
);
auto
*
moment_out
=
ctx
.
Output
<
LoDTensor
>
(
"MomentOut"
);
auto
*
mean_square_out
=
ctx
.
Output
<
LoDTensor
>
(
"MeanSquareOut"
);
auto
grad
=
ctx
.
Input
<
Tensor
>
(
"Grad"
);
auto
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
rho
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"decay"
));
auto
momentum
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
bool
centered
=
ctx
.
Attr
<
bool
>
(
"centered"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
p_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"Param"
);
auto
&
ms_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"MeanSquare"
);
auto
&
lr_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"LearningRate"
);
auto
&
mom_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"Moment"
);
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
float
rho
=
ctx
.
Attr
<
float
>
(
"decay"
);
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
bool
centered
=
ctx
.
Attr
<
bool
>
(
"centered"
);
PADDLE_ENFORCE_EQ
(
&
p_tensor
,
param_out
,
"Param and ParamOut must be the same Tensor"
);
PADDLE_ENFORCE_EQ
(
&
mom_tensor
,
moment_out
,
"Moment and MomentOut must be the same Tensor"
);
PADDLE_ENFORCE_EQ
(
&
ms_tensor
,
mean_square_out
,
"MeanSquare and MeanSquareOut must be the same Tensor"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
size_t
limit
=
static_cast
<
size_t
>
(
ms_tensor
.
numel
());
if
(
grad_var
->
IsType
<
LoDTensor
>
())
{
auto
&
grad_tensor
=
grad_var
->
Get
<
LoDTensor
>
();
if
(
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
)
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
lr_value
=
lr_tensor
.
data
<
T
>
()[
0
];
auto
p
=
EigenVector
<
T
>::
Flatten
(
p_tensor
);
auto
ms
=
EigenVector
<
T
>::
Flatten
(
ms_tensor
);
auto
g
=
EigenVector
<
T
>::
Flatten
(
grad_tensor
);
auto
mom
=
EigenVector
<
T
>::
Flatten
(
mom_tensor
);
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
if
(
centered
)
{
auto
&
mg_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"MeanGrad"
);
auto
mg
=
EigenVector
<
T
>::
Flatten
(
mg_tensor
);
auto
*
mean_grad_out
=
ctx
.
Output
<
LoDTensor
>
(
"MeanGradOut"
);
PADDLE_ENFORCE
(
&
mg_tensor
,
mean_grad_out
,
"MeanGrad and MeanGradOut must be the same Tensor"
);
auto
mg_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_grad_out
);
mg_out
.
device
(
place
)
=
rho
*
mg
+
(
1
-
rho
)
*
g
;
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
-
mg_out
.
square
()
+
epsilon
).
sqrt
();
}
else
{
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
}
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
else
{
DenseRmspropGradFunctor
<
T
>
grad_func
(
grad_tensor
.
data
<
T
>
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
limit
);
if
(
centered
)
{
auto
&
mg_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"MeanGrad"
);
auto
*
mean_grad_out
=
ctx
.
Output
<
LoDTensor
>
(
"MeanGradOut"
);
PADDLE_ENFORCE
(
&
mg_tensor
,
mean_grad_out
,
"MeanGrad and MeanGradOut must be the same Tensor"
);
for_range
(
CenteredRmspropFunctor
<
T
,
DenseRmspropGradFunctor
<
T
>>
(
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_grad_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
DenseRmspropGradFunctor
<
T
>>
(
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
grad
=
grad_var
->
Get
<
framework
::
SelectedRows
>
();
auto
*
merged_grad
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
merge_func
(
dev_ctx
,
grad
,
merged_grad
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
limit
);
const
int64_t
*
rows
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
merged_grad
->
rows
().
CUDAData
(
ctx
.
GetPlace
());
}
else
{
#endif
rows
=
merged_grad
->
rows
().
data
();
#ifdef PADDLE_WITH_CUDA
}
#endif
auto
&
merged_tensor
=
merged_grad
->
value
();
int64_t
row_count
=
merged_grad
->
rows
().
size
();
int64_t
row_numel
=
merged_tensor
.
numel
()
/
row_count
;
SparseRmspropGradFunctor
<
T
>
grad_func
(
merged_tensor
.
data
<
T
>
(),
rows
,
row_numel
,
row_count
);
auto
p
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Param"
));
auto
ms
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"MeanSquare"
));
auto
lr
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"LearningRate"
));
auto
g
=
EigenVector
<
T
>::
Flatten
(
*
grad
);
auto
mom
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Moment"
));
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
static_cast
<
int
>
(
grad
->
numel
()));
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
if
(
centered
)
{
auto
mg
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"MeanGrad"
));
auto
*
mean_grad_out
=
ctx
.
Output
<
Tensor
>
(
"MeanGradOut"
);
mean_grad_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
mg_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_grad_out
);
mg_out
.
device
(
place
)
=
rho
*
mg
+
(
1
-
rho
)
*
g
;
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr
.
broadcast
(
grad_dsize
)
*
g
/
(
ms_out
-
mg_out
.
square
()
+
epsilon
).
sqrt
();
if
(
centered
)
{
auto
&
mg_tensor
=
*
ctx
.
Input
<
LoDTensor
>
(
"MeanGrad"
);
auto
*
mean_grad_out
=
ctx
.
Output
<
LoDTensor
>
(
"MeanGradOut"
);
PADDLE_ENFORCE
(
&
mg_tensor
,
mean_grad_out
,
"MeanGrad and MeanGradOut must be the same Tensor"
);
for_range
(
CenteredRmspropFunctor
<
T
,
SparseRmspropGradFunctor
<
T
>>
(
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_grad_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
SparseRmspropGradFunctor
<
T
>>
(
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
}
else
{
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr
.
broadcast
(
grad_dsize
)
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
PADDLE_THROW
(
"RMSProp only supports LoDTensor or SelectedRows gradient"
);
}
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
};
...
...
paddle/fluid/operators/sgd_op.cc
浏览文件 @
aca05d59
...
...
@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
@@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Param"
));
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
...
...
@@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel {
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var_n
=
op_desc
.
Input
(
"Param"
)[
0
];
auto
in_var_type
=
block
->
FindRecursiveOrCreateVar
(
input_var_n
).
GetType
();
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s"
,
input_var_n
,
in_var_type
);
for
(
auto
&
out_var_n
:
op_desc
.
Output
(
"ParamOut"
))
{
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_n
);
if
(
out_var
.
GetType
()
!=
in_var_type
)
{
out_var
.
SetType
(
in_var_type
);
}
}
}
...
...
paddle/fluid/operators/sgd_op.cu
浏览文件 @
aca05d59
...
...
@@ -57,6 +57,12 @@ template <typename T>
class
SGDOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
...
...
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
aca05d59
...
...
@@ -23,14 +23,14 @@ namespace operators {
template
<
typename
T
>
class
CPUUniformRandomKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
Tensor
*
tensor
=
nullptr
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
Tensor
*
tensor
=
nullptr
;
auto
out_var
=
ctx
.
OutputVar
(
"Out"
);
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
out_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
shape
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"shape"
);
auto
*
selected_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
selected_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
tensor
=
selected_rows
->
mutable_value
();
tensor
->
Resize
(
framework
::
make_ddim
(
shape
));
selected_rows
->
mutable_rows
()
->
reserve
(
shape
[
0
]);
...
...
@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
"uniform_random_op's output only"
"supports SelectedRows and LoDTensor"
);
}
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
...
...
@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
Attrs
().
Get
<
float
>
(
"min"
)
<
ctx
->
Attrs
().
Get
<
float
>
(
"max"
),
"uniform_random's min must less then max"
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
std
::
vector
<
int64_t
>
temp
;
temp
.
reserve
(
shape
.
size
());
for
(
auto
dim
:
shape
)
{
...
...
@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
ctx
.
Attr
<
int
>
(
"dtype"
)),
ctx
.
GetPlace
());
...
...
@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
if
(
block
->
FindRecursiveOrCreateVar
(
out_var_name
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
var_data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"dtype"
)));
auto
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
if
(
out_var
.
GetType
()
!=
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
out_var
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
out_var
.
SetDataType
(
var_data_type
);
}
};
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
aca05d59
...
...
@@ -156,7 +156,50 @@ PYBIND11_PLUGIN(core) {
.
def
(
"_get_double_element"
,
TensorGetElement
<
double
>
)
.
def
(
"_dtype"
,
[](
Tensor
&
self
)
{
return
ToDataType
(
self
.
type
());
});
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
)
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
,
R"DOC(
LoDTensor is a Tensor with optional LoD information.
np.array(lod_tensor) can convert LoDTensor to numpy array.
lod_tensor.lod() can retrieve the LoD information.
LoD is short for Level of Details and is usually used for varied sequence
length. You can skip the following comment if you don't need optional LoD.
For example:
A LoDTensor X can look like the example below. It contains 2 sequences.
The first has length 2 and the second has length 3, as described by x.lod.
The first tensor dimension 5=2+3 is calculated from LoD if it's available.
It means the total number of sequence element. In X, each element has 2
columns, hence [5, 2].
x.lod = [[2, 3]]
x.data = [[1, 2], [3, 4], // seq 1
[5, 6], [7, 8], [9, 10]] // seq 2
x.shape = [5, 2]
LoD can have multiple levels (for example, a paragraph can have multiple
sentences and a sentence can have multiple words). In the following
LodTensor Y, the lod_level is 2. It means there are 2 sequence, the
first sequence length is 2 (has 2 sub-sequences), the second one's
length is 1. The first sequence's 2 sub-sequences have length 2 and 2,
respectively. And the second sequence's 1 sub-sequence has length 3.
y.lod = [[2 1], [2 2 3]]
y.shape = [2+2+3, ...]
Note:
In above description, LoD is length-based. In Paddle internal
implementation, lod is offset-based. Hence, internally,
y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based
equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]).
Sometimes LoD is called recursive_sequence_length to be more
self-explanatory. In this case, it must be length-based. Due to history
reasons. when LoD is called lod in public API, it might be offset-based.
Users should be careful about it.
)DOC"
)
.
def_buffer
(
[](
Tensor
&
self
)
->
py
::
buffer_info
{
return
CastToPyBuffer
(
self
);
})
.
def
(
"__init__"
,
...
...
@@ -596,26 +639,58 @@ All parameter, weight, gradient are variables in Paddle.
// -- python binds for parallel executor.
py
::
class_
<
ParallelExecutor
>
pe
(
m
,
"ParallelExecutor"
);
py
::
class_
<
ExecutionStrategy
>
exec_strategy
(
pe
,
"ExecutionStrategy"
);
py
::
class_
<
ExecutionStrategy
>
exec_strategy
(
pe
,
"ExecutionStrategy"
,
R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 4
train_exe = fluid.ParallelExecutor(use_cuda=True,
loss_name=loss.name,
exec_strategy=exec_strategy)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
)DOC"
);
exec_strategy
.
def
(
py
::
init
())
.
def_property
(
"num_threads"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
num_threads_
;
},
[](
ExecutionStrategy
&
self
,
size_t
num_threads
)
{
self
.
num_threads_
=
num_threads
;
})
},
R"DOC(The type is INT, num_threads represents the size of thread pool that
used to run the operators of the current program in ParallelExecutor.
If :math:`num\_threads=1`, all the operators will execute one by one,
but the order maybe difference between iterations.
If it is not set, it will be set in ParallelExecutor according to the
device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU,
:math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor.
if it is not set, ParallelExecutor will get the cpu count by calling
`multiprocessing.cpu_count()`. Default 0.)DOC"
)
.
def_property
(
"use_cuda"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
use_cuda_
;
},
[](
ExecutionStrategy
&
self
,
bool
use_cuda
)
{
self
.
use_cuda_
=
use_cuda
;
})
})
// FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may
// make user confuse, because ParallelExecutor has a parameter named
// 'use_cuda' too, in current implementation, ParallelExecutor's
// 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'.
.
def_property
(
"allow_op_delay"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
allow_op_delay_
;
},
[](
ExecutionStrategy
&
self
,
bool
allow_op_delay
)
{
self
.
allow_op_delay_
=
allow_op_delay
;
})
},
R"DOC(The type is BOOL, allow_op_delay represents whether to delay the
communication operators to run, it may make the execution faster.
Note that in some models, allow_op_delay may cause program hang. Default False.)DOC"
)
.
def_property
(
"num_iteration_per_drop_scope"
,
[](
const
ExecutionStrategy
&
self
)
{
...
...
@@ -623,7 +698,19 @@ All parameter, weight, gradient are variables in Paddle.
},
[](
ExecutionStrategy
&
self
,
size_t
num_iteration_per_drop_scope
)
{
self
.
num_iteration_per_drop_scope_
=
num_iteration_per_drop_scope
;
});
},
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations. Default 100.
NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor
will clean up the temp variables at the end of the current iteration.
2. In some NLP model, it may cause the GPU memory is insufficient,
in this case, you should reduce `num_iteration_per_drop_scope`.
)DOC"
);
exec_strategy
.
def_property
(
"use_experimental_executor"
,
[](
const
ExecutionStrategy
&
self
)
{
...
...
@@ -634,7 +721,22 @@ All parameter, weight, gradient are variables in Paddle.
:
ExecutionStrategy
::
kDefault
;
});
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
);
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
,
R"DOC(
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
train_exe = fluid.ParallelExecutor(use_cuda=True,
loss_name=loss.name,
build_strategy=build_strategy)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
)DOC"
);
py
::
enum_
<
BuildStrategy
::
ReduceStrategy
>
(
build_strategy
,
"ReduceStrategy"
)
.
value
(
"Reduce"
,
BuildStrategy
::
ReduceStrategy
::
kReduce
)
...
...
@@ -652,31 +754,51 @@ All parameter, weight, gradient are variables in Paddle.
[](
const
BuildStrategy
&
self
)
{
return
self
.
reduce_
;
},
[](
BuildStrategy
&
self
,
BuildStrategy
::
ReduceStrategy
strategy
)
{
self
.
reduce_
=
strategy
;
})
},
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
'AllReduce' and 'Reduce'. If you want that all the parameters'
optimization are done on all devices independently, you should choose 'AllReduce';
if you choose 'Reduce', all the parameters' optimization will be evenly distributed
to different devices, and then broadcast the optimized parameter to other devices.
In some models, `Reduce` is faster. Default 'AllReduce'. )DOC"
)
.
def_property
(
"gradient_scale_strategy"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
gradient_scale_
;
},
[](
BuildStrategy
&
self
,
BuildStrategy
::
GradientScaleStrategy
strategy
)
{
self
.
gradient_scale_
=
strategy
;
})
},
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default,
ParallelExecutor sets the :math:`loss@grad` according to the number of devices.
If you want to customize :math:`loss@grad`, you can choose 'Customized'.
Default 'CoeffNumDevice'.)DOC"
)
.
def_property
(
"debug_graphviz_path"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
debug_graphviz_path_
;
},
[](
BuildStrategy
&
self
,
const
std
::
string
&
path
)
{
self
.
debug_graphviz_path_
=
path
;
})
},
R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC"
)
.
def_property
(
"enable_data_balance"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_data_balance_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_data_balance_
=
b
;
})
.
def_property
(
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_elewise_add_act_ops_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
fuse_elewise_add_act_ops_
=
b
;
});
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_data_balance_
=
b
;
})
// FIXME(chengudo): enable_data_balance seems not important
.
def_property
(
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_elewise_add_act_ops_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
fuse_elewise_add_act_ops_
=
b
;
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC"
);
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
const
std
::
unordered_set
<
std
::
string
>
&
,
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
aca05d59
...
...
@@ -654,11 +654,21 @@ function gen_fluid_inference_lib() {
if
[[
${
WITH_C_API
:-
OFF
}
==
"OFF"
&&
${
WITH_INFERENCE
:-
ON
}
==
"ON"
]]
;
then
cat
<<
EOF
========================================
Deploy
ing fluid inference library ...
Generat
ing fluid inference library ...
========================================
EOF
cmake ..
-DWITH_DISTRIBUTE
=
OFF
make
-j
`
nproc
`
inference_lib_dist
fi
}
function
tar_fluid_inference_lib
()
{
if
[[
${
WITH_C_API
:-
OFF
}
==
"OFF"
&&
${
WITH_INFERENCE
:-
ON
}
==
"ON"
]]
;
then
cat
<<
EOF
========================================
Taring fluid inference library ...
========================================
EOF
cd
${
PADDLE_ROOT
}
/build
cp
-r
fluid_install_dir fluid
tar
-czf
fluid.tgz fluid
...
...
@@ -673,7 +683,7 @@ function test_fluid_inference_lib() {
========================================
EOF
cd
${
PADDLE_ROOT
}
/paddle/fluid/inference/api/demo_ci
./run.sh
${
PADDLE_ROOT
}
${
WITH_MKL
:-
ON
}
${
WITH_GPU
:-
OFF
}
./run.sh
${
PADDLE_ROOT
}
${
WITH_MKL
:-
ON
}
${
WITH_GPU
:-
OFF
}
${
INFERENCE_DEMO_INSTALL_DIR
}
${
TENSORRT_INCLUDE_DIR
:-
/usr/local/TensorRT/include
}
${
TENSORRT_LIB_DIR
:-
/usr/local/TensorRT/lib
}
./clean.sh
fi
}
...
...
@@ -722,6 +732,7 @@ function main() {
fluid_inference_lib
)
cmake_gen
${
PYTHON_ABI
:-
""
}
gen_fluid_inference_lib
tar_fluid_inference_lib
test_fluid_inference_lib
;;
check_style
)
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
aca05d59
...
...
@@ -55,7 +55,11 @@ def data(name,
Args:
name(str): The name/alias of the function
shape(list): Tuple declaring the shape.
append_batch_size(bool): Whether or not to append the data as a batch.
append_batch_size(bool):
1. If true, it prepends -1 to the shape.
For example if shape=[1], the resulting shape is [-1, 1].
2. If shape contains -1, such as shape=[1, -1],
append_batch_size will be enforced to be be False (ineffective).
dtype(int|float): The type of data : float32, float_16, int etc
type(VarType): The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
aca05d59
此差异已折叠。
点击以展开。
python/paddle/fluid/layers/ops.py
浏览文件 @
aca05d59
...
...
@@ -14,6 +14,8 @@
from
__future__
import
print_function
from
.layer_function_generator
import
generate_layer_fn
,
generate_layer_fn_noattr
from
..
import
core
from
..framework
import
convert_np_dtype_to_dtype_
__activations_noattr__
=
[
'sigmoid'
,
...
...
@@ -58,8 +60,11 @@ _uniform_random_ = generate_layer_fn('uniform_random')
def
uniform_random
(
shape
,
dtype
=
None
,
min
=
None
,
max
=
None
,
seed
=
None
):
locals_var
=
locals
().
keys
()
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
...
...
@@ -78,8 +83,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def
hard_shrink
(
x
,
threshold
=
None
):
locals_var
=
locals
().
keys
()
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
...
...
@@ -99,12 +105,12 @@ _cum_sum_ = generate_layer_fn('cumsum')
def
cumsum
(
x
,
axis
=
None
,
exclusive
=
None
,
reverse
=
None
):
locals_var
=
locals
().
keys
()
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
return
_cum_sum_
(
**
kwargs
)
...
...
@@ -121,8 +127,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def
thresholded_relu
(
x
,
threshold
=
None
):
locals_var
=
locals
().
keys
()
kwargs
=
dict
()
for
name
in
locals
()
:
for
name
in
locals
_var
:
val
=
locals
()[
name
]
if
val
is
not
None
:
kwargs
[
name
]
=
val
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
aca05d59
...
...
@@ -111,7 +111,7 @@ def create_global_var(shape,
force_cpu
=
False
,
name
=
None
):
"""
Create a new
variabl
e in the global block(block 0).
Create a new
tensor variable with valu
e in the global block(block 0).
Args:
shape(list[int]): shape of the variable
...
...
python/paddle/fluid/nets.py
浏览文件 @
aca05d59
...
...
@@ -64,23 +64,33 @@ def simple_img_conv_pool(input,
average-pooling. Default :math:`max`.
global_pooling (bool): Whether to use the global pooling. If global_pooling = true,
pool_size and pool_padding while be ignored. Default False
conv_stride (int|list|tuple): The stride size of the
C
onv2d Layer. If stride is a
conv_stride (int|list|tuple): The stride size of the
c
onv2d Layer. If stride is a
list or tuple, it must contain two integers, (conv_stride_H, conv_stride_W). Otherwise,
the conv_stride_H = conv_stride_W = conv_stride. Default: conv_stride = 1.
conv_padding (int|list|tuple): The padding size of the
C
onv2d Layer. If padding is
conv_padding (int|list|tuple): The padding size of the
c
onv2d Layer. If padding is
a list or tuple, it must contain two integers, (conv_padding_H, conv_padding_W).
Otherwise, the conv_padding_H = conv_padding_W = conv_padding. Default: conv_padding = 0.
conv_dilation (int|list|tuple): The dilation size of the
C
onv2d Layer. If dilation is
conv_dilation (int|list|tuple): The dilation size of the
c
onv2d Layer. If dilation is
a list or tuple, it must contain two integers, (conv_dilation_H, conv_dilation_W).
Otherwise, the conv_dilation_H = conv_dilation_W = conv_dilation. Default: conv_dilation = 1.
conv_groups (int): The groups number of the
C
onv2d Layer. According to grouped
conv_groups (int): The groups number of the
c
onv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1
param_attr (ParamAttr): The parameters to the Conv2d Layer. Default: None
bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None
act (str): Activation type for Conv2d. Default: None
connected to the second half of the input channels. Default: groups=1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(
\\
frac{2.0 }{filter\_elem\_num})^{0.5}`.
Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
act (str): Activation type for conv2d, if it is set to None, activation is not
appended. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
aca05d59
...
...
@@ -659,6 +659,9 @@ class AdamaxOptimizer(Optimizer):
optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, AdamaxOptimizer doesn't support sparse gradient.
"""
_moment_acc_str
=
"moment"
_inf_norm_acc_str
=
"inf_norm"
...
...
@@ -778,6 +781,9 @@ class DecayedAdagradOptimizer(Optimizer):
optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, DecayedAdagradOptimizer doesn't support sparse gradient.
"""
_moment_acc_str
=
"moment"
...
...
@@ -858,6 +864,9 @@ class AdadeltaOptimizer(Optimizer):
optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, AdadeltaOptimizer doesn't support sparse gradient.
"""
_avg_squared_grad_acc_str
=
"_avg_squared_grad"
...
...
@@ -1126,6 +1135,9 @@ class FtrlOptimizer(Optimizer):
optimizer = fluid.optimizer.Ftrl(0.0001)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, FtrlOptimizer doesn't support sparse gradient.
"""
_squared_acc_str
=
"squared"
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
aca05d59
...
...
@@ -31,15 +31,32 @@ BuildStrategy = core.ParallelExecutor.BuildStrategy
class
ParallelExecutor
(
object
):
"""
ParallelExecutor can run program in parallel.
ParallelExecutor is designed for data parallelism, which focuses on distributing
the data across different nodes and every node operates on the data in parallel.
If you use ParallelExecutor to run the current program on GPU, the node means GPU
device, and ParallelExecutor will get the available GPU device automatically on
the current machine. If you use ParallelExecutor to run the current program on CPU,
the node means the CPU device, and you can specify the CPU device number by adding
'CPU_NUM' environment variable, for example 'CPU_NUM=4', if the environment variable
is not found, ParallelExecutor will call `multiprocessing.cpu_count` to get the number
of CPUs in the system.
Args:
use_cuda (bool): Whether to use CUDA or not.
loss_name (str): The loss name must set in training. Default None.
main_program (Program): The program that need to run, if not provided,
then default_main_program will be used. Default None.
share_vars_from(ParallelExecutor): If provi
ed
, it will share variables
share_vars_from(ParallelExecutor): If provi
de
, it will share variables
from the specified ParallelExecutor. Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to control how to run
the program in ParallelExecutor, for example how many threads are used to
execute the program, how many iterations to clean up the temp variables
which is generated during execution. For more information, please refer
to fluid.ExecutionStrategy. Default None.
build_strategy(BuildStrategy): build_strategy is used to control how to
build the SSA Graph in ParallelExecutor by setting the property,
for example reduce_strategy, gradient_scale_strategy. For more information,
please refer to fluid.BuildStrategy. Default None.
num_trainers(int): If greater than 1, NCCL will be initialized with
multiple rank of nodes, each node should have same number of GPUs.
Distributed training will be enabled then. Default 1.
...
...
python/paddle/fluid/tests/unittests/test_momentum_op.py
浏览文件 @
aca05d59
...
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
op_test
import
OpTest
...
...
@@ -88,5 +90,97 @@ class TestMomentumOp2(OpTest):
self
.
check_output
()
class
TestSparseMomentumOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
use_nesterov
=
False
def
check_with_place
(
self
,
place
):
self
.
init_kernel
()
scope
=
core
.
Scope
()
# create and initialize Grad Variable
height
=
10
rows
=
[
0
,
4
,
7
]
row_numel
=
12
mu
=
1.0
use_nesterov
=
self
.
use_nesterov
# create and initialize Param Variable
param
=
scope
.
var
(
'Param'
).
get_tensor
()
param_array
=
np
.
full
((
height
,
row_numel
),
5.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
param_out
=
scope
.
var
(
"ParamOut"
).
get_tensor
()
param_out_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
param_out
.
set
(
param_out_array
,
place
)
grad_selected_rows
=
scope
.
var
(
'Grad'
).
get_selected_rows
()
grad_selected_rows
.
set_height
(
height
)
grad_selected_rows
.
set_rows
(
rows
)
grad_np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
grad_np_array
[
0
,
0
]
=
2.0
grad_np_array
[
2
,
8
]
=
4.0
grad_tensor
=
grad_selected_rows
.
get_tensor
()
grad_tensor
.
set
(
grad_np_array
,
place
)
velocity
=
scope
.
var
(
'Velocity'
).
get_tensor
()
velocity_np_array
=
np
.
ones
((
height
,
row_numel
)).
astype
(
"float32"
)
velocity
.
set
(
velocity_np_array
,
place
)
velocity_out
=
scope
.
var
(
'VelocityOut'
).
get_tensor
()
velocity_out_np_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
velocity_out
.
set
(
velocity_out_np_array
,
place
)
# create and initialize LeraningRate Variable
lr
=
scope
.
var
(
'LearningRate'
).
get_tensor
()
lr_array
=
np
.
full
((
1
),
2.0
).
astype
(
"float32"
)
lr
.
set
(
lr_array
,
place
)
# create and run operator
op
=
Operator
(
"momentum"
,
Param
=
'Param'
,
Grad
=
'Grad'
,
Velocity
=
'Velocity'
,
ParamOut
=
'ParamOut'
,
VelocityOut
=
'VelocityOut'
,
LearningRate
=
'LearningRate'
,
mu
=
mu
,
use_nesterov
=
use_nesterov
)
op
.
run
(
scope
,
place
)
# get and compare result
param_out_np_array
=
np
.
array
(
param_out
)
velocity_out_np_array
=
np
.
array
(
velocity_out
)
# TODO(dzh): add a more suitable general numpy interface
# for sparse update.
_grad_np_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
for
i
in
range
(
len
(
rows
)):
_grad_np_array
[
rows
[
i
]]
=
grad_np_array
[
i
]
_velocity_out
=
mu
*
velocity_np_array
+
_grad_np_array
_param
=
param_array
if
use_nesterov
:
_param_out
=
_param
-
(
_grad_np_array
+
_velocity_out
*
mu
)
*
lr_array
else
:
_param_out
=
_param
-
lr_array
*
_velocity_out
self
.
assertTrue
((
_velocity_out
==
velocity_out_np_array
).
all
())
self
.
assertTrue
((
_param_out
==
param_out_np_array
).
all
())
def
init_kernel
(
self
):
pass
def
test_sparse_momentum
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
class
TestSparseMomentumOp2
(
TestSparseMomentumOp
):
def
init_kernel
(
self
):
self
.
use_nesterov
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_rmsprop_op.py
浏览文件 @
aca05d59
...
...
@@ -19,33 +19,76 @@ import unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
import
paddle.fluid
as
fluid
def
create_selected_rows_and_tensor
(
scope
,
place
,
height
,
row_num
,
embedding_size
):
sr
=
scope
.
var
(
"@selected_rows@"
).
get_selected_rows
()
tensor
=
scope
.
var
(
"grad"
).
get_tensor
()
rows
=
np
.
random
.
random_integers
(
low
=
0
,
high
=
height
-
1
,
size
=
[
row_num
,
]).
astype
(
'int64'
)
sr_val
=
np
.
random
.
random
(
size
=
[
row_num
,
embedding_size
]).
astype
(
'float32'
)
sr
.
set_height
(
height
)
sr
.
set_rows
(
rows
)
sr
.
get_tensor
().
set
(
sr_val
,
place
)
tensor_val
=
np
.
zeros
(
shape
=
[
height
,
embedding_size
],
dtype
=
'float32'
)
for
i
in
range
(
row_num
):
row
=
rows
[
i
]
tensor_val
[
row
,
:]
=
tensor_val
[
row
,
:]
+
sr_val
[
i
,
:]
tensor
.
set
(
tensor_val
,
place
)
return
tensor_val
,
sr_val
class
TestBase
(
unittest
.
TestCase
):
def
setup
(
self
,
centered
,
epsilon
=
1e-6
):
def
setup
(
self
,
place
,
is_sparse
,
centered
,
size
,
row_num
=
None
,
epsilon
=
1e-6
):
np
.
random
.
seed
(
5
)
# fix seed
self
.
scope
=
fluid
.
global_scope
()
self
.
place
=
place
self
.
param_name
=
"param"
self
.
param
=
np
.
random
.
random
(
(
123
,
321
)
).
astype
(
"float32"
)
self
.
param
=
np
.
random
.
random
(
size
).
astype
(
"float32"
)
self
.
mean_square_name
=
"mean_square"
self
.
mean_square
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
self
.
mean_square
=
np
.
random
.
uniform
(
low
=
1
,
high
=
2
,
size
=
size
).
astype
(
"float32"
)
self
.
mean_grad_name
=
"mean_grad"
self
.
mean_grad
=
np
.
random
.
random
(
(
123
,
321
)
).
astype
(
"float32"
)
self
.
mean_grad
=
np
.
random
.
random
(
size
).
astype
(
"float32"
)
self
.
lr_name
=
"lr"
self
.
learning_rate
=
np
.
array
([
0.01
]).
astype
(
"float32"
)
self
.
grad_name
=
"grad"
self
.
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
self
.
is_sparse
=
is_sparse
if
self
.
is_sparse
:
self
.
grad_sr_name
=
"@selected_rows@"
self
.
grad
,
self
.
grad_sr
=
create_selected_rows_and_tensor
(
self
.
scope
,
place
,
size
[
0
],
row_num
,
size
[
1
])
else
:
self
.
grad
=
np
.
random
.
random
(
size
).
astype
(
"float32"
)
grad_tensor
=
self
.
scope
.
var
(
self
.
grad_name
).
get_tensor
()
grad_tensor
.
set
(
self
.
grad
,
place
)
self
.
moment_name
=
"moment"
self
.
moment
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
self
.
moment
=
np
.
random
.
uniform
(
low
=
0
,
high
=
1
,
size
=
size
).
astype
(
"float32"
)
self
.
epsilon
=
epsilon
self
.
decay
=
0.9
self
.
momentum
=
0.
0
self
.
momentum
=
0.
1
self
.
centered
=
centered
self
.
ms_out
=
self
.
decay
*
self
.
mean_square
+
(
1
-
self
.
decay
...
...
@@ -61,118 +104,122 @@ class TestBase(unittest.TestCase):
self
.
param_out
=
self
.
param
-
self
.
moment_out
def
check
(
self
,
actual_t
,
expect_t
,
place
,
out_name
,
atol
=
1e-5
,
equal_nan
=
False
):
self
.
assertTrue
(
np
.
allclose
(
actual_t
,
expect_t
,
atol
=
atol
,
equal_nan
=
equal_nan
),
"Output ("
+
out_name
+
") has diff at "
+
str
(
place
)
+
"
\n
Expect "
+
str
(
expect_t
)
+
"
\n
"
+
"But Got"
+
str
(
actual_t
))
class
TestRmspropOp
(
TestBase
):
def
check_with_place
(
self
,
place
,
centered
,
epsilon
):
self
.
setup
(
centered
,
epsilon
)
scope
=
core
.
Scope
()
# create and initialize Param Variable
param
=
scope
.
var
(
self
.
param_name
).
get_tensor
()
param
.
set
(
self
.
param
,
place
)
self
.
param_tensor
=
self
.
scope
.
var
(
self
.
param_name
).
get_tensor
()
self
.
param_tensor
.
set
(
self
.
param
,
place
)
mean_square
=
scope
.
var
(
self
.
mean_square_name
).
get_tensor
()
mean_square
.
set
(
self
.
mean_square
,
place
)
self
.
mean_square_tensor
=
self
.
scope
.
var
(
self
.
mean_square_name
).
get_tensor
()
self
.
mean_square_tensor
.
set
(
self
.
mean_square
,
place
)
lr
=
scope
.
var
(
self
.
lr_name
).
get_tensor
()
lr
=
s
elf
.
s
cope
.
var
(
self
.
lr_name
).
get_tensor
()
lr
.
set
(
self
.
learning_rate
,
place
)
grad
=
scope
.
var
(
self
.
grad
_name
).
get_tensor
()
grad
.
set
(
self
.
grad
,
place
)
self
.
moment_tensor
=
self
.
scope
.
var
(
self
.
moment
_name
).
get_tensor
()
self
.
moment_tensor
.
set
(
self
.
moment
,
place
)
moment
=
scope
.
var
(
self
.
moment_name
).
get_tensor
()
moment
.
set
(
self
.
moment
,
place
)
if
self
.
centered
:
self
.
mean_grad_tensor
=
self
.
scope
.
var
(
self
.
mean_grad_name
).
get_tensor
()
self
.
mean_grad_tensor
.
set
(
self
.
mean_grad
,
place
)
# create and run sgd operator
def
check
(
self
,
actual_t
,
expect_t
,
place
,
out_name
,
atol
=
1e-5
):
self
.
assertTrue
(
np
.
allclose
(
actual_t
,
expect_t
,
atol
=
atol
),
"Output ("
+
out_name
+
") has diff at "
+
str
(
place
)
+
"
\n
Expect "
+
str
(
expect_t
)
+
"
\n
"
+
"But Got"
+
str
(
actual_t
))
if
self
.
centered
:
mean_grad
=
scope
.
var
(
self
.
mean_grad_name
).
get_tensor
()
mean_grad
.
set
(
self
.
mean_grad
,
place
)
rmsprop_op
=
Operator
(
"rmsprop"
,
Param
=
self
.
param_name
,
Grad
=
self
.
grad_name
,
MeanSquare
=
self
.
mean_square_name
,
MeanGrad
=
self
.
mean_grad_name
,
Moment
=
self
.
moment_name
,
LearningRate
=
self
.
lr_name
,
ParamOut
=
self
.
param_name
,
MeanSquareOut
=
self
.
mean_square_name
,
MomentOut
=
self
.
moment_name
,
MeanGradOut
=
self
.
mean_grad_name
,
epsilon
=
self
.
epsilon
,
decay
=
self
.
decay
,
momentum
=
self
.
momentum
,
centered
=
True
)
else
:
rmsprop_op
=
Operator
(
"rmsprop"
,
Param
=
self
.
param_name
,
Grad
=
self
.
grad_name
,
MeanSquare
=
self
.
mean_square_name
,
Moment
=
self
.
moment_name
,
LearningRate
=
self
.
lr_name
,
ParamOut
=
self
.
param_name
,
MeanSquareOut
=
self
.
mean_square_name
,
MomentOut
=
self
.
moment_name
,
epsilon
=
self
.
epsilon
,
decay
=
self
.
decay
,
momentum
=
self
.
momentum
,
centered
=
False
)
rmsprop_op
.
run
(
scope
,
place
)
atol
=
1e-5
equal_nan
=
False
class
TestRmspropOp
(
TestBase
):
def
check_with_place
(
self
,
place
,
is_sparse
,
centered
,
size
,
row_num
=
None
,
epsilon
=
1e-6
):
self
.
setup
(
place
,
is_sparse
,
centered
,
size
,
row_num
,
epsilon
)
self
.
run_and_check
()
def
run_and_check
(
self
):
grad_name
=
self
.
grad_sr_name
if
self
.
is_sparse
else
self
.
grad_name
kwargs
=
{
'Param'
:
self
.
param_name
,
'Grad'
:
grad_name
,
'MeanSquare'
:
self
.
mean_square_name
,
'Moment'
:
self
.
moment_name
,
'LearningRate'
:
self
.
lr_name
,
'ParamOut'
:
self
.
param_name
,
'MeanSquareOut'
:
self
.
mean_square_name
,
'MomentOut'
:
self
.
moment_name
,
'epsilon'
:
self
.
epsilon
,
'decay'
:
self
.
decay
,
'momentum'
:
self
.
momentum
,
'centered'
:
self
.
centered
}
if
self
.
centered
:
atol
=
1e-3
equal_nan
=
True
kwargs
[
'MeanGrad'
]
=
self
.
mean_grad_name
kwargs
[
'MeanGradOut'
]
=
self
.
mean_grad_name
rmsprop_op
=
Operator
(
'rmsprop'
,
**
kwargs
)
atol
=
1e-6
rmsprop_op
.
run
(
self
.
scope
,
self
.
place
)
self
.
check
(
np
.
array
(
mean_square
),
self
.
ms_out
,
place
,
self
.
mean_square_name
)
np
.
array
(
self
.
mean_square_tensor
),
self
.
ms_out
,
self
.
place
,
self
.
mean_square_name
,
atol
=
atol
)
self
.
check
(
np
.
array
(
moment
),
np
.
array
(
self
.
moment_tensor
),
self
.
moment_out
,
place
,
self
.
place
,
self
.
moment_name
,
atol
=
atol
,
equal_nan
=
equal_nan
)
atol
=
atol
)
self
.
check
(
np
.
array
(
param
),
np
.
array
(
self
.
param_tensor
),
self
.
param_out
,
place
,
self
.
place
,
self
.
param_name
,
atol
=
atol
,
equal_nan
=
equal_nan
)
atol
=
atol
)
if
self
.
centered
:
self
.
check
(
np
.
array
(
mean_grad
),
self
.
mg_out
,
place
,
self
.
mean_grad_name
)
np
.
array
(
self
.
mean_grad_tensor
),
self
.
mg_out
,
self
.
place
,
self
.
mean_grad_name
)
def
test_rmsprop
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
size
=
(
128
,
320
)
for
place
in
places
:
self
.
check_with_place
(
place
,
False
,
1e-6
)
self
.
check_with_place
(
place
,
False
,
1e-10
)
self
.
check_with_place
(
place
,
True
,
1e-6
)
self
.
check_with_place
(
place
,
True
,
1e-10
)
for
centered
in
[
False
,
True
]:
with
fluid
.
scope_guard
(
core
.
Scope
()):
self
.
check_with_place
(
place
,
is_sparse
=
False
,
centered
=
centered
,
size
=
size
)
with
fluid
.
scope_guard
(
core
.
Scope
()):
self
.
check_with_place
(
place
,
is_sparse
=
True
,
centered
=
centered
,
row_num
=
512
,
size
=
size
)
with
fluid
.
scope_guard
(
core
.
Scope
()):
self
.
check_with_place
(
place
,
is_sparse
=
True
,
centered
=
centered
,
row_num
=
60
,
size
=
size
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录