提交 6debbcd9 编写于 作者: Y Yancey1989

connect fetch barrier and concat op

上级 147d54ba
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
...@@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// always use the first device // always use the first device
CreateRPCOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
// CreateComputationalOps(&result, *op, 1); CreateDistTrainOp(&result, *op);
CreateComputationalOp(&result, *op, 0);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
...@@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
std::ostringstream sout; std::ofstream fout("/tmp/graph.dot");
PrintGraphviz(*graph, sout); PrintGraphviz(*graph, fout);
VLOG(10) << sout.str();
} }
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
...@@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, ...@@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
} }
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
......
...@@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
......
...@@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient"); auto client_var_name = Output("RPCClient");
int sync_recv = Attr<int>("sync_recv");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
...@@ -54,8 +55,10 @@ class RecvOp : public framework::OperatorBase { ...@@ -54,8 +55,10 @@ class RecvOp : public framework::OperatorBase {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
if (sync_recv) {
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
} }
}
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -75,6 +78,10 @@ This operator can get variables from server side. ...@@ -75,6 +78,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("sync_recv",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
} }
}; };
......
...@@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase {
"Can not find variable '%s' in the scope.", "Can not find variable '%s' in the scope.",
client_var_name); client_var_name);
auto* client_var = scope.FindVar(client_var_name); auto* client_var = scope.FindVar(client_var_name);
VLOG(3) << "client var addr: " << client_var;
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
VLOG(3) << "rpc_client addr: " << rpc_client;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -357,12 +357,35 @@ class DistributeTranspiler: ...@@ -357,12 +357,35 @@ class DistributeTranspiler:
ps_dispatcher.reset() ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars) eplist = ps_dispatcher.dispatch(recv_vars)
#program.global_block().append_op(
# type="recv",
# inputs={},
# outputs={"Out": recv_vars,
# "RPCClient": rpc_client_var},
# attrs={"epmap": eplist})
#program.global_block().append_op(
# type="fetch_barrier",
# inputs={},
# outputs={"RPCClient": rpc_client_var},
# attrs={"endpoints": pserver_endpoints})
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={}, inputs={},
outputs={"Out": recv_vars, outputs={"Out": splited_var,
"RPCClient": rpc_client_var}, "RPCClient": rpc_client_var},
attrs={"epmap": eplist}) attrs={"epmap": eps})
program.global_block().append_op( program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
...@@ -370,10 +393,6 @@ class DistributeTranspiler: ...@@ -370,10 +393,6 @@ class DistributeTranspiler:
outputs={"RPCClient": rpc_client_var}, outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints}) attrs={"endpoints": pserver_endpoints})
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册