提交 e4d7d7ae 编写于 作者: X Xin Pan

pass refactoring

上级 142e832d
......@@ -244,6 +244,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
result.Set("vars", new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars);
result.Set("ops", new GraphOps);
result.Set("sharded_var_device", new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
......@@ -276,11 +277,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block.
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(node);
int op_dev_id = GetOpDeviceID(result, node);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Name(), op_dev_id);
graph->Get<ShardedVarDevice>("sharded_var_device")
.emplace(n->Name(), op_dev_id);
}
} else {
// This op runs on all devices, and its output may have parameter's
......@@ -317,7 +319,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id);
graph->Get<ShardedVarDevice>("sharded_var_device")
.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
......@@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
......@@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]);
int dev_id = GetVarDeviceID(graph, param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
const std::string &varname) const {
auto &sharded_var_device = graph.Get<ShardedVarDevice>("sharded_var_device");
auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
......@@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]);
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
}
}
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
}
} else {
PADDLE_ENFORCE(
......@@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int op_dev_id = -1;
if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by
......@@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
}
}
} else if (node->Op()->Type() == "recv") {
......@@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id);
result->Get<ShardedVarDevice>("sharded_var_device")
.emplace(varname, op_dev_id);
}
} else {
// send_barrier and fetch_barrier op can be scheduled on device 0
......@@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_device_pass,
paddle::framework::details::MultiDevSSAGraphBuilder);
......@@ -34,7 +34,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
......@@ -51,6 +50,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
......@@ -84,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(ir::Node *node) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......@@ -102,7 +103,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle,
......
......@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
const ir::Graph& Graph() const { return underlying_executor_->Graph(); }
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
......
......@@ -47,13 +47,13 @@ typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
typedef std::unordered_map<std::string, int> ShardedVarDevice;
class SSAGraphBuilder : public ir::Pass {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected:
......
......@@ -21,8 +21,8 @@
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(new MultiDevSSAGraphBuilder);
std::unique_ptr<ir::Pass> ParallelExecutorPassManager::Create() {
std::unique_ptr<ir::Pass> res(new MultiDevSSAGraphBuilder);
res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
......@@ -33,18 +33,18 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
#endif
if (!strategy_.debug_graphviz_path_.empty()) {
SSAGraphBuilder *previous_pass = res.release();
ir::Pass *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithPrinter);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
res->Set<ir::Pass>("previous_pass", previous_pass);
res->SetNotOwned<std::string>("debug_graphviz_path",
&strategy_.debug_graphviz_path_);
res->Set<GraphvizSSAGraphPrinter>("graph_printer",
new GraphvizSSAGraphPrinter);
}
SSAGraphBuilder *previous_pass = res.release();
ir::Pass *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithChecker);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
res->Set<ir::Pass>("previous_pass", previous_pass);
return res;
}
......
......@@ -29,13 +29,13 @@ namespace framework {
class Scope;
namespace details {
class SSAGraphBuilderFactory {
class ParallelExecutorPassManager {
public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
ParallelExecutorPassManager(
const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes,
const BuildStrategy& strategy)
const std::vector<Scope*>& local_scopes, const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
......@@ -52,7 +52,7 @@ class SSAGraphBuilderFactory {
}
#endif
std::unique_ptr<SSAGraphBuilder> Create();
std::unique_ptr<ir::Pass> Create();
private:
std::vector<platform::Place> places_;
......
......@@ -85,3 +85,6 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker);
......@@ -26,16 +26,11 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph =
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
}
bool IsValidGraph(const ir::Graph* graph) const;
};
......
......@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
virtual const ir::Graph& Graph() const = 0;
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
};
} // namespace details
} // namespace framework
......
......@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_device_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter);
......@@ -39,8 +39,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph =
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>("debug_graphviz_path")));
......@@ -48,10 +47,6 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
}
};
} // namespace details
......
......@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph);
const ir::Graph &Graph() const { return *graph_; }
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
......
......@@ -42,6 +42,8 @@ class Graph {
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for graph.", attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
......
......@@ -44,6 +44,8 @@ class Pass {
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
......
......@@ -33,6 +33,48 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes, const bool use_cuda,
#ifdef PADDLE_WITH_CUDA
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
#else
const BuildStrategy &strategy) {
#endif
details::ParallelExecutorPassManager builder_factory(
places, loss_var_name, param_names, local_scopes, strategy);
if (use_cuda) {
#ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(nccl_ctxs);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
auto builder = builder_factory.Create();
graph = builder->Apply(std::move(graph));
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_before_exec");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
return graph;
}
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
......@@ -120,38 +162,18 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = var->Persistable();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy);
if (member_->use_cuda_) {
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get());
std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get());
#else
PADDLE_THROW("Not compiled with CUDA.");
std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy);
#endif
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!build_strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
builder_ = builder_factory.Create();
graph = builder_->Apply(std::move(graph));
if (!build_strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
......@@ -165,11 +187,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool initializing = builder_.get() == nullptr ? true : false;
bool initializing = member_->executor_ ? false : true;
for (auto &var : vars) {
int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
int var_dev_id = -1;
if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
"sharded_var_device");
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}
}
if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
......@@ -307,3 +336,6 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace paddle
USE_PASS(graph_viz_pass);
USE_PASS(multi_device_pass);
USE_PASS(multi_device_check_pass);
USE_PASS(multi_device_print_pass);
......@@ -70,7 +70,6 @@ class ParallelExecutor {
private:
ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
};
} // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto3";
package sendrecv;
option cc_generic_services = false;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.
// It can be:
// LoDTensor
// SelectedRows
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64 profile = 11;
}
message VoidMessage {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册