Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
627d7a64
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
627d7a64
编写于
6月 11, 2018
作者:
G
gongweibao
提交者:
GitHub
6月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clean `sendop` `recv` operator. (#11309)
上级
fa29ef0b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
39 addition
and
162 deletion
+39
-162
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+4
-4
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+4
-7
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+7
-1
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+12
-38
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+0
-101
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+3
-2
python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py
...ddle/fluid/tests/unittests/test_simple_dist_transpiler.py
+2
-2
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+7
-7
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
627d7a64
...
@@ -89,7 +89,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
...
@@ -89,7 +89,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// TODO(Yancey1989): use a graceful method to find send op,
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
// instead of the the hard code string
if
(
op
->
Type
()
==
"send
_vars
"
)
{
if
(
op
->
Type
()
==
"send"
)
{
auto
op_vars
=
op
->
InputArgumentNames
();
auto
op_vars
=
op
->
InputArgumentNames
();
send_vars
.
reserve
(
send_vars
.
size
()
+
send_vars
.
reserve
(
send_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
...
@@ -468,17 +468,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
...
@@ -468,17 +468,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
new
RPCOpHandle
(
op
,
local_scopes_
[
0
],
op
.
Type
(),
places_
[
0
]));
new
RPCOpHandle
(
op
,
local_scopes_
[
0
],
op
.
Type
(),
places_
[
0
]));
if
(
op
.
Type
()
==
"send_barrier"
)
{
if
(
op
.
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send
_vars
"
);
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send"
);
}
else
if
(
op
.
Type
()
==
"recv"
)
{
}
else
if
(
op
.
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_barrier"
);
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_barrier"
);
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"recv"
);
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"recv"
);
}
else
if
(
op
.
Type
()
==
"send
_vars
"
)
{
}
else
if
(
op
.
Type
()
==
"send"
)
{
// do nothing
// do nothing
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"rpc op should be in ["
"rpc op should be in ["
"send
_vars
, send_barrier. recv, fetch_barrier]"
);
"send, send_barrier. recv, fetch_barrier]"
);
}
}
// TODO(Yancey1989): schedule rpc op on different place may
// TODO(Yancey1989): schedule rpc op on different place may
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
627d7a64
...
@@ -189,16 +189,14 @@ if(WITH_DISTRIBUTE)
...
@@ -189,16 +189,14 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
op_library
(
send_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
prefetch_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
prefetch_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
prefetch_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
prefetch_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
recv_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
recv_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
recv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
recv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
listen_and_serv_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
listen_and_serv_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
send_
vars_
op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
send_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_
vars_
op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
send_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
send_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
send_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
fetch_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
fetch_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
send_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
@@ -208,15 +206,14 @@ if(WITH_DISTRIBUTE)
...
@@ -208,15 +206,14 @@ if(WITH_DISTRIBUTE)
# listen_and_serv_op sum_op executor SERIAL)
# listen_and_serv_op sum_op executor SERIAL)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
set_source_files_properties
(
test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op
cc_test
(
test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL
)
listen_and_serv_op executor SERIAL
)
op_library
(
gen_nccl_id_op DEPS nccl_common sendrecvop_grpc
)
op_library
(
gen_nccl_id_op DEPS nccl_common sendrecvop_grpc
)
set_source_files_properties
(
gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
endif
()
endif
()
else
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op prefetch_op recv_op listen_and_serv_op send_vars
_op send_barrier_op fetch_barrier_op gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
prefetch_op recv_op listen_and_serv_op send
_op send_barrier_op fetch_barrier_op gen_nccl_id_op
)
endif
()
endif
()
op_library
(
cross_entropy_op DEPS cross_entropy
)
op_library
(
cross_entropy_op DEPS cross_entropy
)
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
627d7a64
...
@@ -78,9 +78,15 @@ This operator can get variables from server side.
...
@@ -78,9 +78,15 @@ This operator can get variables from server side.
}
}
};
};
class
RecvOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
recv
,
ops
::
RecvOp
,
ops
::
RecvOpMaker
);
REGISTER_OPERATOR
(
recv
,
ops
::
RecvOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
RecvOpMaker
,
ops
::
RecvOpShapeInference
);
paddle/fluid/operators/send_op.cc
浏览文件 @
627d7a64
...
@@ -16,7 +16,6 @@ limitations under the License. */
...
@@ -16,7 +16,6 @@ limitations under the License. */
#include <ostream>
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
...
@@ -36,12 +35,9 @@ class SendOp : public framework::OperatorBase {
...
@@ -36,12 +35,9 @@ class SendOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
ins
=
Inputs
(
"X"
);
auto
ins
=
Inputs
(
"X"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
endpoints
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_send
=
Attr
<
int
>
(
"sync_mode"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
...
@@ -55,32 +51,14 @@ class SendOp : public framework::OperatorBase {
...
@@ -55,32 +51,14 @@ class SendOp : public framework::OperatorBase {
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
}
}
rpc_client
->
Wait
();
if
(
sync_send
)
{
if
(
sync_mode
)
{
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
rpc_client
->
Wait
();
}
if
(
outs
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
rpc_client
->
Wait
();
// tell pservers that current trainer have called fetch
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
2
)
<<
"send fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
rpc_client
->
Wait
();
rpc_client
->
Wait
();
}
}
}
}
...
@@ -89,26 +67,22 @@ class SendOp : public framework::OperatorBase {
...
@@ -89,26 +67,22 @@ class SendOp : public framework::OperatorBase {
class
SendOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SendOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
{
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
.
AsDuplicable
();
.
AsDuplicable
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Send operator
Send operator
This operator will send
tensor to recv_
op at the parameter server.
This operator will send
variables to listen_and_serve
op at the parameter server.
)DOC"
);
)DOC"
);
// TODO(typhoonzero): remove this attr generate de-duplicated vector from
AddAttr
<
int
>
(
"sync_mode"
,
// epmap when initializing.
"(int, default 0)"
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"sync send or async send."
)
"(string vector, default 127.0.0.1:6164)"
.
SetDefault
(
0
);
"Server endpoints to send variables to."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"Server endpoints in the order of input "
"variables for mapping"
)
"variables for mapping"
)
.
SetDefault
({});
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
bool
>
(
"sync_mode"
,
"work in sync_mode or not"
).
SetDefault
(
true
);
}
}
};
};
...
...
paddle/fluid/operators/send_vars_op.cc
已删除
100644 → 0
浏览文件 @
fa29ef0b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
class
SendVarsOp
:
public
framework
::
OperatorBase
{
public:
SendVarsOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
ins
=
Inputs
(
"X"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_send
=
Attr
<
int
>
(
"sync_send"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
if
(
sync_send
)
{
rpc_client
->
Wait
();
}
}
};
class
SendVarsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
.
AsDuplicable
();
AddComment
(
R"DOC(
Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC"
);
AddAttr
<
int
>
(
"sync_send"
,
"(int, default 0)"
"sync send or async send."
)
.
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({
"127.0.0.1:6164"
});
}
};
class
SendVarsOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
send_vars
,
ops
::
SendVarsOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendVarsOpMaker
,
ops
::
SendVarsOpShapeInference
);
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
627d7a64
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
delete_ops
from
paddle.fluid.transpiler.distribute_transpiler
import
delete_ops
...
@@ -54,10 +55,10 @@ class TestDistTranspiler(TranspilerTest):
...
@@ -54,10 +55,10 @@ class TestDistTranspiler(TranspilerTest):
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
ops
=
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
ops
=
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
"split_byref"
,
"send
_vars
"
,
"send_barrier"
,
"recv"
,
"recv"
,
"split_byref"
,
"send"
,
"send_barrier"
,
"recv"
,
"recv"
,
"fetch_barrier"
,
"concat"
"fetch_barrier"
,
"concat"
]
]
ops
.
insert
(
ops
.
index
(
"elementwise_add_grad"
)
+
1
,
"send
_vars
"
)
ops
.
insert
(
ops
.
index
(
"elementwise_add_grad"
)
+
1
,
"send"
)
return
ops
return
ops
...
...
python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py
浏览文件 @
627d7a64
...
@@ -59,9 +59,9 @@ class TestSimpleDistTranspiler(TranspilerTest):
...
@@ -59,9 +59,9 @@ class TestSimpleDistTranspiler(TranspilerTest):
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
ops
=
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
ops
=
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
"send
_vars
"
,
"send_barrier"
,
"recv"
,
"recv"
,
"fetch_barrier"
"send"
,
"send_barrier"
,
"recv"
,
"recv"
,
"fetch_barrier"
]
]
ops
.
insert
(
ops
.
index
(
"elementwise_add_grad"
)
+
1
,
"send
_vars
"
)
ops
.
insert
(
ops
.
index
(
"elementwise_add_grad"
)
+
1
,
"send"
)
return
ops
return
ops
def
_transpiler_instance
(
self
):
def
_transpiler_instance
(
self
):
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
627d7a64
...
@@ -24,9 +24,9 @@ Steps to transpile trainer:
...
@@ -24,9 +24,9 @@ Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and
fetch
4. append send_op to send splited variables to server and
params(splited blocks or origin param) from server.
5. add recv_op to fetch
params(splited blocks or origin param) from server.
5
. append concat_op to merge splited blocks to update local weights.
6
. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
Steps to transpile pserver:
1. create new program for parameter server.
1. create new program for parameter server.
...
@@ -317,7 +317,7 @@ class DistributeTranspiler:
...
@@ -317,7 +317,7 @@ class DistributeTranspiler:
program
.
global_block
().
insert_op
(
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
index
=
index
+
1
,
type
=
"send
_vars
"
,
type
=
"send"
,
inputs
=
{
"X"
:
splited_vars
},
inputs
=
{
"X"
:
splited_vars
},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
...
@@ -678,7 +678,7 @@ class DistributeTranspiler:
...
@@ -678,7 +678,7 @@ class DistributeTranspiler:
break
break
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
pserver_endpoints
):
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
pserver_endpoints
):
# 2. add split_ids_op and send_
vars_
op to send gradient to pservers
# 2. add split_ids_op and send_op to send gradient to pservers
# there should only be one table_name
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
all_ops
=
program
.
global_block
().
ops
table_grad_name
=
grad_var_name
(
self
.
table_name
)
table_grad_name
=
grad_var_name
(
self
.
table_name
)
...
@@ -695,11 +695,11 @@ class DistributeTranspiler:
...
@@ -695,11 +695,11 @@ class DistributeTranspiler:
outputs
=
{
"Out"
:
self
.
trainer_side_table_grad_list
})
outputs
=
{
"Out"
:
self
.
trainer_side_table_grad_list
})
program
.
global_block
().
insert_op
(
program
.
global_block
().
insert_op
(
index
=
op_index
+
2
,
index
=
op_index
+
2
,
type
=
"send
_vars
"
,
type
=
"send"
,
inputs
=
{
'X'
:
self
.
trainer_side_table_grad_list
},
inputs
=
{
'X'
:
self
.
trainer_side_table_grad_list
},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
"sync_
send
"
:
True
,
"sync_
mode
"
:
True
,
"epmap"
:
pserver_endpoints
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录