Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
caf9a09d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
caf9a09d
编写于
2月 11, 2018
作者:
Y
Yancey
提交者:
GitHub
2月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Merge selected rows with dynamic variable count (#8023)
* dynamic send/recv selected rows * update by comment * fix by comment
上级
4f4abfa3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
47 addition
and
25 deletion
+47
-25
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+16
-0
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+22
-2
paddle/fluid/operators/split_selected_rows_op.cc
paddle/fluid/operators/split_selected_rows_op.cc
+1
-22
paddle/fluid/operators/split_selected_rows_op.h
paddle/fluid/operators/split_selected_rows_op.h
+1
-0
paddle/fluid/operators/sum_op.h
paddle/fluid/operators/sum_op.h
+3
-1
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+4
-0
未找到文件。
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
caf9a09d
...
...
@@ -101,6 +101,9 @@ class ListenAndServOp : public framework::OperatorBase {
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
...
...
@@ -143,6 +146,9 @@ class ListenAndServOp : public framework::OperatorBase {
PADDLE_THROW
(
"Can not find server side var"
);
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
}
}
}
VLOG
(
3
)
<<
"recv "
<<
recv_var_cnt
<<
" parmeters for one barrier."
;
...
...
@@ -156,9 +162,19 @@ class ListenAndServOp : public framework::OperatorBase {
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for
(
auto
&
var
:
sparse_vars
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
update_param_cnt
);
grads_counter_
.
clear
();
sparse_vars
.
clear
();
}
// while(true)
}
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
caf9a09d
...
...
@@ -24,6 +24,22 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
static
bool
IsVariableInitialized
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
varname
)
{
auto
*
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
varname
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
return
var
->
Get
<
framework
::
LoDTensor
>
().
IsInitialized
();
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
return
var
->
Get
<
framework
::
SelectedRows
>
().
value
().
IsInitialized
();
}
else
{
PADDLE_THROW
(
"Variable type in send side should be in "
"[LodTensor, SelectedRows]"
);
}
return
false
;
}
class
SendOp
:
public
framework
::
OperatorBase
{
public:
...
...
@@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase {
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
rpc_client
->
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
if
(
IsVariableInitialized
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
rpc_client
->
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
...
...
paddle/fluid/operators/split_selected_rows_op.cc
浏览文件 @
caf9a09d
...
...
@@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
SplitSelectedRowsOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input SelectedRows."
);
AddOutput
(
"Out"
,
"The outputs of input SelectedRows."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"The outputs of
the
input SelectedRows."
).
AsDuplicable
();
AddAttr
<
std
::
vector
<
int
>>
(
"height_sections"
,
"Height for each output SelectedRows."
)
.
SetDefault
(
std
::
vector
<
int
>
({}));
...
...
@@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"SplitSelectedRowsOp must has input X."
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
"SplitSelectedRowsOp must has output Out."
);
std
::
vector
<
int
>
height_sections
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"height_sections"
);
int64_t
n
=
ctx
->
Outputs
(
"Out"
).
size
();
std
::
vector
<
framework
::
DDim
>
outs_dims
;
outs_dims
.
reserve
(
n
);
// make output dims
for
(
int64_t
i
=
0
;
i
<
n
;
++
i
)
{
auto
dims
=
ctx
->
GetInputDim
(
"X"
);
if
(
height_sections
.
size
())
{
PADDLE_ENFORCE_EQ
(
height_sections
.
size
(),
static_cast
<
size_t
>
(
n
),
"The size of height section should be the same with height"
" section size."
);
dims
[
0
]
=
height_sections
[
i
];
}
outs_dims
.
push_back
(
dims
);
}
ctx
->
SetOutputsDim
(
"Out"
,
outs_dims
);
}
};
...
...
paddle/fluid/operators/split_selected_rows_op.h
浏览文件 @
caf9a09d
...
...
@@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
for
(
size_t
i
=
0
;
i
<
outs_rows_idx
.
size
();
++
i
)
{
auto
rows_idx
=
outs_rows_idx
[
i
];
outs
[
i
]
->
set_height
(
height_sections
[
i
]);
if
(
rows_idx
.
size
()
>
0
)
{
auto
dims
=
x
->
GetCompleteDims
();
dims
[
0
]
=
rows_idx
.
size
();
...
...
paddle/fluid/operators/sum_op.h
浏览文件 @
caf9a09d
...
...
@@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel<T> {
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
auto
&
sel_row
=
get_selected_row
(
i
);
if
(
!
sel_row
.
value
().
IsInitialized
()
||
sel_row
.
rows
().
size
()
==
0
)
{
continue
;
}
PADDLE_ENFORCE_EQ
(
out
->
height
(),
sel_row
.
height
());
functor
(
context
.
template
device_context
<
DeviceContext
>(),
sel_row
,
offset
,
out
);
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
caf9a09d
...
...
@@ -191,6 +191,7 @@ class DistributeTranspiler:
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_outputs
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist
=
split_method
(
send_inputs
,
pserver_endpoints
)
...
...
@@ -274,6 +275,7 @@ class DistributeTranspiler:
name
=
"%s.block%d"
%
(
varname
,
i
),
psersistable
=
False
,
dtype
=
orig_var
.
dtype
,
type
=
orig_var
.
type
,
shape
=
splited_shape
)
# flattend splited var
var_mapping
[
varname
].
append
(
var
)
return
var_mapping
...
...
@@ -335,6 +337,7 @@ class DistributeTranspiler:
name
=
"%s.trainer_%d"
%
(
var
.
name
,
i
),
psersistable
=
var
.
persistable
,
dtype
=
var
.
dtype
,
type
=
var
.
type
,
shape
=
var
.
shape
)
var_list
.
append
(
var_each
)
return
var_list
...
...
@@ -561,6 +564,7 @@ class DistributeTranspiler:
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
# step6
optimize_block
=
pserver_program
.
create_block
(
0
)
# step 6.1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录