提交 e9abc669 编写于 作者: Y Yancey1989

fix pe

上级 952fa040
...@@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() { ...@@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] { this->RunAndRecordEvent([this] {
VLOG(3) << "begin run op type is " << op_->Type();
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
VLOG(3) << "end run op type is " << op_->Type();
}); });
} }
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
...@@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
CreateOpOutput(result, op_handle, each_var_name, p, place_id); CreateOpOutput(result, op_handle, each_var_name, p, place_id);
} }
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const { std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
if (send_op == nullptr) { const ProgramDesc &program) const {
std::vector<std::string> send_vars;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send_vars" || op->Type() == "send") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "recv" || op->Type() == "send") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false; return false;
} }
...@@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars * Check any of opvars contains `.block` and in sendvars
*/ */
auto checker = [](const std::vector<std::string> &opvars, auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool { const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) { for (auto &var : opvars) {
if (var.find(".block") != std::string::npos && if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true; return true;
} }
} }
return false; return false;
}; };
if (op.Type() == "split" || op.Type() == "split_byref") { if (op.Type() == "split" || op.Type() == "split_byref" ||
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); op.Type() == "split_selected_rows") {
return checker(op.OutputArgumentNames(), send_vars);
} else if (op.Type() == "concat") { } else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); return checker(op.InputArgumentNames(), recv_vars);
} }
return false; return false;
} }
...@@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
// Find "send" op first for split is in front of send. // find send/recv vars so that we can place the distributed training
OpDesc *send_op = GetSendOpDesc(program); // realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> var_name_on_devices;
...@@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// append rpc op if program is distributed trainer main program. // append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateRPCOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateComputationalOps(&result, *op, 1); // CreateComputationalOps(&result, *op, 1);
CreateComputationalOp(&result, *op, 0);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
...@@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
std::string filename = "/tmp/graph"; std::ostringstream sout;
std::ofstream fout(filename); PrintGraphviz(*graph, sout);
PrintGraphviz(*graph, fout); VLOG(10) << sout.str();
} }
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
...@@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( ...@@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
} }
return nullptr; return nullptr;
} }
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var; return var;
} }
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
std::string op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) { for (auto &prev_op : result->ops_) {
if (prev_op->Name() == op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var); result->dep_vars_.emplace(dep_var);
result->ops_.back().get()->AddInput(dep_var); op->AddInput(dep_var);
} }
} }
} }
...@@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
VLOG(3) << "create rpc op: " << op.Type();
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, "send_vars"); ConnectOp(result, result->ops_.back().get(), "send_vars");
} else if (op.Type() == "recv") { } else if (op.Type() == "recv") {
ConnectOp(result, "send_barrier"); ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") { } else if (op.Type() == "fetch_barrier") {
ConnectOp(result, "recv"); ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send" || op.Type() == "send_vars") { } else if (op.Type() == "send" || op.Type() == "send_vars") {
// do nothing // do nothing
} else { } else {
...@@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
} }
// FIXME(wuyi): send op always copy from GPU 0 // FIXME(wuyi): send op always copy from GPU 0
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
// Create inputs for output on original place and no ssa output // Create inputs for output on original place and no ssa output
// is created for send op. // is created for send op.
CreateOpHandleIOs(result, op, 0); CreateOpHandleIOs(result, op, 0);
......
...@@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
*/ */
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
bool IsRPCOp(const OpDesc &op) const; bool IsRPCOp(const OpDesc &op) const;
void ConnectOp(SSAGraph *result, std::string op_name) const; void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const; size_t num_places) const;
......
...@@ -245,17 +245,11 @@ bool RPCClient::Proceed() { ...@@ -245,17 +245,11 @@ bool RPCClient::Proceed() {
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep, std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
const std::string& key) { const std::string& key) {
VLOG(3) << "this addr: " << this;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(key); auto it = channels_.find(key);
if (it != channels_.end()) { if (it != channels_.end()) {
VLOG(3) << "find ep: " << ep;
return it->second; return it->second;
} }
VLOG(3) << "can not find ep: " << ep;
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
VLOG(3) << "ep: " << it->first;
}
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
......
...@@ -373,6 +373,16 @@ class DistributeTranspiler: ...@@ -373,6 +373,16 @@ class DistributeTranspiler:
for i, ep in enumerate(eplist): for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={"axis": 0})
# TODO(Yancey1989): check dist lookup table # TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册