Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bfa7b3ee
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bfa7b3ee
编写于
6年前
作者:
Q
Qiao Longfei
提交者:
GitHub
6年前
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12907 from jacquesqiao/cherry-pick-add-dependency-to-send-recv
Cherry pick add dependency to send recv
上级
482d297b
86776a1f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
107 addition
and
40 deletion
+107
-40
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+22
-4
paddle/fluid/framework/ir/node.cc
paddle/fluid/framework/ir/node.cc
+1
-1
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+1
-1
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+2
-0
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+5
-9
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+2
-0
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+4
-1
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+6
-0
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+19
-4
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+45
-20
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
bfa7b3ee
...
...
@@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
// put them into transpiler.
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
...
...
@@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver
// instances.
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
&&
node
->
inputs
[
0
]
->
Name
().
find
(
".block"
)
==
std
::
string
::
npos
)
{
std
::
vector
<
std
::
string
>
input_var_names
;
for
(
ir
::
Node
*
n
:
node
->
inputs
)
{
input_var_names
.
push_back
(
n
->
Name
());
}
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
auto
send_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
send_param_grad
.
size
(),
2U
);
op_dev_id
=
GetAppropriateDeviceID
({
send_param_grad
[
1
]});
VLOG
(
10
)
<<
"send grad "
<<
input_var_names
[
0
]
<<
" origin "
<<
send_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
for
(
auto
&
varname
:
input_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
send_param_grad
[
1
],
op_dev_id
);
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
std
::
vector
<
std
::
string
>
output_var_names
;
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
output_var_names
.
push_back
(
n
->
Name
());
}
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
// FIXME(typhoonzero): assume each recv op output one param
// Use the same place as send.
if
(
recv_param_grad
.
size
()
==
2U
)
{
op_dev_id
=
GetVarDeviceID
(
*
result
,
recv_param_grad
[
1
]);
VLOG
(
10
)
<<
"recv param "
<<
recv_param_grad
[
0
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
}
else
{
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
}
for
(
auto
&
varname
:
output_var_names
)
{
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/framework/ir/node.cc
浏览文件 @
bfa7b3ee
...
...
@@ -17,7 +17,7 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
const
char
Node
::
kControlDepVarName
[]
=
"__control_var"
;
const
expr
char
Node
::
kControlDepVarName
[]
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
This diff is collapsed.
Click to expand it.
paddle/fluid/framework/ir/node.h
浏览文件 @
bfa7b3ee
...
...
@@ -27,7 +27,7 @@ namespace ir {
class
Node
{
public:
enum
class
Type
{
kOperation
,
kVariable
};
static
const
char
kControlDepVarName
[]
;
static
const
expr
char
kControlDepVarName
[]
=
"__control_var"
;
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/recv_op.cc
浏览文件 @
bfa7b3ee
...
...
@@ -57,6 +57,8 @@ class RecvOp : public framework::OperatorBase {
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Any) Dummy inputs, used for control dependency"
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Variables to get from server."
).
AsDuplicable
();
AddComment
(
R"DOC(
Recv operator
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
bfa7b3ee
...
...
@@ -37,22 +37,19 @@ class SendBarrierOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
VLOG
(
3
)
<<
"SendBarrierOp sync
_mode:"
<<
sync_mode
;
VLOG
(
3
)
<<
"SendBarrierOp sync
"
;
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
(),
"internal error in RPCClient"
);
if
(
sync_mode
)
{
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
(),
"internal error in RPCClient"
);
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
(),
"internal error in RPCClient"
);
}
};
...
...
@@ -70,7 +67,6 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
bool
>
(
"sync_mode"
,
"work in sync_mode or not"
).
SetDefault
(
true
);
}
};
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/send_op.cc
浏览文件 @
bfa7b3ee
...
...
@@ -66,6 +66,8 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
{
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Any) Dummy outputs, used for control dependency"
)
.
AsDuplicable
();
AddComment
(
R"DOC(
Send operator
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/pybind/const_value.cc
浏览文件 @
bfa7b3ee
...
...
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include <paddle/fluid/framework/op_proto_maker.h>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
...
...
@@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) {
m
->
def
(
"kTempVarName"
,
[]
{
return
framework
::
kTempVarName
;
});
m
->
def
(
"kGradVarSuffix"
,
[]
{
return
framework
::
kGradVarSuffix
;
});
m
->
def
(
"kZeroVarSuffix"
,
[]
{
return
framework
::
kZeroVarSuffix
;
});
m
->
def
(
"kControlDepVarName"
,
[]
{
return
framework
::
ir
::
Node
::
kControlDepVarName
;
});
auto
op_proto_and_checker_maker
=
m
->
def_submodule
(
"op_proto_and_checker_maker"
);
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/framework.py
浏览文件 @
bfa7b3ee
...
...
@@ -49,6 +49,12 @@ EMPTY_VAR_NAME = core.kEmptyVarName()
TEMP_VAR_NAME
=
core
.
kTempVarName
()
GRAD_VAR_SUFFIX
=
core
.
kGradVarSuffix
()
ZERO_VAR_SUFFIX
=
core
.
kZeroVarSuffix
()
CONTROL_DEP_VAR_PREFIX
=
core
.
kControlDepVarName
()
def
generate_control_dev_var_name
():
import
random
return
CONTROL_DEP_VAR_PREFIX
+
"@"
+
str
(
random
.
random
())
def
grad_var_name
(
var_name
):
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/layers/io.py
浏览文件 @
bfa7b3ee
...
...
@@ -24,7 +24,7 @@ from .layer_function_generator import templatedoc
from
..
import
core
from
..executor
import
global_scope
from
..framework
import
convert_np_dtype_to_dtype_
,
default_main_program
,
\
default_startup_program
,
program_guard
,
Program
default_startup_program
,
program_guard
,
Program
,
Variable
from
..layer_helper
import
LayerHelper
from
..unique_name
import
generate
as
unique_name
...
...
@@ -209,7 +209,7 @@ class ListenAndServ(object):
})
def
Send
(
endpoints
,
send_vars
,
sync
=
True
):
def
Send
(
endpoints
,
send_vars
,
dummy_output
=
None
,
sync
=
True
):
"""
Send variables to the server side, and get vars from server
side when server have finished running server side program.
...
...
@@ -223,6 +223,13 @@ def Send(endpoints, send_vars, sync=True):
"""
assert
(
type
(
send_vars
)
==
list
)
if
dummy_output
is
None
:
dummy_output
=
[]
elif
isinstance
(
dummy_output
,
Variable
):
dummy_output
=
[
dummy_output
]
assert
(
type
(
dummy_output
)
==
list
)
epmap
=
endpoints
.
split
(
","
)
endpoints
=
list
(
set
(
epmap
))
...
...
@@ -232,6 +239,7 @@ def Send(endpoints, send_vars, sync=True):
helper
.
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
send_vars
},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
,
...
...
@@ -241,7 +249,7 @@ def Send(endpoints, send_vars, sync=True):
helper
.
append_op
(
type
=
"send_barrier"
,
attrs
=
{
"endpoints"
:
endpoints
})
def
Recv
(
endpoints
,
get_vars
,
sync
=
True
):
def
Recv
(
endpoints
,
get_vars
,
dummy_input
=
None
,
sync
=
True
):
"""
Receive variables from server side
...
...
@@ -256,13 +264,20 @@ def Recv(endpoints, get_vars, sync=True):
"""
assert
(
type
(
get_vars
)
==
list
)
if
dummy_input
is
None
:
dummy_input
=
[]
elif
isinstance
(
dummy_input
,
Variable
):
dummy_input
=
[
dummy_input
]
assert
(
type
(
dummy_input
)
==
list
)
epmap
=
endpoints
.
split
(
","
)
endpoints
=
list
(
set
(
epmap
))
helper
=
LayerHelper
(
"Recv"
,
**
locals
())
helper
.
append_op
(
type
=
"recv"
,
inputs
=
{
"X"
:
get_vars
},
inputs
=
{
"X"
:
dummy_input
},
outputs
=
{
"Out"
:
get_vars
},
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
bfa7b3ee
...
...
@@ -210,6 +210,11 @@ class DistributeTranspiler(object):
ps_dispatcher
=
self
.
config
.
split_method
(
self
.
pserver_endpoints
)
self
.
has_distributed_lookup_table
=
self
.
_has_distributed_lookup_table
()
self
.
param_name_to_grad_name
=
dict
()
self
.
grad_name_to_param_name
=
dict
()
for
param_var
,
grad_var
in
self
.
params_grads
:
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
self
.
grad_name_to_param_name
[
grad_var
.
name
]
=
param_var
.
name
# step 1: split and create vars, then put splited vars in dicts for later use.
self
.
_init_splited_vars
()
...
...
@@ -229,34 +234,43 @@ class DistributeTranspiler(object):
random
.
seed
(
self
.
origin_program
.
random_seed
)
random
.
shuffle
(
grad_var_mapping_items
)
for
orig_varname
,
splited_vars
in
grad_var_mapping_items
:
grad_name_to_send_dummy_out
=
dict
()
for
grad_varname
,
splited_vars
in
grad_var_mapping_items
:
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
if
not
self
.
config
.
slice_var_up
:
assert
(
len
(
splited_vars
)
==
1
)
splited_grad_varname
=
grad_varname
if
len
(
splited_vars
)
==
1
:
orig
_varname
=
splited_vars
[
0
].
name
splited_grad
_varname
=
splited_vars
[
0
].
name
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig
_varname
)
splited_grad
_varname
)
elif
len
(
splited_vars
)
>
1
:
orig_var
=
program
.
global_block
().
vars
[
orig
_varname
]
orig_var
=
program
.
global_block
().
vars
[
splited_grad
_varname
]
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig
_varname
)
splited_grad
_varname
)
self
.
_insert_split_op
(
program
,
orig_var
,
index
,
splited_vars
)
index
+=
1
else
:
AssertionError
(
"Can not insert the send op by original "
"variable name :"
,
orig_varname
)
"variable name :"
,
splited_grad_varname
)
dummy_output
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
grad_name_to_send_dummy_out
[
grad_varname
]
=
dummy_output
program
.
global_block
().
_insert_op
(
index
=
index
+
1
,
type
=
"send"
,
inputs
=
{
"X"
:
splited_vars
},
outputs
=
{},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
"epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
self
.
grad_name_to_param_name
[
grad_varname
],
grad_varname
],
"sync_mode"
:
not
self
.
sync_mode
,
})
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
...
...
@@ -268,7 +282,6 @@ class DistributeTranspiler(object):
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"sync_mode"
:
self
.
sync_mode
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -284,19 +297,25 @@ class DistributeTranspiler(object):
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
for
param_
varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
eps
=
[]
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
eps
.
append
(
eplist
[
index
])
grad_send_dummy_out
=
grad_name_to_send_dummy_out
[
self
.
param_name_to_grad_name
[
param_varname
]]
program
.
global_block
().
append_op
(
type
=
"recv"
,
inputs
=
{},
inputs
=
{
"X"
:
[
grad_send_dummy_out
]
},
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
param_varname
,
self
.
param_name_to_grad_name
[
param_varname
]
],
"sync_mode"
:
not
self
.
sync_mode
})
if
self
.
sync_mode
:
...
...
@@ -309,10 +328,10 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
for
param_
varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
if
len
(
splited_var
)
<=
1
:
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
orig_param
=
program
.
global_block
().
vars
[
param_
varname
]
program
.
global_block
().
append_op
(
type
=
"concat"
,
inputs
=
{
"X"
:
splited_var
},
...
...
@@ -380,7 +399,7 @@ class DistributeTranspiler(object):
op
=
startup_program
.
global_block
().
append_op
(
type
=
"recv"
,
inputs
=
{},
inputs
=
{
"X"
:
[]
},
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
...
...
@@ -786,19 +805,21 @@ class DistributeTranspiler(object):
self
.
config
.
min_block_size
)
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
# origin_
varname -> [splited_var
]
# origin_
param_name -> [splited_param_vars
]
self
.
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
self
.
origin_program
,
param_blocks
)
# origin_grad_name -> [splited_grad_vars]
self
.
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
self
.
origin_program
,
grad_blocks
,
add_trainer_suffix
=
self
.
trainer_num
>
1
)
# dict(grad_splited_var -> param_splited_var)
self
.
grad_param_mapping
=
collections
.
OrderedDict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
self
.
grad_param_mapping
[
self
.
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
self
.
param_var_mapping
[
p_name
][
int
(
p_bid
)]
self
.
param_var_mapping
[
p_name
][
int
(
p_bid
)]
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
collections
.
OrderedDict
()
...
...
@@ -919,11 +940,15 @@ class DistributeTranspiler(object):
index
=
op_index
+
2
,
type
=
"send"
,
inputs
=
{
'X'
:
self
.
trainer_side_table_grad_list
},
outputs
=
{},
outputs
=
{
'Out'
:
[]
},
attrs
=
{
"sync_mode"
:
True
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
self
.
grad_name_to_param_name
[
table_grad_name
],
table_grad_name
]
})
break
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录