未验证 提交 ff9b1a0f 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #11234 from reyoung/feature/refine_code

SSA Graph Builder Factory
...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope ...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method) framework_proto glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place ...@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
...@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
void RunImpl() override; void RunImpl() override;
private: private:
const std::vector<Scope *> &local_scopes_; std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> &places_; std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -29,6 +31,8 @@ struct BuildStrategy { ...@@ -29,6 +31,8 @@ struct BuildStrategy {
ReduceStrategy reduce_{ReduceStrategy::kAllReduce}; ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
std::string debug_graphviz_path_{""};
}; };
} // namespace details } // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(
#ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
#endif
); // NOLINT
if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout(
new std::ofstream(strategy_.debug_graphviz_path_));
PADDLE_ENFORCE(fout->good());
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
new GraphvizSSAGraphPrinter());
res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res)));
}
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
class Scope;
namespace details {
class SSAGraphBuilderFactory {
public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes,
const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {}
#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
nccl_ctxs_ = nccl_ctxs;
}
#endif
std::unique_ptr<SSAGraphBuilder> Create();
private:
std::vector<platform::Place> places_;
std::string loss_var_name_;
std::unordered_set<std::string> param_names_;
std::vector<Scope*> local_scopes_;
BuildStrategy strategy_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap* nccl_ctxs_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -30,10 +30,6 @@ ...@@ -30,10 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif #endif
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, fout);
}
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
......
...@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
void RunImpl() override; void RunImpl() override;
private: private:
const std::vector<Scope *> &local_scopes_; std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> &places_; std::vector<platform::Place> places_;
const platform::NCCLContextMap &nccl_ctxs_; const platform::NCCLContextMap &nccl_ctxs_;
}; };
......
...@@ -32,8 +32,8 @@ namespace framework { ...@@ -32,8 +32,8 @@ namespace framework {
namespace details { namespace details {
struct ReduceOpHandle : public OpHandleBase { struct ReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> &places_; std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
......
...@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, ...@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
template <typename Callback>
void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
}
}
}
for (auto &var : graph.dep_vars_) {
callback(*var);
}
}
void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}
for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}
sout << "}\n";
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) { for (auto &op : graph->ops_) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
......
...@@ -55,8 +55,6 @@ class SSAGraphBuilder { ...@@ -55,8 +55,6 @@ class SSAGraphBuilder {
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph); static void AddOutputToLeafOps(SSAGraph *graph);
static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
namespace paddle {
namespace framework {
namespace details {
template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
}
}
}
for (auto &var : graph.dep_vars_) {
callback(*var);
}
}
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}
for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}
sout << "}\n";
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iosfwd>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
class SSAGraph;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public:
void Print(const SSAGraph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
SSAGraghBuilderWithPrinter(std::ostream& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ref_(sout) {}
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_);
return graph;
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
std::unique_ptr<std::ostream> stream_ptr_;
std::ostream& stream_ref_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/graph_builder_factory.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -102,22 +102,19 @@ ParallelExecutor::ParallelExecutor( ...@@ -102,22 +102,19 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = var->Persistable(); var_infos.back().persistable_ = var->Persistable();
} }
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), build_strategy); build_strategy);
#else #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get());
params, member_->local_scopes_,
build_strategy);
#endif #endif
auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
......
...@@ -553,6 +553,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -553,6 +553,12 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, [](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) { BuildStrategy::GradientScaleStrategy strategy) {
self.gradient_scale_ = strategy; self.gradient_scale_ = strategy;
})
.def_property(
"debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
self.debug_graphviz_path_ = path;
}); });
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册