Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e805cfcb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e805cfcb
编写于
12月 20, 2017
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix unit test failed
上级
b1b7af40
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
30 addition
and
19 deletion
+30
-19
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+4
-2
paddle/operators/send_recv_op_test.cc
paddle/operators/send_recv_op_test.cc
+17
-12
python/paddle/v2/fluid/tests/book/notest_recognize_digits_conv_dist.py
.../v2/fluid/tests/book/notest_recognize_digits_conv_dist.py
+9
-5
未找到文件。
paddle/operators/recv_op.cc
浏览文件 @
e805cfcb
...
...
@@ -160,10 +160,12 @@ This operator will recv tensor from send_op
"Serialized ProgramDesc string for recv to run."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
,
"type list of string"
,
"grad->param name mapping to find which param to optimize."
);
"grad->param name mapping to find which param to optimize."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which param to optimize."
);
"grad->param name mapping to find which param to optimize."
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Trainers"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
...
...
paddle/operators/send_recv_op_test.cc
浏览文件 @
e805cfcb
...
...
@@ -16,12 +16,14 @@
// a RemoteOptimizer.
#include <unistd.h>
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
#include "paddle/string/printf.h"
USE_NO_KERNEL_OP
(
send
);
USE_NO_KERNEL_OP
(
recv
);
...
...
@@ -33,18 +35,21 @@ std::unique_ptr<paddle::framework::OperatorBase> recv_op;
void
InitTensorsInScope
(
paddle
::
framework
::
Scope
&
scope
,
paddle
::
platform
::
CPUPlace
&
place
)
{
paddle
::
platform
::
CPUDeviceContext
ctx
(
place
);
auto
var
=
scope
.
Var
(
"X"
);
auto
tensor
=
var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
tensor
->
Resize
({
10
,
10
});
float
*
expect
=
tensor
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
tensor
->
numel
();
++
i
)
{
expect
[
i
]
=
static_cast
<
float
>
(
i
);
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
auto
var_name
=
paddle
::
string
::
Sprintf
(
"x%d"
,
i
);
auto
var
=
scope
.
Var
(
var_name
);
auto
tensor
=
var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
tensor
->
Resize
({
10
,
10
});
float
*
expect
=
tensor
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
tensor
->
numel
();
++
i
)
{
expect
[
i
]
=
static_cast
<
float
>
(
i
);
}
}
auto
out_var
=
scope
.
Var
(
"Out"
);
auto
out_tensor
=
out_var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
out_tensor
->
Resize
({
10
,
10
});
tensor
->
mutable_data
<
float
>
(
place
);
// allocate
out_
tensor
->
mutable_data
<
float
>
(
place
);
// allocate
}
void
AddOp
(
const
std
::
string
&
type
,
...
...
@@ -81,7 +86,7 @@ void StartServerNet() {
paddle
::
framework
::
ProgramDescBind
program
;
paddle
::
framework
::
BlockDescBind
*
block
=
program
.
MutableBlock
(
0
);
// X for server side tensors, RX for received tensers, must be of same shape.
AddOp
(
"sum"
,
{{
"X"
,
{
"
X"
,
"RX
"
}}},
{{
"Out"
,
{
"Out"
}}},
{},
block
);
AddOp
(
"sum"
,
{{
"X"
,
{
"
x0"
,
"x1
"
}}},
{{
"Out"
,
{
"Out"
}}},
{},
block
);
paddle
::
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"endpoint"
,
std
::
string
(
"127.0.0.1:6174"
)});
...
...
@@ -89,8 +94,8 @@ void StartServerNet() {
PADDLE_ENFORCE
(
program
.
Proto
()
->
SerializeToString
(
&
program_proto
));
attrs
.
insert
({
"OptimizeProgram"
,
program_proto
});
recv_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"recv"
,
{{
"RX"
,
{
"RX"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
recv_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"recv"
,
{{
"RX"
,
{
"x0"
,
"x1"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
paddle
::
platform
::
CPUDeviceContext
ctx
(
place
);
recv_op
->
Run
(
scope
,
ctx
);
}
...
...
@@ -107,11 +112,11 @@ TEST(SendRecvOp, CPU) {
attrs
.
insert
({
"endpoint"
,
std
::
string
(
"127.0.0.1:6174"
)});
auto
send_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"
X
"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
"send"
,
{{
"X"
,
{
"
x0"
,
"x1
"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
paddle
::
platform
::
CPUDeviceContext
ctx
(
place
);
send_op
->
Run
(
scope
,
ctx
);
auto
in_var
=
scope
.
Var
(
"
X
"
);
auto
in_var
=
scope
.
Var
(
"
x0
"
);
auto
tensor
=
in_var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
float
*
expected
=
tensor
->
data
<
float
>
();
...
...
python/paddle/v2/fluid/tests/book/notest_recognize_digits_conv_dist.py
浏览文件 @
e805cfcb
...
...
@@ -39,14 +39,16 @@ train_reader = paddle.batch(
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
)
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
1
)
pserver_endpoint
=
os
.
getenv
(
"PSERVER"
)
if
pserver_endpoint
:
pserver_prog
=
t
.
get_pserver_program
(
pserver_endpoint
,
optimize_ops
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
pserver_endpoints
,
optimize_ops
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
pserver_prog
)
el
se
:
el
if
training_role
==
"TRAINER"
:
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
...
...
@@ -64,5 +66,7 @@ else:
pass_acc
=
accuracy
.
eval
(
exe
)
print
(
"pass_id="
+
str
(
pass_id
)
+
" pass_acc="
+
str
(
pass_acc
))
else
:
print
(
"environment var TRAINER_ROLE should be TRAINER os PSERVER"
)
exit
(
1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录