Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
781d2844
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
781d2844
编写于
11月 01, 2019
作者:
1
123malin
提交者:
GitHub
11月 01, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize decay (#20816) (#20952)
* update pserver decay blocks * update distributed notify handler
上级
55c2329a
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
399 addition
and
194 deletion
+399
-194
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+9
-1
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+15
-3
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+16
-19
paddle/fluid/operators/distributed/communicator_test.cc
paddle/fluid/operators/distributed/communicator_test.cc
+2
-2
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+28
-14
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+2
-15
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+9
-11
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+36
-16
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+16
-2
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+2
-1
paddle/fluid/operators/distributed/rpc_common.h
paddle/fluid/operators/distributed/rpc_common.h
+12
-2
paddle/fluid/operators/distributed_ops/distributed_notify_op.cc
.../fluid/operators/distributed_ops/distributed_notify_op.cc
+0
-84
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+2
-1
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+31
-7
python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py
...fluid/tests/unittests/test_dist_transpiler_async_decay.py
+143
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+75
-15
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
781d2844
...
...
@@ -62,8 +62,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
auto
merge_add
=
boost
::
get
<
bool
>
(
node
->
Op
()
->
GetNullableAttr
(
"merge_add"
));
if
(
!
merge_add
)
{
merge_add
=
FLAGS_communicator_is_sgd_optimizer
;
}
auto
use_send_handler
=
boost
::
get
<
bool
>
(
node
->
Op
()
->
GetNullableAttr
(
"use_send_handler"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
);
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
,
merge_add
,
use_send_handler
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Name
()
==
"recv"
)
{
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
781d2844
...
...
@@ -130,8 +130,15 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto
height_section
=
boost
::
get
<
std
::
vector
<
int64_t
>>
(
op
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
auto
merge_add
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"merge_add"
));
if
(
!
merge_add
)
{
merge_add
=
FLAGS_communicator_is_sgd_optimizer
;
}
auto
use_send_handler
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"use_send_handler"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
);
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
,
merge_add
,
use_send_handler
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
op
->
Type
()
==
"recv"
)
{
...
...
@@ -208,12 +215,17 @@ void AsyncCommunicator::SendThread() {
}
}
auto
before_merge
=
GetCurrentUS
();
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
ctx
.
use_send_handler
)
{
MergeVars
<
float
>
(
var_name
,
vars
,
send_scope_
.
get
(),
ctx
.
merge_add
);
}
else
{
MergeVars
<
int64_t
>
(
var_name
,
vars
,
send_scope_
.
get
(),
ctx
.
merge_add
);
}
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
FLAGS_communicator_fake_rpc
)
{
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
}
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
781d2844
...
...
@@ -107,21 +107,21 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
inline
void
MergeVars
(
const
std
::
string
&
var_name
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>&
vars
,
Scope
*
scope
)
{
Scope
*
scope
,
bool
merge_add
=
true
)
{
PADDLE_ENFORCE
(
!
vars
.
empty
(),
"should have value to merge!"
);
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
var0
=
vars
[
0
];
auto
*
out_var
=
scope
->
Var
(
var_name
);
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
dims
=
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor dims "
<<
dims
;
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor dims "
<<
dims
<<
"; merge add: "
<<
merge_add
;
// init output tensor
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_t
->
mutable_data
<
float
>
(
dims
,
cpu_place
);
out_t
->
mutable_data
<
T
>
(
dims
,
cpu_place
);
// check the input dims
for
(
auto
&
var
:
vars
)
{
auto
&
var_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
...
...
@@ -130,44 +130,41 @@ inline void MergeVars(const std::string& var_name,
// set output tensor to 0.
auto
cpu_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
math
::
SetConstant
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
constant_functor
;
constant_functor
(
cpu_ctx
,
out_t
,
static_cast
<
float
>
(
0
));
math
::
SetConstant
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
constant_functor
;
constant_functor
(
cpu_ctx
,
out_t
,
static_cast
<
T
>
(
0
));
// sum all vars to out
auto
result
=
EigenVector
<
float
>::
Flatten
(
*
out_t
);
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out_t
);
for
(
auto
&
var
:
vars
)
{
auto
&
in_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
in
=
EigenVector
<
float
>::
Flatten
(
in_t
);
auto
in
=
EigenVector
<
T
>::
Flatten
(
in_t
);
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
+
in
;
}
if
(
!
FLAGS_communicator_is_sgd_optimizer
)
{
if
(
!
merge_add
)
{
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
/
static_cast
<
float
>
(
vars
.
size
());
result
/
static_cast
<
T
>
(
vars
.
size
());
}
}
else
if
(
var0
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
slr0
=
var0
->
Get
<
framework
::
SelectedRows
>
();
auto
*
out_slr
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
mutable_rows
()
->
clear
();
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
({{}},
cpu_place
);
out_slr
->
mutable_value
()
->
mutable_data
<
T
>
({{}},
cpu_place
);
std
::
vector
<
const
paddle
::
framework
::
SelectedRows
*>
inputs
;
inputs
.
reserve
(
vars
.
size
());
for
(
auto
&
var
:
vars
)
{
inputs
.
push_back
(
&
var
->
Get
<
framework
::
SelectedRows
>
());
}
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
if
(
FLAGS_communicator_is_sgd_optimizer
)
{
math
::
scatter
::
MergeAdd
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
merge_add
;
if
(
merge_add
)
{
math
::
scatter
::
MergeAdd
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
merge_add
;
merge_add
(
dev_ctx
,
inputs
,
out_slr
);
}
else
{
math
::
scatter
::
MergeAverage
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
math
::
scatter
::
MergeAverage
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
merge_average
;
merge_average
(
dev_ctx
,
inputs
,
out_slr
);
}
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" SelectedRows height: "
<<
slr0
.
height
()
<<
" dims: "
<<
slr0
.
value
().
dims
();
<<
" dims: "
<<
slr0
.
value
().
dims
()
<<
"; merge add: "
<<
merge_add
;
}
else
{
PADDLE_THROW
(
"unsupported var type!"
);
}
...
...
paddle/fluid/operators/distributed/communicator_test.cc
浏览文件 @
781d2844
...
...
@@ -47,7 +47,7 @@ TEST(communicator, merge_lod_tensors) {
scope
.
reset
(
new
framework
::
Scope
());
scope
->
Var
(
out_name
);
for
(
auto
i
=
0
;
i
<
10
;
++
i
)
{
MergeVars
(
out_name
,
in_vars
,
scope
.
get
());
MergeVars
<
float
>
(
out_name
,
in_vars
,
scope
.
get
());
}
auto
&
out_tensor
=
scope
->
FindVar
(
out_name
)
->
Get
<
LoDTensor
>
();
auto
*
out_data
=
out_tensor
.
data
<
float
>
();
...
...
@@ -86,7 +86,7 @@ TEST(communicator, merge_selected_rows) {
scope
.
reset
(
new
framework
::
Scope
());
scope
->
Var
(
out_name
);
for
(
auto
i
=
0
;
i
<
10
;
++
i
)
{
MergeVars
(
out_name
,
in_vars
,
scope
.
get
());
MergeVars
<
float
>
(
out_name
,
in_vars
,
scope
.
get
());
}
auto
&
out_slr
=
scope
->
FindVar
(
out_name
)
->
Get
<
SelectedRows
>
();
auto
&
out_t
=
out_slr
.
value
();
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
781d2844
...
...
@@ -438,26 +438,40 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
return
h
;
}
VarHandlePtr
GRPCClient
::
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
type
,
VarHandlePtr
GRPCClient
::
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
DistributeNotifyProcessor
*
s
=
new
DistributeNotifyProcessor
(
ch
);
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
kRequestNotify
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
LEARNING_RATE_DECAY_MESSAGE
,
nullptr
,
nullptr
));
SendProcessor
*
s
=
new
SendProcessor
(
ch
);
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
type
);
framework
::
AsyncIO
([
var_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
var_name_val
,
var
,
*
p_ctx
,
&
req
,
""
,
trainer_id_
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
// stub context
s
->
response_call_back_
=
nullptr
;
platform
::
RecordRPCEvent
record_event
(
method
);
auto
rpc
=
s
->
stub_
->
AsyncDistributeNotify
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/DistributeNotify"
,
req
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
});
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
781d2844
...
...
@@ -173,20 +173,6 @@ class CheckpointNotifyProcessor : public BaseProcessor {
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
class
DistributeNotifyProcessor
:
public
BaseProcessor
{
public:
explicit
DistributeNotifyProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
()
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
}
virtual
~
DistributeNotifyProcessor
()
{}
void
ProcessImpl
()
override
{}
sendrecv
::
VoidMessage
reply_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
class
GRPCClient
:
public
RPCClient
{
public:
GRPCClient
()
:
ok_
(
true
),
completed_
(
false
),
stopped_
(
false
)
{}
...
...
@@ -240,7 +226,8 @@ class GRPCClient : public RPCClient {
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
type
,
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendComplete
(
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
781d2844
...
...
@@ -400,33 +400,31 @@ class RequestNotify final : public RequestBase {
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
()));
request_handler
->
dev_ctx
(),
!
request_handler
->
sync_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestNotify
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestNotify
()
{}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
varname
=
GetReqName
();
VLOG
(
4
)
<<
"RequestNotify var_name:"
<<
varname
;
std
::
string
varname
=
request_
->
Varname
();
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
request_
->
GetVar
();
int
trainer_id
=
request_
->
GetTrainerId
();
VLOG
(
4
)
<<
"RequestNotify notify: "
<<
varname
<<
", trainer id: "
<<
trainer_id
;
request_handler_
->
Handle
(
varname
,
scope
,
nullptr
,
nullptr
,
trainer_id
);
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
Finish
(
reply_
,
&
responder_
);
}
protected:
std
::
shared_ptr
<
GRPCVariableResponse
>
request_
;
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
GRPCVariableResponse
>
request_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
781d2844
...
...
@@ -116,7 +116,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
row_offset
+=
outs_dims
[
i
][
0
];
}
}
if
(
rpc_ctx
.
use_send_handler
)
{
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
send_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
VLOG
(
4
)
<<
"send var name: "
<<
send_var_name
;
...
...
@@ -133,7 +133,27 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
<<
rpc_ctx
.
splited_var_names
[
i
];
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
rpc_ctx
.
epmap
.
size
();
j
++
)
{
auto
&
send_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
VLOG
(
4
)
<<
"send var name: "
<<
send_var_name
;
auto
&
endpoint
=
rpc_ctx
.
epmap
[
j
];
VLOG
(
4
)
<<
"send var endpoint: "
<<
endpoint
;
VLOG
(
4
)
<<
"need send: "
<<
NeedSend
(
*
local_scope
.
get
(),
send_var_name
);
if
(
NeedSend
(
*
local_scope
.
get
(),
send_var_name
))
{
VLOG
(
3
)
<<
"sending "
<<
send_var_name
<<
" to "
<<
endpoint
;
rets
.
push_back
(
rpc_client
->
AsyncDistributeNotify
(
endpoint
,
cpu_ctx
,
*
local_scope
.
get
(),
send_var_name
));
VLOG
(
4
)
<<
"send var "
<<
send_var_name
<<
" async handle done"
;
}
else
{
VLOG
(
3
)
<<
"don't send non-initialized variable: "
<<
rpc_ctx
.
splited_var_names
[
i
];
}
}
}
}
}
else
if
(
send_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
send_slr
=
send_var
->
Get
<
framework
::
SelectedRows
>
();
auto
abs_sections
=
ToAbsoluteSection
(
rpc_ctx
.
height_sections
);
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
781d2844
...
...
@@ -63,7 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_
MESSAGE "LRDECAY@RECV
"
#define LEARNING_RATE_DECAY_
COUNTER "@LR_DECAY_COUNTER@
"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
781d2844
...
...
@@ -262,11 +262,25 @@ bool RequestNotifyHandler::Handle(const std::string& varname,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestNotifyHandler"
<<
varname
;
if
(
varname
==
LEARNING_RATE_DECAY_MESSAGE
)
{
VLOG
(
4
)
<<
"RequestNotifyHandler: "
<<
varname
;
VLOG
(
3
)
<<
"async process var: "
<<
varname
<<
", trainer_id: "
<<
trainer_id
;
string
::
Piece
decay_piece
(
LEARNING_RATE_DECAY_COUNTER
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
if
(
string
::
Contains
(
var_name_piece
,
decay_piece
))
{
VLOG
(
3
)
<<
"LearningRate Decay Counter Update"
;
PADDLE_ENFORCE_NE
(
lr_decay_block_id
,
-
1
,
"when lr_decay_block_id = -1, there should be no RPC invoke."
);
auto
*
origin_var
=
scope_
->
FindVar
(
varname
);
auto
origin_var_tensor
=
origin_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
send_var
=
scope
->
FindVar
(
varname
);
auto
send_var_tensor
=
send_var
->
Get
<
framework
::
LoDTensor
>
();
int64_t
*
origin_value
=
origin_var_tensor
.
mutable_data
<
int64_t
>
(
origin_var_tensor
.
place
());
int64_t
*
send_value
=
send_var_tensor
.
mutable_data
<
int64_t
>
(
send_var_tensor
.
place
());
origin_value
[
0
]
+=
send_value
[
0
];
executor_
->
RunPreparedContext
(
lr_decay_prepared_ctx_
.
get
(),
scope_
);
}
return
true
;
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
781d2844
...
...
@@ -81,7 +81,8 @@ class RPCClient {
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
type
,
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncSendComplete
(
...
...
paddle/fluid/operators/distributed/rpc_common.h
浏览文件 @
781d2844
...
...
@@ -27,12 +27,15 @@ struct RpcContext {
RpcContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
,
int
id
)
const
std
::
vector
<
int64_t
>
&
sections
,
int
id
,
bool
merge_add_
=
true
,
bool
use_send_handler_
=
true
)
:
var_name
(
name
),
splited_var_names
(
names
),
epmap
(
emap
),
height_sections
(
sections
),
trainer_id
(
id
)
{}
trainer_id
(
id
),
merge_add
(
merge_add_
),
use_send_handler
(
use_send_handler_
)
{}
RpcContext
(
const
RpcContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
...
...
@@ -40,6 +43,8 @@ struct RpcContext {
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
trainer_id
=
ctx
.
trainer_id
;
merge_add
=
ctx
.
merge_add
;
use_send_handler
=
ctx
.
use_send_handler
;
}
std
::
string
var_name
;
...
...
@@ -47,6 +52,8 @@ struct RpcContext {
std
::
vector
<
std
::
string
>
epmap
;
std
::
vector
<
int64_t
>
height_sections
;
int
trainer_id
;
bool
merge_add
;
bool
use_send_handler
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RpcContext
&
rpc_ctx
)
{
...
...
@@ -70,6 +77,9 @@ inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os
<<
section
<<
", "
;
}
os
<<
"]
\n
"
;
os
<<
"merge add: "
<<
rpc_ctx
.
merge_add
;
os
<<
"; send handler: "
<<
rpc_ctx
.
use_send_handler
<<
"
\n
"
;
os
<<
"}"
;
return
os
;
}
...
...
paddle/fluid/operators/distributed_ops/distributed_notify_op.cc
已删除
100644 → 0
浏览文件 @
55c2329a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
operators
{
class
DistributedNotifyOp
:
public
framework
::
OperatorBase
{
public:
DistributedNotifyOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
string
type
=
Attr
<
std
::
string
>
(
"type"
);
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
rpc_client
->
AsyncDistributeNotify
(
epmap
[
i
],
type
);
VLOG
(
4
)
<<
"distribute notify sending : "
<<
type
<<
" to "
<<
epmap
[
i
];
}
PADDLE_ENFORCE_EQ
(
rpc_client
->
Wait
(),
true
,
"internal error in RPCClient"
);
}
};
class
DistributedNotifyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order"
)
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
std
::
string
>
(
"type"
,
"(string, default '') indicate the action type"
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
DistributeNotify operator
This operator will send a signal to listen_and_serve op at
the parameter server.
)DOC"
);
}
};
class
DistributedNotifyOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
distributed_notify
,
ops
::
DistributedNotifyOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
DistributedNotifyOpMaker
,
ops
::
DistributedNotifyOpShapeInference
);
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
781d2844
...
...
@@ -383,7 +383,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
request_notify_handler_
.
get
(),
1
);
request_notify_handler_
.
get
(),
FLAGS_rpc_send_thread_num
);
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
781d2844
...
...
@@ -45,6 +45,7 @@ class SendOp : public framework::OperatorBase {
auto
send_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"send_varnames"
);
auto
height_sections
=
Attr
<
std
::
vector
<
int64_t
>>
(
"sections"
);
auto
use_send_handler
=
Attr
<
bool
>
(
"use_send_handler"
);
if
(
send_varnames
.
size
()
>
0
)
{
if
(
ins
.
size
()
>
1
)
{
...
...
@@ -62,6 +63,7 @@ class SendOp : public framework::OperatorBase {
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
if
(
use_send_handler
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
...
...
@@ -71,6 +73,19 @@ class SendOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
epmap
.
size
();
j
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
j
];
rets
.
push_back
(
rpc_client
->
AsyncDistributeNotify
(
epmap
[
j
],
ctx
,
scope
,
ins
[
i
]));
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
}
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
VLOG
(
7
)
<<
"before sync_send "
<<
ins
[
i
]
<<
"from "
<<
epmap
[
i
];
PADDLE_ENFORCE_NE
(
rets
[
i
]
->
Wait
(),
0U
,
"internal error in RPCClient"
);
...
...
@@ -113,6 +128,15 @@ This operator will send variables to listen_and_serve op at the parameter server
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]"
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"merge_add"
,
"(bool, default 0)"
"merge method, true represent add, false represent average"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_send_handler"
,
"(bool, default 1)"
"if it's true, use send handler, other wise, use notify handler"
)
.
SetDefault
(
true
);
}
};
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py
0 → 100644
浏览文件 @
781d2844
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
gc
import
paddle.fluid
as
fluid
class
TranspilerAsyncLRDecayTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
trainer_id
=
0
self
.
trainers
=
2
self
.
pservers
=
2
# NOTE: we do not actually bind this port
self
.
pserver_eps
=
"127.0.0.1:6174,127.0.0.1:6175"
self
.
pserver1_ep
=
"127.0.0.1:6174"
self
.
pserver2_ep
=
"127.0.0.1:6175"
self
.
sync_mode
=
False
self
.
transpiler
=
None
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1000
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'fc_w'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'fc_b'
))
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
0.1
,
decay_steps
=
100
,
decay_rate
=
0.99
,
staircase
=
True
))
sgd_optimizer
.
minimize
(
avg_cost
)
def
get_main_program
(
self
):
main
=
fluid
.
Program
()
main
.
random_seed
=
1
with
fluid
.
program_guard
(
main
):
self
.
net_conf
()
self
.
origin_prog
=
main
.
clone
()
return
main
def
get_trainer
(
self
,
config
=
None
):
src
=
fluid
.
default_startup_program
().
clone
()
t
=
self
.
_transpiler_instance
(
config
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_startup
=
fluid
.
default_startup_program
()
assert
(
src
.
num_blocks
==
1
)
assert
(
trainer_startup
.
num_blocks
==
src
.
num_blocks
)
return
trainer_main
,
trainer_startup
def
get_pserver
(
self
,
ep
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
pserver
=
t
.
get_pserver_program
(
ep
)
startup
=
t
.
get_startup_program
(
ep
,
pserver
)
return
pserver
,
startup
def
_transpiler_instance
(
self
,
config
=
None
,
sync_mode
=
True
):
if
not
self
.
transpiler
:
main
=
self
.
get_main_program
()
self
.
transpiler
=
fluid
.
DistributeTranspiler
(
config
=
config
)
self
.
transpiler
.
transpile
(
self
.
trainer_id
,
program
=
main
,
pservers
=
self
.
pserver_eps
,
trainers
=
self
.
trainers
,
sync_mode
=
sync_mode
)
return
self
.
transpiler
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
sync_mode
=
False
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
sync_mode
=
False
)
trainer
,
trainer_startup
=
self
.
get_trainer
()
src
=
[
op
.
type
for
op
in
trainer_startup
.
global_block
().
ops
]
dst
=
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
\
'uniform_random'
,
'recv'
,
'recv'
,
'fetch_barrier'
,
'concat'
]
self
.
assertEqual
(
src
,
dst
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'split_byref'
,
'send'
,
'send'
,
'recv'
,
'recv'
,
'concat'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
4
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
# block1: sum,cast,scale,floor,fill_constant,elementwise_pow,scale
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
1
].
ops
],
[
"sum"
,
"cast"
,
"scale"
,
"floor"
,
"fill_constant"
,
"elementwise_pow"
,
"scale"
])
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sgd"
])
# confirm startup program
self
.
assertEqual
([
op
.
type
for
op
in
startup
.
global_block
().
ops
],
[
"fill_constant"
,
"fill_constant"
,
"fill_constant"
,
"fill_constant"
,
"uniform_random"
])
def
test_transpiler
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del
self
.
transpiler
del
main
del
startup
gc
.
collect
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
781d2844
...
...
@@ -41,7 +41,7 @@ import logging
import
numpy
as
np
from
.ps_dispatcher
import
RoundRobin
,
PSDispatcher
from
..
import
core
,
framework
,
unique_name
from
..
import
core
,
framework
,
unique_name
,
initializer
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
Block
,
Parameter
,
grad_var_name
from
.details
import
wait_server_ready
,
UnionFind
,
VarStruct
,
VarsDistributed
...
...
@@ -304,6 +304,7 @@ class DistributeTranspiler(object):
PRINT_LOG
=
True
assert
(
self
.
config
.
min_block_size
>=
8192
)
assert
(
self
.
config
.
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
self
.
counter_var
=
None
def
_transpile_nccl2
(
self
,
trainer_id
,
...
...
@@ -631,6 +632,7 @@ class DistributeTranspiler(object):
np
.
random
.
shuffle
(
grad_var_mapping_items
)
self
.
grad_name_to_send_dummy_out
=
dict
()
for
grad_varname
,
splited_vars
in
grad_var_mapping_items
:
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
...
...
@@ -720,6 +722,31 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input
.
append
(
send_barrier_out
)
else
:
lr_ops
=
self
.
_get_lr_ops
()
if
len
(
lr_ops
)
>
0
and
self
.
counter_var
:
decay_dummy_output
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
if
self
.
config
.
runtime_split_send_recv
:
## async mode, using communicator to merge and send
send_varnames
=
[
self
.
counter_var
.
name
]
else
:
send_varnames
=
[]
sections
=
[]
program
.
global_block
().
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
self
.
counter_var
},
outputs
=
{
"Out"
:
decay_dummy_output
},
attrs
=
{
"epmap"
:
pserver_endpoints
,
"sections"
:
sections
,
"send_varnames"
:
send_varnames
,
"merge_add"
:
True
,
"use_send_handler"
:
False
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
self
.
counter_var
.
name
,
self
.
counter_var
.
name
]
})
# step 3: insert recv op to receive parameters from parameter server
recv_vars
=
[]
...
...
@@ -821,19 +848,6 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME
:
DIST_OP_ROLE_ATTR_VALUE
})
if
not
self
.
sync_mode
:
lr_ops
=
self
.
_get_lr_ops
()
if
len
(
lr_ops
)
>
0
:
program
.
global_block
().
append_op
(
type
=
"distributed_notify"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"epmap"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
"type"
:
"LRDECAY@RECV"
})
self
.
_get_trainer_startup_program
(
recv_vars
=
recv_vars
,
eplist
=
eplist
)
if
self
.
has_distributed_lookup_table
:
...
...
@@ -2380,11 +2394,57 @@ class DistributeTranspiler(object):
def
_get_lr_ops
(
self
):
lr_ops
=
[]
block
=
self
.
origin_program
.
global_block
()
for
op
in
block
.
ops
:
for
index
,
op
in
enumerate
(
block
.
ops
)
:
role_id
=
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
if
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
or
\
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
|
\
int
(
OPT_OP_ROLE_ATTR_VALUE
):
if
self
.
sync_mode
==
False
and
op
.
type
==
'increment'
:
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
op
)
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
op
)
for
key
in
outputs
:
counter_var
=
outputs
[
key
]
all_trainer_counter_inputs
=
[
self
.
origin_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
counter_var
.
name
,
id_
),
type
=
counter_var
.
type
,
shape
=
counter_var
.
shape
,
dtype
=
counter_var
.
dtype
,
persistable
=
counter_var
.
persistable
)
for
id_
in
range
(
self
.
trainer_num
)
]
for
i
,
op
in
enumerate
(
self
.
startup_program
.
global_block
()
.
ops
):
if
op
.
type
==
'fill_constant'
:
for
key
in
op
.
output_names
:
if
len
(
op
.
output
(
key
))
==
1
and
op
.
output
(
key
)[
0
]
==
counter_var
.
name
:
self
.
startup_program
.
global_block
().
ops
[
i
].
_set_attr
(
'value'
,
float
(
0.0
-
self
.
trainer_num
))
for
var
in
all_trainer_counter_inputs
:
if
var
.
name
==
"%s.trainer_%d"
%
(
counter_var
.
name
,
self
.
trainer_id
):
self
.
counter_var
=
var
self
.
startup_program
.
global_block
().
create_var
(
name
=
var
.
name
,
type
=
var
.
type
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
,
persistable
=
var
.
persistable
,
initializer
=
initializer
.
Constant
(
1
))
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
block
.
_remove_op
(
index
)
op
=
block
.
_insert_op
(
index
,
type
=
'sum'
,
inputs
=
{
'X'
:
all_trainer_counter_inputs
},
outputs
=
outputs
,
attrs
=
{
op_role_attr_name
:
LR_SCHED_OP_ROLE_ATTR_VALUE
})
lr_ops
.
append
(
op
)
log
(
"append lr op: "
,
op
.
type
)
return
lr_ops
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录