Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
a2be4b4d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a2be4b4d
编写于
4月 23, 2019
作者:
C
chengduo
提交者:
GitHub
4月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fuse momenutum ops (#16745)
* Add fuse momenutum ops
上级
03d469ad
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
363 addition
and
282 deletion
+363
-282
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+3
-1
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+19
-15
paddle/fluid/framework/details/fuse_adam_op_pass.cc
paddle/fluid/framework/details/fuse_adam_op_pass.cc
+175
-162
paddle/fluid/framework/details/fuse_momentum_op_pass.cc
paddle/fluid/framework/details/fuse_momentum_op_pass.cc
+94
-0
paddle/fluid/framework/details/fuse_optimizer_op_pass.cc
paddle/fluid/framework/details/fuse_optimizer_op_pass.cc
+10
-10
paddle/fluid/framework/details/fuse_sgd_op_pass.cc
paddle/fluid/framework/details/fuse_sgd_op_pass.cc
+40
-39
paddle/fluid/framework/details/fuse_sgd_op_pass.h
paddle/fluid/framework/details/fuse_sgd_op_pass.h
+0
-50
python/paddle/fluid/tests/unittests/test_fuse_optimizer_pass.py
.../paddle/fluid/tests/unittests/test_fuse_optimizer_pass.py
+22
-5
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
a2be4b4d
...
@@ -14,6 +14,7 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
...
@@ -14,6 +14,7 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library
(
alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper
)
cc_library
(
alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper
)
cc_library
(
record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper
)
...
@@ -126,4 +127,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
...
@@ -126,4 +127,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass
)
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass
)
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
a2be4b4d
...
@@ -57,7 +57,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -57,7 +57,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
if
(
strategy_
.
enable_sequential_execution_
)
{
if
(
strategy_
.
enable_sequential_execution_
)
{
VLOG
(
10
)
<<
"Add sequential_execution_pass"
;
VLOG
(
5
)
<<
"Add sequential_execution_pass"
;
AppendPass
(
"sequential_execution_pass"
);
AppendPass
(
"sequential_execution_pass"
);
}
}
...
@@ -68,7 +68,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -68,7 +68,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion.
// Add op fusion.
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
VLOG
(
10
)
<<
"Add fuse_relu_depthwise_conv_pass"
;
VLOG
(
5
)
<<
"Add fuse_relu_depthwise_conv_pass"
;
AppendPass
(
"fuse_relu_depthwise_conv_pass"
);
AppendPass
(
"fuse_relu_depthwise_conv_pass"
);
}
}
...
@@ -80,19 +80,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -80,19 +80,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace.
// Add automatically inplace.
if
(
strategy_
.
enable_inplace_
)
{
if
(
strategy_
.
enable_inplace_
)
{
VLOG
(
10
)
<<
"Add inplace_pass"
;
VLOG
(
5
)
<<
"Add inplace_pass"
;
AppendPass
(
"inplace_pass"
);
AppendPass
(
"inplace_pass"
);
}
}
if
(
strategy_
.
fuse_elewise_add_act_ops_
)
{
if
(
strategy_
.
fuse_elewise_add_act_ops_
)
{
VLOG
(
10
)
<<
"Add fuse_elewise_add_act_pass"
;
VLOG
(
5
)
<<
"Add fuse_elewise_add_act_pass"
;
AppendPass
(
"fuse_elewise_add_act_pass"
);
AppendPass
(
"fuse_elewise_add_act_pass"
);
}
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
VLOG
(
10
)
<<
"Add alloc_continuous_space_for_grad_pass"
;
VLOG
(
5
)
<<
"Add alloc_continuous_space_for_grad_pass"
;
AppendPass
(
"alloc_continuous_space_for_grad_pass"
);
AppendPass
(
"alloc_continuous_space_for_grad_pass"
);
}
}
...
@@ -107,10 +107,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -107,10 +107,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
// Currently, only one type of optimization algorithm can be fused.
VLOG
(
10
)
<<
"Add fuse_adam_op_pass"
;
VLOG
(
5
)
<<
"Add fuse_adam_op_pass"
;
AppendPass
(
"fuse_adam_op_pass"
);
AppendPass
(
"fuse_adam_op_pass"
);
VLOG
(
10
)
<<
"Add fuse_sgd_op_pass"
;
VLOG
(
5
)
<<
"Add fuse_sgd_op_pass"
;
AppendPass
(
"fuse_sgd_op_pass"
);
AppendPass
(
"fuse_sgd_op_pass"
);
VLOG
(
5
)
<<
"Add fuse_momentum_op_pass"
;
AppendPass
(
"fuse_momentum_op_pass"
);
}
}
}
}
...
@@ -139,7 +141,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -139,7 +141,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
// , so fetchlist should be set persistable before call the Run interface.
if
(
strategy_
.
memory_optimize_
)
{
if
(
strategy_
.
memory_optimize_
)
{
VLOG
(
10
)
<<
"Add memory_optimize_pass"
;
VLOG
(
5
)
<<
"Add memory_optimize_pass"
;
AppendPass
(
"memory_optimize_pass"
);
AppendPass
(
"memory_optimize_pass"
);
}
}
...
@@ -147,7 +149,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -147,7 +149,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// all original and fused operators. But no operators can be enabled this
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
// attr if putting it after MultiDevPass.
if
(
strategy_
.
cache_runtime_context_
)
{
if
(
strategy_
.
cache_runtime_context_
)
{
VLOG
(
10
)
<<
"Add runtime_context_cache_pass"
;
VLOG
(
5
)
<<
"Add runtime_context_cache_pass"
;
AppendPass
(
"runtime_context_cache_pass"
);
AppendPass
(
"runtime_context_cache_pass"
);
}
}
...
@@ -161,7 +163,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -161,7 +163,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG
(
10
)
<<
"Add fuse_all_reduce_op_pass"
;
VLOG
(
5
)
<<
"Add fuse_all_reduce_op_pass"
;
AppendPass
(
"fuse_all_reduce_op_pass"
);
AppendPass
(
"fuse_all_reduce_op_pass"
);
}
}
...
@@ -182,12 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -182,12 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if
(
!
strategy_
.
enable_parallel_graph_
&&
if
(
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
))
{
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
))
{
VLOG
(
10
)
<<
"Add all_reduce_deps_pass"
;
VLOG
(
5
)
<<
"Add all_reduce_deps_pass"
;
AppendPass
(
"all_reduce_deps_pass"
);
AppendPass
(
"all_reduce_deps_pass"
);
}
}
if
(
strategy_
.
remove_unnecessary_lock_
)
{
if
(
strategy_
.
remove_unnecessary_lock_
)
{
VLOG
(
10
)
<<
"Add modify_op_lock_and_record_event_pass"
;
VLOG
(
5
)
<<
"Add modify_op_lock_and_record_event_pass"
;
AppendPass
(
"modify_op_lock_and_record_event_pass"
);
AppendPass
(
"modify_op_lock_and_record_event_pass"
);
}
}
...
@@ -202,16 +204,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -202,16 +204,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if
(
strategy_
.
async_mode_
)
{
if
(
strategy_
.
async_mode_
)
{
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
}
else
if
(
strategy_
.
is_distribution_
)
{
}
else
if
(
strategy_
.
is_distribution_
)
{
VLOG
(
10
)
VLOG
(
5
)
<<
"Add dist_multi_devices_pass, multi device parameter server mode"
;
<<
"Add dist_multi_devices_pass, multi device parameter server mode"
;
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
}
else
{
}
else
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
VLOG
(
10
)
<<
"Add all_reduce_mode_multi_devices_pass"
;
VLOG
(
5
)
<<
"Add all_reduce_mode_multi_devices_pass"
;
multi_devices_pass
=
multi_devices_pass
=
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
}
else
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
}
else
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
VLOG
(
10
)
<<
"Add reduce_mode_multi_devices_pass"
;
VLOG
(
5
)
<<
"Add reduce_mode_multi_devices_pass"
;
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
}
else
{
}
else
{
PADDLE_THROW
(
"Unknown reduce strategy."
);
PADDLE_THROW
(
"Unknown reduce strategy."
);
...
@@ -277,6 +279,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -277,6 +279,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
||
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_sgd_op_pass"
||
pass
->
Type
()
==
"fuse_sgd_op_pass"
||
pass
->
Type
()
==
"fuse_momentum_op_pass"
||
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
...
@@ -341,6 +344,7 @@ USE_PASS(alloc_continuous_space_for_grad_pass);
...
@@ -341,6 +344,7 @@ USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS
(
graph_to_program_pass
);
USE_PASS
(
graph_to_program_pass
);
USE_PASS
(
fuse_adam_op_pass
);
USE_PASS
(
fuse_adam_op_pass
);
USE_PASS
(
fuse_sgd_op_pass
);
USE_PASS
(
fuse_sgd_op_pass
);
USE_PASS
(
fuse_momentum_op_pass
);
USE_PASS
(
fuse_all_reduce_op_pass
);
USE_PASS
(
fuse_all_reduce_op_pass
);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
expected_kernel_cache_pass
);
USE_PASS
(
expected_kernel_cache_pass
);
...
...
paddle/fluid/framework/details/fuse_adam_op_pass.cc
浏览文件 @
a2be4b4d
...
@@ -11,9 +11,15 @@
...
@@ -11,9 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_adam_op_pass.h"
#include <algorithm>
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -21,13 +27,15 @@ namespace paddle {
...
@@ -21,13 +27,15 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
const
std
::
string
FuseAdamOpPass
::
GetOpType
()
const
{
return
"adam"
;
}
class
FuseAdamOpPass
:
public
FuseOptimizerOpPass
{
private:
const
std
::
string
GetOpType
()
const
{
return
"adam"
;
}
const
std
::
vector
<
std
::
string
>
FuseAdamOpPass
::
GetAuxiliaryVarNames
()
const
{
const
std
::
vector
<
std
::
string
>
GetAuxiliaryVarNames
()
const
{
return
{
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
};
return
{
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
};
}
}
void
FuseAdamOpPass
::
FuseOptimizerOps
(
void
FuseOptimizerOps
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
aux_var_set
,
&
aux_var_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
...
@@ -37,9 +45,9 @@ void FuseAdamOpPass::FuseOptimizerOps(
...
@@ -37,9 +45,9 @@ void FuseAdamOpPass::FuseOptimizerOps(
adam_ops
,
graph
);
adam_ops
,
graph
);
FuseScaleOps
(
aux_var_set
.
at
(
"Beta2Pow"
),
fused_vars_name
.
at
(
"Beta2Pow"
),
FuseScaleOps
(
aux_var_set
.
at
(
"Beta2Pow"
),
fused_vars_name
.
at
(
"Beta2Pow"
),
adam_ops
,
graph
);
adam_ops
,
graph
);
}
}
void
FuseAdamOpPass
::
FuseAdamOps
(
void
FuseAdamOps
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
ir
::
Graph
*
graph
)
const
{
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
ir
::
Graph
*
graph
)
const
{
...
@@ -67,14 +75,15 @@ void FuseAdamOpPass::FuseAdamOps(
...
@@ -67,14 +75,15 @@ void FuseAdamOpPass::FuseAdamOps(
PADDLE_ENFORCE_EQ
(
min_row_size_to_use_multithread
,
PADDLE_ENFORCE_EQ
(
min_row_size_to_use_multithread
,
boost
::
get
<
int64_t
>
(
adam_op
->
Op
()
->
GetAttr
(
boost
::
get
<
int64_t
>
(
adam_op
->
Op
()
->
GetAttr
(
"min_row_size_to_use_multithread"
)));
"min_row_size_to_use_multithread"
)));
PADDLE_ENFORCE_EQ
(
op_role
,
boost
::
get
<
int
>
(
adam_op
->
Op
()
->
GetAttr
(
PADDLE_ENFORCE_EQ
(
op_role
,
boost
::
get
<
int
>
(
adam_op
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
())));
OpProtoAndCheckerMaker
::
OpRoleAttrName
())));
}
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// NOTE: fused_var is only exist in scope, so the graph doesn't have
//
node.
// fused_var
node.
VLOG
(
10
)
<<
"Insert adam to graph "
;
VLOG
(
7
)
<<
"Insert adam to graph "
;
OpDesc
adam_desc
(
adam_ops
[
0
]
->
Op
()
->
Block
());
OpDesc
adam_desc
(
adam_ops
[
0
]
->
Op
()
->
Block
());
adam_desc
.
SetType
(
"adam"
);
adam_desc
.
SetType
(
"adam"
);
adam_desc
.
SetInput
(
kParam
,
{
fused_vars_name
.
at
(
kParam
)});
adam_desc
.
SetInput
(
kParam
,
{
fused_vars_name
.
at
(
kParam
)});
...
@@ -100,9 +109,9 @@ void FuseAdamOpPass::FuseAdamOps(
...
@@ -100,9 +109,9 @@ void FuseAdamOpPass::FuseAdamOps(
auto
adam_node
=
graph
->
CreateOpNode
(
&
adam_desc
);
auto
adam_node
=
graph
->
CreateOpNode
(
&
adam_desc
);
InserInputAndOutputForOptOps
(
adam_ops
,
adam_node
);
InserInputAndOutputForOptOps
(
adam_ops
,
adam_node
);
}
}
void
FuseAdamOpPass
::
FuseScaleOps
(
const
std
::
vector
<
std
::
string
>
&
beta_name
,
void
FuseScaleOps
(
const
std
::
vector
<
std
::
string
>
&
beta_name
,
const
std
::
string
&
fused_var_name
,
const
std
::
string
&
fused_var_name
,
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
*
graph
)
const
{
...
@@ -117,7 +126,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
...
@@ -117,7 +126,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
auto
beta_pow_iter
=
std
::
find_if
(
auto
beta_pow_iter
=
std
::
find_if
(
adam_ops
[
i
]
->
inputs
.
begin
(),
adam_ops
[
i
]
->
inputs
.
end
(),
adam_ops
[
i
]
->
inputs
.
begin
(),
adam_ops
[
i
]
->
inputs
.
end
(),
[
&
beta_name
,
&
beta_1_pow_name
](
ir
::
Node
*
var_node
)
->
bool
{
[
&
beta_name
,
&
beta_1_pow_name
](
ir
::
Node
*
var_node
)
->
bool
{
return
var_node
->
Var
()
&&
var_node
->
Var
()
->
Name
()
==
beta_1_pow_name
;
return
var_node
->
Var
()
&&
var_node
->
Var
()
->
Name
()
==
beta_1_pow_name
;
});
});
PADDLE_ENFORCE
(
beta_pow_iter
!=
adam_ops
[
i
]
->
inputs
.
end
());
PADDLE_ENFORCE
(
beta_pow_iter
!=
adam_ops
[
i
]
->
inputs
.
end
());
...
@@ -144,18 +154,20 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
...
@@ -144,18 +154,20 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for
(
auto
&
scale_op
:
scale_ops
)
{
for
(
auto
&
scale_op
:
scale_ops
)
{
PADDLE_ENFORCE_EQ
(
scale
,
PADDLE_ENFORCE_EQ
(
scale
,
boost
::
get
<
float
>
(
scale_op
->
Op
()
->
GetAttr
(
"scale"
)));
boost
::
get
<
float
>
(
scale_op
->
Op
()
->
GetAttr
(
"scale"
)));
PADDLE_ENFORCE_EQ
(
bias
,
boost
::
get
<
float
>
(
scale_op
->
Op
()
->
GetAttr
(
"bias"
)));
PADDLE_ENFORCE_EQ
(
bias
,
boost
::
get
<
float
>
(
scale_op
->
Op
()
->
GetAttr
(
"bias"
)));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
bias_after_scale
,
bias_after_scale
,
boost
::
get
<
bool
>
(
scale_op
->
Op
()
->
GetAttr
(
"bias_after_scale"
)));
boost
::
get
<
bool
>
(
scale_op
->
Op
()
->
GetAttr
(
"bias_after_scale"
)));
PADDLE_ENFORCE_EQ
(
op_role
,
boost
::
get
<
int
>
(
scale_op
->
Op
()
->
GetAttr
(
PADDLE_ENFORCE_EQ
(
op_role
,
boost
::
get
<
int
>
(
scale_op
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
())));
OpProtoAndCheckerMaker
::
OpRoleAttrName
())));
}
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// NOTE: fused_var is only exist in scope, so the graph doesn't have
//
node.
// fused_var
node.
VLOG
(
10
)
<<
"Insert fused scale to graph."
;
VLOG
(
7
)
<<
"Insert fused scale to graph."
;
OpDesc
scale_desc
(
scale_ops
[
0
]
->
Op
()
->
Block
());
OpDesc
scale_desc
(
scale_ops
[
0
]
->
Op
()
->
Block
());
scale_desc
.
SetType
(
"scale"
);
scale_desc
.
SetType
(
"scale"
);
scale_desc
.
SetInput
(
"X"
,
{
fused_var_name
});
scale_desc
.
SetInput
(
"X"
,
{
fused_var_name
});
...
@@ -169,7 +181,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
...
@@ -169,7 +181,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for
(
auto
scale_op
:
scale_ops
)
{
for
(
auto
scale_op
:
scale_ops
)
{
// set inputs
// set inputs
scale_node
->
inputs
.
insert
(
scale_node
->
inputs
.
begin
(),
scale_node
->
inputs
.
insert
(
scale_node
->
inputs
.
begin
(),
scale_op
->
inputs
.
begin
(),
scale_op
->
inputs
.
end
());
scale_op
->
inputs
.
begin
(),
scale_op
->
inputs
.
end
());
for
(
auto
&
input
:
scale_op
->
inputs
)
{
for
(
auto
&
input
:
scale_op
->
inputs
)
{
std
::
replace
(
input
->
outputs
.
begin
(),
input
->
outputs
.
end
(),
scale_op
,
std
::
replace
(
input
->
outputs
.
begin
(),
input
->
outputs
.
end
(),
scale_op
,
scale_node
);
scale_node
);
...
@@ -188,8 +201,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
...
@@ -188,8 +201,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for
(
auto
&
scale_op
:
scale_ops
)
{
for
(
auto
&
scale_op
:
scale_ops
)
{
graph
->
RemoveNode
(
scale_op
);
graph
->
RemoveNode
(
scale_op
);
}
}
}
}
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/details/fuse_
adam_op_pass.h
→
paddle/fluid/framework/details/fuse_
momentum_op_pass.cc
浏览文件 @
a2be4b4d
...
@@ -12,44 +12,83 @@
...
@@ -12,44 +12,83 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/
details/multi_devices
_helper.h"
#include "paddle/fluid/framework/
ir/graph
_helper.h"
#include "paddle/fluid/framework/
ir/graph
.h"
#include "paddle/fluid/framework/
op_registry
.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
class
Fuse
Ada
mOpPass
:
public
FuseOptimizerOpPass
{
class
Fuse
Momentu
mOpPass
:
public
FuseOptimizerOpPass
{
private:
private:
virtual
const
std
::
string
GetOpType
()
const
;
virtual
const
std
::
string
GetOpType
()
const
{
return
"momentum"
;
}
virtual
const
std
::
vector
<
std
::
string
>
GetAuxiliaryVarNames
()
const
;
virtual
const
std
::
vector
<
std
::
string
>
GetAuxiliaryVarNames
()
const
{
return
{
"Velocity"
};
}
// Fuse
Adam Ops and Scale Ops which are used to update "Beta1Pow", "Beta2Pow"
// Fuse
Momentum Ops
virtual
void
FuseOptimizerOps
(
virtual
void
FuseOptimizerOps
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
ir
::
Graph
*
graph
)
const
;
const
std
::
vector
<
ir
::
Node
*>
&
momentum_ops
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_GT
(
momentum_ops
.
size
(),
static_cast
<
size_t
>
(
0
));
void
FuseAdamOps
(
// Check attributions
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
// NOTE: If new attribution is added, the following code maybe need change.
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
int
op_role
=
boost
::
get
<
int
>
(
momentum_ops
[
0
]
->
Op
()
->
GetAttr
(
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
ir
::
Graph
*
graph
)
const
;
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
float
mu
=
boost
::
get
<
float
>
(
momentum_ops
[
0
]
->
Op
()
->
GetAttr
(
"mu"
));
bool
use_nesterov
=
boost
::
get
<
bool
>
(
momentum_ops
[
0
]
->
Op
()
->
GetAttr
(
"use_nesterov"
));
for
(
auto
&
momentum_op
:
momentum_ops
)
{
PADDLE_ENFORCE_EQ
(
mu
,
boost
::
get
<
float
>
(
momentum_op
->
Op
()
->
GetAttr
(
"mu"
)));
PADDLE_ENFORCE_EQ
(
use_nesterov
,
boost
::
get
<
bool
>
(
momentum_op
->
Op
()
->
GetAttr
(
"use_nesterov"
)));
PADDLE_ENFORCE_EQ
(
op_role
,
boost
::
get
<
int
>
(
momentum_op
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node.
void
FuseScaleOps
(
const
std
::
vector
<
std
::
string
>
&
aux_var_set
,
VLOG
(
7
)
<<
"Insert momentum to graph "
;
const
std
::
string
&
fused_var_name
,
OpDesc
momentum_desc
(
momentum_ops
[
0
]
->
Op
()
->
Block
());
const
std
::
vector
<
ir
::
Node
*>
&
adam_ops
,
momentum_desc
.
SetType
(
"momentum"
);
ir
::
Graph
*
graph
)
const
;
momentum_desc
.
SetInput
(
kParam
,
{
fused_vars_name
.
at
(
kParam
)});
momentum_desc
.
SetInput
(
kGrad
,
{
fused_vars_name
.
at
(
kGrad
)});
momentum_desc
.
SetInput
(
"Velocity"
,
{
fused_vars_name
.
at
(
"Velocity"
)});
// TODO(zcd): The LearningRate should be equal.
momentum_desc
.
SetInput
(
kLearningRate
,
momentum_ops
[
0
]
->
Op
()
->
Input
(
kLearningRate
));
momentum_desc
.
SetOutput
(
"ParamOut"
,
{
fused_vars_name
.
at
(
kParam
)});
momentum_desc
.
SetOutput
(
"VelocityOut"
,
{
fused_vars_name
.
at
(
"Velocity"
)});
momentum_desc
.
SetAttr
(
"mu"
,
mu
);
momentum_desc
.
SetAttr
(
"use_nesterov"
,
use_nesterov
);
momentum_desc
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
op_role
);
auto
momentum_node
=
graph
->
CreateOpNode
(
&
momentum_desc
);
InserInputAndOutputForOptOps
(
momentum_ops
,
momentum_node
);
}
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
details
::
FuseMomentumOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/details/fuse_optimizer_op_pass.cc
浏览文件 @
a2be4b4d
...
@@ -42,14 +42,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -42,14 +42,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
&
aux_var_set
);
&
aux_var_set
);
}
}
VLOG
(
10
)
<<
"Find "
<<
fuse_op_type
<<
" operators: "
<<
opt_ops
.
size
();
VLOG
(
6
)
<<
"Find "
<<
fuse_op_type
<<
" operators: "
<<
opt_ops
.
size
();
if
(
opt_ops
.
size
()
==
0
)
{
if
(
opt_ops
.
size
()
==
0
)
{
return
;
return
;
}
}
if
(
result
.
Has
(
kFusedOptType
))
{
if
(
result
.
Has
(
kFusedOptType
))
{
VLOG
(
10
)
VLOG
(
6
)
<<
"Currently only support fusing one type optimizer op. Has fused "
<<
"Currently only support fusing one type optimizer op. Has fused "
<<
result
.
Get
<
FusedOptType
>
(
kFusedOptType
);
<<
result
.
Get
<
FusedOptType
>
(
kFusedOptType
);
return
;
return
;
}
else
{
}
else
{
...
@@ -70,7 +69,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -70,7 +69,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
for
(
auto
&
var_name
:
aux_var_names
)
{
for
(
auto
&
var_name
:
aux_var_names
)
{
auto
fused_var_name
=
prefix
+
"_"
+
fuse_op_type
+
"_"
+
var_name
+
"_"
+
auto
fused_var_name
=
prefix
+
"_"
+
fuse_op_type
+
"_"
+
var_name
+
"_"
+
aux_var_set
[
var_name
][
0
];
aux_var_set
[
var_name
][
0
];
VLOG
(
10
)
<<
fused_var_name
;
VLOG
(
6
)
<<
var_name
<<
": "
<<
fused_var_name
;
fused_vars_name
.
emplace
(
var_name
,
fused_var_name
);
fused_vars_name
.
emplace
(
var_name
,
fused_var_name
);
PADDLE_ENFORCE_EQ
(
fused_var_set
.
count
(
fused_var_name
),
0
);
PADDLE_ENFORCE_EQ
(
fused_var_set
.
count
(
fused_var_name
),
0
);
fused_var_set
.
insert
(
fused_var_name
);
fused_var_set
.
insert
(
fused_var_name
);
...
@@ -151,7 +150,7 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
...
@@ -151,7 +150,7 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
// Init Grads
// Init Grads
for
(
auto
it
=
local_scopes
.
rbegin
();
it
!=
local_scopes
.
rend
();
++
it
)
{
for
(
auto
it
=
local_scopes
.
rbegin
();
it
!=
local_scopes
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
auto
&
scope
=
*
it
;
VLOG
(
10
)
<<
"Init
"
<<
fused_grad_name
;
VLOG
(
6
)
<<
"Init:
"
<<
fused_grad_name
;
PADDLE_ENFORCE
(
scope
->
FindVar
(
fused_grad_name
)
==
nullptr
,
PADDLE_ENFORCE
(
scope
->
FindVar
(
fused_grad_name
)
==
nullptr
,
"%s has existed in scope."
,
fused_grad_name
);
"%s has existed in scope."
,
fused_grad_name
);
scope
->
Var
(
fused_grad_name
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
fused_grad_name
)
->
GetMutable
<
LoDTensor
>
();
...
@@ -211,13 +210,12 @@ void FuseOptimizerOpPass::RunInitOps(const std::vector<platform::Place> &places,
...
@@ -211,13 +210,12 @@ void FuseOptimizerOpPass::RunInitOps(const std::vector<platform::Place> &places,
void
FuseOptimizerOpPass
::
InitVars
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
void
FuseOptimizerOpPass
::
InitVars
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
string
&
fused_var_name
)
const
{
const
std
::
string
&
fused_var_name
)
const
{
VLOG
(
10
)
<<
"Init FusedVars."
;
// Alloc parameters and auxiliary vars in the respective scope.
// Alloc parameters and auxiliary vars in the respective scope.
size_t
idx
=
local_scopes
.
size
();
size_t
idx
=
local_scopes
.
size
();
for
(
auto
iter
=
local_scopes
.
rbegin
();
iter
!=
local_scopes
.
rend
();
for
(
auto
iter
=
local_scopes
.
rbegin
();
iter
!=
local_scopes
.
rend
();
++
iter
,
--
idx
)
{
++
iter
,
--
idx
)
{
auto
&
scope
=
*
iter
;
auto
&
scope
=
*
iter
;
VLOG
(
10
)
<<
"Init
"
<<
fused_var_name
;
VLOG
(
6
)
<<
"Init:
"
<<
fused_var_name
;
PADDLE_ENFORCE
(
scope
->
FindVar
(
fused_var_name
)
==
nullptr
,
PADDLE_ENFORCE
(
scope
->
FindVar
(
fused_var_name
)
==
nullptr
,
"%s has exist in scope[%d]"
,
fused_var_name
,
idx
);
"%s has exist in scope[%d]"
,
fused_var_name
,
idx
);
scope
->
Var
(
fused_var_name
)
->
GetMutable
<
LoDTensor
>
();
scope
->
Var
(
fused_var_name
)
->
GetMutable
<
LoDTensor
>
();
...
@@ -253,7 +251,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
...
@@ -253,7 +251,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
for
(
auto
&
var_name
:
aux_vars
.
second
)
{
for
(
auto
&
var_name
:
aux_vars
.
second
)
{
out
<<
var_name
<<
" "
;
out
<<
var_name
<<
" "
;
}
}
VLOG
(
10
)
<<
aux_vars
.
first
<<
": "
<<
out
.
str
();
VLOG
(
6
)
<<
aux_vars
.
first
<<
": "
<<
out
.
str
();
}
}
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
...
@@ -271,12 +269,14 @@ void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
...
@@ -271,12 +269,14 @@ void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
const
{
const
{
if
(
node
->
Op
()
->
Type
()
!=
op_type
)
return
;
if
(
node
->
Op
()
->
Type
()
!=
op_type
)
return
;
std
::
stringstream
out
;
for
(
auto
&
var_n
:
aux_vars_name
)
{
for
(
auto
&
var_n
:
aux_vars_name
)
{
auto
arg_names
=
node
->
Op
()
->
Input
(
var_n
);
auto
arg_names
=
node
->
Op
()
->
Input
(
var_n
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
static_cast
<
size_t
>
(
1
));
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
static_cast
<
size_t
>
(
1
));
(
*
aux_args_name
)[
var_n
].
emplace_back
(
arg_names
[
0
]);
(
*
aux_args_name
)[
var_n
].
emplace_back
(
arg_names
[
0
]);
VLOG
(
10
)
<<
var_n
<<
", "
<<
arg_names
[
0
]
;
out
<<
var_n
<<
", "
<<
arg_names
[
0
]
<<
"; "
;
}
}
VLOG
(
7
)
<<
out
.
str
();
ops
->
emplace_back
(
node
);
ops
->
emplace_back
(
node
);
}
}
...
...
paddle/fluid/framework/details/fuse_sgd_op_pass.cc
浏览文件 @
a2be4b4d
...
@@ -11,42 +11,43 @@
...
@@ -11,42 +11,43 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_sgd_op_pass.h"
#include <algorithm>
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
const
std
::
string
FuseSgdOpPass
::
GetOpType
()
const
{
return
"sgd"
;
}
class
FuseSgdOpPass
:
public
FuseOptimizerOpPass
{
private:
virtual
const
std
::
string
GetOpType
()
const
{
return
"sgd"
;
}
const
std
::
vector
<
std
::
string
>
FuseSgdOpPass
::
GetAuxiliaryVarNames
()
const
{
virtual
const
std
::
vector
<
std
::
string
>
GetAuxiliaryVarNames
()
const
{
return
{};
return
{};
}
}
void
FuseSgdOpPass
::
FuseOptimizerOps
(
// Fuse Sgd Ops
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
virtual
void
FuseOptimizerOps
(
&
aux_var_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
vector
<
ir
::
Node
*>
&
sgd_ops
,
ir
::
Graph
*
graph
)
const
{
FuseSgdOps
(
aux_var_set
,
fused_vars_name
,
sgd_ops
,
graph
);
}
void
FuseSgdOpPass
::
FuseSgdOps
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
vector
<
ir
::
Node
*>
&
sgd_ops
,
ir
::
Graph
*
graph
)
const
{
const
std
::
vector
<
ir
::
Node
*>
&
sgd_ops
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_GT
(
sgd_ops
.
size
(),
static_cast
<
size_t
>
(
0
));
PADDLE_ENFORCE_GT
(
sgd_ops
.
size
(),
static_cast
<
size_t
>
(
0
));
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// NOTE: fused_var is only exist in scope, so the graph doesn't have
//
node.
// fused_var
node.
int
op_role
=
boost
::
get
<
int
>
(
int
op_role
=
boost
::
get
<
int
>
(
sgd_ops
[
0
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
sgd_ops
[
0
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
VLOG
(
10
)
<<
"Insert sgd to graph "
;
VLOG
(
7
)
<<
"Insert sgd to graph "
;
// Add fused scale
// Add fused scale
OpDesc
Sgd_desc
(
sgd_ops
[
0
]
->
Op
()
->
Block
());
OpDesc
Sgd_desc
(
sgd_ops
[
0
]
->
Op
()
->
Block
());
Sgd_desc
.
SetType
(
"sgd"
);
Sgd_desc
.
SetType
(
"sgd"
);
...
@@ -54,7 +55,7 @@ void FuseSgdOpPass::FuseSgdOps(
...
@@ -54,7 +55,7 @@ void FuseSgdOpPass::FuseSgdOps(
Sgd_desc
.
SetInput
(
kGrad
,
{
fused_vars_name
.
at
(
kGrad
)});
Sgd_desc
.
SetInput
(
kGrad
,
{
fused_vars_name
.
at
(
kGrad
)});
Sgd_desc
.
SetOutput
(
"ParamOut"
,
{
fused_vars_name
.
at
(
kParam
)});
Sgd_desc
.
SetOutput
(
"ParamOut"
,
{
fused_vars_name
.
at
(
kParam
)});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow
should be equal.
// TODO(zcd): The LearningRate
should be equal.
Sgd_desc
.
SetInput
(
kLearningRate
,
sgd_ops
[
0
]
->
Op
()
->
Input
(
kLearningRate
));
Sgd_desc
.
SetInput
(
kLearningRate
,
sgd_ops
[
0
]
->
Op
()
->
Input
(
kLearningRate
));
// NOTE: multi_devices_pass requires that every op should have a role.
// NOTE: multi_devices_pass requires that every op should have a role.
...
@@ -63,8 +64,8 @@ void FuseSgdOpPass::FuseSgdOps(
...
@@ -63,8 +64,8 @@ void FuseSgdOpPass::FuseSgdOps(
auto
sgd_node
=
graph
->
CreateOpNode
(
&
Sgd_desc
);
auto
sgd_node
=
graph
->
CreateOpNode
(
&
Sgd_desc
);
InserInputAndOutputForOptOps
(
sgd_ops
,
sgd_node
);
InserInputAndOutputForOptOps
(
sgd_ops
,
sgd_node
);
}
}
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/details/fuse_sgd_op_pass.h
已删除
100644 → 0
浏览文件 @
03d469ad
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
FuseSgdOpPass
:
public
FuseOptimizerOpPass
{
private:
virtual
const
std
::
string
GetOpType
()
const
;
virtual
const
std
::
vector
<
std
::
string
>
GetAuxiliaryVarNames
()
const
;
// Fuse Sgd Ops
virtual
void
FuseOptimizerOps
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
vector
<
ir
::
Node
*>
&
sgd_ops
,
ir
::
Graph
*
graph
)
const
;
void
FuseSgdOps
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
vars_set
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
fused_vars_name
,
const
std
::
vector
<
ir
::
Node
*>
&
sgd_ops
,
ir
::
Graph
*
graph
)
const
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_fuse_optimizer_pass.py
浏览文件 @
a2be4b4d
...
@@ -31,18 +31,17 @@ class TestFuseAdamOps(TestParallelExecutorBase):
...
@@ -31,18 +31,17 @@ class TestFuseAdamOps(TestParallelExecutorBase):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
img
,
label
=
init_data
()
img
,
label
=
init_data
()
feed_dict
=
{
"image"
:
img
,
"label"
:
label
}
not_fuse_op_first_loss
,
not_fuse_op_last_loss
=
self
.
check_network_convergence
(
not_fuse_op_first_loss
,
not_fuse_op_last_loss
=
self
.
check_network_convergence
(
model
,
model
,
feed_dict
=
{
"image"
:
img
,
feed_dict
=
feed_dict
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
fuse_all_optimizer_ops
=
False
,
fuse_all_optimizer_ops
=
False
,
memory_opt
=
False
,
# avoid the gradient's name changed in Python side.
memory_opt
=
False
,
# avoid the gradient's name changed in Python side.
optimizer
=
optimizer
)
optimizer
=
optimizer
)
fuse_op_first_loss
,
fuse_op_last_loss
=
self
.
check_network_convergence
(
fuse_op_first_loss
,
fuse_op_last_loss
=
self
.
check_network_convergence
(
model
,
model
,
feed_dict
=
{
"image"
:
img
,
feed_dict
=
feed_dict
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
fuse_all_optimizer_ops
=
True
,
fuse_all_optimizer_ops
=
True
,
memory_opt
=
False
,
# avoid the gradient's name changed in Python side.
memory_opt
=
False
,
# avoid the gradient's name changed in Python side.
...
@@ -63,7 +62,7 @@ class TestFuseAdamOps(TestParallelExecutorBase):
...
@@ -63,7 +62,7 @@ class TestFuseAdamOps(TestParallelExecutorBase):
class
TestFuseSGDOps
(
TestFuseAdamOps
):
class
TestFuseSGDOps
(
TestFuseAdamOps
):
def
sgd_optimizer
(
self
,
learning_rate
=
1e-
4
):
def
sgd_optimizer
(
self
,
learning_rate
=
1e-
3
):
return
fluid
.
optimizer
.
SGD
(
learning_rate
=
learning_rate
)
return
fluid
.
optimizer
.
SGD
(
learning_rate
=
learning_rate
)
def
test_simple_fc_with_fuse_op
(
self
):
def
test_simple_fc_with_fuse_op
(
self
):
...
@@ -79,5 +78,23 @@ class TestFuseSGDOps(TestFuseAdamOps):
...
@@ -79,5 +78,23 @@ class TestFuseSGDOps(TestFuseAdamOps):
fc_with_batchnorm
,
False
,
optimizer
=
self
.
sgd_optimizer
)
fc_with_batchnorm
,
False
,
optimizer
=
self
.
sgd_optimizer
)
class
TestFuseMomentumOps
(
TestFuseAdamOps
):
def
momentum_optimizer
(
self
,
learning_rate
=
1e-3
):
return
fluid
.
optimizer
.
Momentum
(
learning_rate
=
learning_rate
,
momentum
=
0.1
)
def
test_simple_fc_with_fuse_op
(
self
):
self
.
_compare_fused_optimizer_ops
(
simple_fc_net
,
True
,
optimizer
=
self
.
momentum_optimizer
)
self
.
_compare_fused_optimizer_ops
(
simple_fc_net
,
False
,
optimizer
=
self
.
momentum_optimizer
)
def
test_batchnorm_fc_with_fuse_op
(
self
):
self
.
_compare_fused_optimizer_ops
(
fc_with_batchnorm
,
True
,
optimizer
=
self
.
momentum_optimizer
)
self
.
_compare_fused_optimizer_ops
(
fc_with_batchnorm
,
False
,
optimizer
=
self
.
momentum_optimizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录