提交 e9abc669 编写于 作者: Y Yancey1989

fix pe

上级 952fa040
......@@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] {
VLOG(3) << "begin run op type is " << op_->Type();
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
VLOG(3) << "end run op type is " << op_->Type();
});
}
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
......@@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
}
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const {
std::vector<std::string> send_vars;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send_vars" || op->Type() == "send") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "recv" || op->Type() == "send") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}
......@@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool {
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
if (op.Type() == "split" || op.Type() == "split_byref") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
if (op.Type() == "split" || op.Type() == "split_byref" ||
op.Type() == "split_selected_rows") {
return checker(op.OutputArgumentNames(), send_vars);
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
return checker(op.InputArgumentNames(), recv_vars);
}
return false;
}
......@@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
// Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
......@@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
// CreateComputationalOps(&result, *op, 1);
CreateComputationalOp(&result, *op, 0);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
......@@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::string filename = "/tmp/graph";
std::ofstream fout(filename);
PrintGraphviz(*graph, fout);
std::ostringstream sout;
PrintGraphviz(*graph, sout);
VLOG(10) << sout.str();
}
return std::unique_ptr<SSAGraph>(graph);
......@@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
......@@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var;
}
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
std::string op_name) const {
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == op_name) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
result->ops_.back().get()->AddInput(dep_var);
op->AddInput(dep_var);
}
}
}
......@@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
auto *s = local_scopes_[0];
VLOG(3) << "create rpc op: " << op.Type();
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") {
ConnectOp(result, "send_vars");
ConnectOp(result, result->ops_.back().get(), "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, "send_barrier");
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, "recv");
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send" || op.Type() == "send_vars") {
// do nothing
} else {
......@@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
}
// FIXME(wuyi): send op always copy from GPU 0
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(result, op, 0);
......
......@@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
bool IsRPCOp(const OpDesc &op) const;
void ConnectOp(SSAGraph *result, std::string op_name) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;
......
......@@ -245,17 +245,11 @@ bool RPCClient::Proceed() {
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
const std::string& key) {
VLOG(3) << "this addr: " << this;
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(key);
if (it != channels_.end()) {
VLOG(3) << "find ep: " << ep;
return it->second;
}
VLOG(3) << "can not find ep: " << ep;
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
VLOG(3) << "ep: " << it->first;
}
grpc::ChannelArguments args;
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
......
......@@ -373,6 +373,16 @@ class DistributeTranspiler:
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={"axis": 0})
# TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册