Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9508c726
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9508c726
编写于
12月 13, 2017
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
wip: should fix variable recreate
上级
b4cd7f3d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
103 addition
and
78 deletion
+103
-78
paddle/framework/executor.cc
paddle/framework/executor.cc
+26
-24
paddle/framework/executor.h
paddle/framework/executor.h
+2
-1
paddle/operators/detail/recv_impl.cc
paddle/operators/detail/recv_impl.cc
+8
-3
paddle/operators/detail/send_impl.cc
paddle/operators/detail/send_impl.cc
+19
-4
paddle/operators/detail/send_recv.proto
paddle/operators/detail/send_recv.proto
+3
-1
paddle/operators/detail/send_recv_impl.h
paddle/operators/detail/send_recv_impl.h
+5
-3
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+30
-39
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+8
-1
python/paddle/v2/fluid/executor.py
python/paddle/v2/fluid/executor.py
+1
-2
python/paddle/v2/fluid/tests/book/test_recognize_digits_conv_dist.py
...le/v2/fluid/tests/book/test_recognize_digits_conv_dist.py
+1
-0
未找到文件。
paddle/framework/executor.cc
浏览文件 @
9508c726
...
...
@@ -85,7 +85,7 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
}
void
Executor
::
Run
(
const
ProgramDescBind
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
)
{
bool
create_local_scope
,
bool
create_vars
)
{
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
...
...
@@ -94,33 +94,35 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
auto
&
device
=
device_contexts_
[
0
];
Scope
*
local_scope
=
scope
;
if
(
create_local_scope
)
{
local_scope
=
&
scope
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Name
()
==
framework
::
kEmptyVarName
)
{
continue
;
if
(
create_vars
)
{
if
(
create_local_scope
)
{
local_scope
=
&
scope
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Name
()
==
framework
::
kEmptyVarName
)
{
continue
;
}
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" global, which pointer is "
<<
ptr
;
}
else
{
auto
*
ptr
=
local_scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" locally, which pointer is "
<<
ptr
;
}
}
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" global, which pointer is "
<<
ptr
;
}
else
{
}
else
{
for
(
auto
&
var
:
block
.
AllVars
())
{
auto
*
ptr
=
local_scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create
Variable "
<<
var
->
Name
()
<<
" locally, which pointer is "
<<
ptr
;
VLOG
(
3
)
<<
"Create
variable "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
}
}
else
{
for
(
auto
&
var
:
block
.
AllVars
())
{
auto
*
ptr
=
local_scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
}
}
// if (create_local_scope)
}
// if (create_vars)
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
...
...
paddle/framework/executor.h
浏览文件 @
9508c726
...
...
@@ -35,7 +35,8 @@ class Executor {
* ProgramDesc
* Scope
*/
void
Run
(
const
ProgramDescBind
&
,
Scope
*
,
int
,
bool
create_local_scope
=
true
);
void
Run
(
const
ProgramDescBind
&
,
Scope
*
,
int
,
bool
create_local_scope
=
true
,
bool
create_vars
=
true
);
private:
std
::
vector
<
const
platform
::
DeviceContext
*>
device_contexts_
;
...
...
paddle/operators/detail/recv_impl.cc
浏览文件 @
9508c726
...
...
@@ -20,7 +20,7 @@ namespace detail {
Status
SendRecvServerImpl
::
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
V
ariable
Message
*
out_var
)
{
V
oid
Message
*
out_var
)
{
// TODO(typhoonzero): support different variable types.
std
::
istringstream
iss
(
in_var
->
serialized
());
framework
::
LoDTensor
t
;
...
...
@@ -29,6 +29,12 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
std
::
make_pair
(
in_var
->
varname
(),
std
::
move
(
t
));
var_recv_queue_
.
Push
(
std
::
move
(
tensor_with_name
));
return
Status
::
OK
;
}
Status
SendRecvServerImpl
::
GetVariable
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VariableMessage
*
out_var
)
{
// Block util the sub graph is done.
auto
out_tensor_with_name
=
var_return_queue_
.
Pop
();
std
::
ostringstream
oss
;
...
...
@@ -36,10 +42,9 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
platform
::
CPUDeviceContext
());
std
::
string
*
varname
=
out_var
->
mutable_varname
();
*
varname
=
in_var
->
varname
()
;
*
varname
=
out_tensor_with_name
.
first
;
std
::
string
*
serialized
=
out_var
->
mutable_serialized
();
*
serialized
=
oss
.
str
();
return
Status
::
OK
;
}
...
...
paddle/operators/detail/send_impl.cc
浏览文件 @
9508c726
...
...
@@ -19,10 +19,10 @@ namespace operators {
namespace
detail
{
bool
RPCClient
::
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
,
const
std
::
string
&
outname
)
{
const
std
::
string
&
inname
)
{
ClientContext
context
;
VariableMessage
msg
,
out_msg
;
VariableMessage
msg
;
VoidMessage
out_msg
;
// FIXME(typhoonzero): pass device context to here.
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
*
var
=
scope
.
FindVar
(
inname
);
...
...
@@ -40,7 +40,22 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
return
false
;
}
std
::
istringstream
iss
(
out_msg
.
serialized
());
return
true
;
}
bool
RPCClient
::
GetVariable
(
const
framework
::
Scope
&
scope
)
{
ClientContext
context
;
VariableMessage
msg
;
VoidMessage
void_msg
;
auto
ctx
=
platform
::
CPUDeviceContext
();
Status
status
=
stub_
->
GetVariable
(
&
context
,
void_msg
,
&
msg
);
if
(
!
status
.
ok
())
{
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
return
false
;
}
std
::
istringstream
iss
(
msg
.
serialized
());
auto
outname
=
msg
.
varname
();
framework
::
LoDTensor
ret_tensor
;
framework
::
DeserializeFromStream
(
iss
,
&
ret_tensor
);
auto
*
outvar
=
scope
.
FindVar
(
outname
);
...
...
paddle/operators/detail/send_recv.proto
浏览文件 @
9508c726
...
...
@@ -20,7 +20,9 @@ service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc
SendVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
// Argument VariableMessage for GetVariable should only contain varname.
rpc
GetVariable
(
VoidMessage
)
returns
(
VariableMessage
)
{}
}
// VariableMessage is serialized paddle variable message.
...
...
paddle/operators/detail/send_recv_impl.h
浏览文件 @
9508c726
...
...
@@ -55,7 +55,9 @@ class SendRecvServerImpl final : public SendRecvService::Service {
explicit
SendRecvServerImpl
()
{}
Status
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VariableMessage
*
out_var
)
override
;
VoidMessage
*
out_var
)
override
;
Status
GetVariable
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VariableMessage
*
out_var
)
override
;
const
TensorWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
...
...
@@ -75,8 +77,8 @@ class RPCClient {
RPCClient
(
std
::
shared_ptr
<
Channel
>
channel
)
:
stub_
(
SendRecvService
::
NewStub
(
channel
))
{}
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
,
const
std
::
string
&
outnam
e
);
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
);
bool
GetVariable
(
const
framework
::
Scope
&
scop
e
);
private:
std
::
unique_ptr
<
SendRecvService
::
Stub
>
stub_
;
...
...
paddle/operators/recv_op.cc
浏览文件 @
9508c726
...
...
@@ -66,37 +66,25 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
// blocking get one var from client.
const
detail
::
TensorWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
it
=
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
size_t
param_count
=
param_list
.
size
();
for
(
size_t
i
=
0
;
i
<
param_count
;
++
i
)
{
// blocking get one var from client.
const
detail
::
TensorWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
auto
it
=
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
VLOG
(
10
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
// FIXME(typhoonzero): do not copy
framework
::
CopyFrom
(
v
.
second
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
tensor
);
}
// find input by "grad_var_name"
// auto inputs = Inputs("RX");
// FIXME(typhoonzero): Find the parameter name from input grad name
// rename X -> Param
// rename RX -> Grad
LOG
(
ERROR
)
<<
"recved grad: "
<<
grad_var_name
<<
" param: "
<<
param_var_name
;
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
// Param is in parent scope, put it in current scope.
auto
*
param_var
=
recv_scope
.
FindVar
(
param_var_name
);
auto
param_scope
=
recv_scope
.
FindScope
(
param_var
);
param_scope
->
Rename
(
param_var_name
,
"Param"
);
recv_scope
.
Rename
(
grad_var_name
,
"Grad"
);
// FIXME(typhoonzero): do not copy
framework
::
CopyFrom
(
v
.
second
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
tensor
);
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
ProgramDesc
program_desc
;
...
...
@@ -104,17 +92,20 @@ class RecvOp : public framework::OperatorBase {
framework
::
ProgramDescBind
program
(
program_desc
);
framework
::
Executor
executor
(
dev_ctx
);
// Run sub graph to get optimized tensor
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
);
auto
*
out_var
=
recv_scope
.
FindVar
(
"ParamOut"
);
detail
::
TensorWithName
out
;
out
.
first
=
param_var_name
;
out
.
second
=
out_var
->
Get
<
framework
::
LoDTensor
>
();
rpc_service_
->
Push
(
out
);
// rename back the params
param_scope
.
Rename
(
"Param"
,
param_var_name
);
recv_scope
.
Rename
(
"Grad"
,
grad_var_name
);
try
{
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
for
(
size_t
i
=
0
;
i
<
param_count
;
++
i
)
{
auto
*
out_var
=
recv_scope
.
FindVar
(
param_list
[
i
]);
detail
::
TensorWithName
out
;
out
.
first
=
param_list
[
i
];
out
.
second
=
out_var
->
Get
<
framework
::
LoDTensor
>
();
rpc_service_
->
Push
(
out
);
}
}
protected:
...
...
paddle/operators/send_op.cc
浏览文件 @
9508c726
...
...
@@ -48,11 +48,18 @@ class SendOp : public framework::OperatorBase {
// should block until server responds.
for
(
auto
in
:
ins
)
{
LOG
(
ERROR
)
<<
"sending grad: "
<<
in
;
bool
ret
=
client_
->
SendVariable
(
scope
,
in
,
in
);
bool
ret
=
client_
->
SendVariable
(
scope
,
in
);
if
(
!
ret
)
{
LOG
(
ERROR
)
<<
"send variable error"
;
}
}
for
(
auto
in
:
ins
)
{
LOG
(
ERROR
)
<<
"updating from server..."
;
bool
ret
=
client_
->
GetVariable
(
scope
);
if
(
!
ret
)
{
LOG
(
ERROR
)
<<
"GetVariable error"
;
}
}
}
protected:
...
...
python/paddle/v2/fluid/executor.py
浏览文件 @
9508c726
...
...
@@ -138,7 +138,6 @@ class Executor(object):
inputs
=
opt_op
.
inputs
,
outputs
=
opt_op
.
outputs
,
attrs
=
opt_op
.
attrs
)
print
(
"optimize program: "
,
optimize_sub_program
)
pserver_program
.
global_block
().
append_op
(
type
=
"recv"
,
...
...
@@ -248,7 +247,7 @@ class Executor(object):
outputs
=
{
'Out'
:
[
fetch_var
]},
attrs
=
{
'col'
:
i
})
self
.
executor
.
run
(
program
.
desc
,
scope
,
0
,
True
)
self
.
executor
.
run
(
program
.
desc
,
scope
,
0
,
True
,
True
)
outs
=
[
core
.
get_fetch_variable
(
scope
,
fetch_var_name
,
i
)
for
i
in
xrange
(
len
(
fetch_list
))
...
...
python/paddle/v2/fluid/tests/book/test_recognize_digits_conv_dist.py
浏览文件 @
9508c726
...
...
@@ -44,6 +44,7 @@ exe.optimize(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1)
pserver_endpoint
=
os
.
getenv
(
"PSERVER"
)
if
pserver_endpoint
:
pserver_prog
=
exe
.
get_pserver_program
(
pserver_endpoint
,
optimize_ops
)
print
(
"pserver startup: "
,
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
while
True
:
exe
.
run
(
pserver_prog
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录