提交 6debbcd9 编写于 作者: Y Yancey1989

connect fetch barrier and concat op

上级 147d54ba
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
......@@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// always use the first device
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
// CreateComputationalOps(&result, *op, 1);
CreateComputationalOp(&result, *op, 0);
CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
......@@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
PrintGraphviz(*graph, sout);
VLOG(10) << sout.str();
std::ofstream fout("/tmp/graph.dot");
PrintGraphviz(*graph, fout);
}
return std::unique_ptr<SSAGraph>(graph);
......@@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
}
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
......
......@@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
......
......@@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
int sync_recv = Attr<int>("sync_recv");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......@@ -54,8 +55,10 @@ class RecvOp : public framework::OperatorBase {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
if (sync_recv) {
PADDLE_ENFORCE(rpc_client->Wait());
}
}
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -75,6 +78,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("sync_recv",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
}
};
......
......@@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase {
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
VLOG(3) << "client var addr: " << client_var;
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
VLOG(3) << "rpc_client addr: " << rpc_client;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......
......@@ -357,12 +357,35 @@ class DistributeTranspiler:
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
#program.global_block().append_op(
# type="recv",
# inputs={},
# outputs={"Out": recv_vars,
# "RPCClient": rpc_client_var},
# attrs={"epmap": eplist})
#program.global_block().append_op(
# type="fetch_barrier",
# inputs={},
# outputs={"RPCClient": rpc_client_var},
# attrs={"endpoints": pserver_endpoints})
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
program.global_block().append_op(
type="recv",
inputs={},
outputs={"Out": recv_vars,
outputs={"Out": splited_var,
"RPCClient": rpc_client_var},
attrs={"epmap": eplist})
attrs={"epmap": eps})
program.global_block().append_op(
type="fetch_barrier",
......@@ -370,10 +393,6 @@ class DistributeTranspiler:
outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints})
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册