提交 142e832d 编写于 作者: X Xin Pan

pass registration

上级 5b183557
...@@ -34,30 +34,16 @@ namespace paddle { ...@@ -34,30 +34,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void MultiDevSSAGraphBuilder::Init() const {
loss_var_name_ = Get<std::string>("loss_var_name");
places_ = Get<std::vector<platform::Place>>("places");
local_scopes_ = Get<std::vector<Scope *>>("local_scopes");
strategy_ = Get<BuildStrategy>("strategy");
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs),
strategy_(strategy) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
strategy_(strategy) {
#endif #endif
for (auto &p : params) {
for (auto &p : Get<std::unordered_set<std::string>>("params")) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
...@@ -241,6 +227,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -241,6 +227,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph); std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
......
...@@ -32,20 +32,6 @@ namespace details { ...@@ -32,20 +32,6 @@ namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs,
const BuildStrategy &strategy);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
...@@ -53,15 +39,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -53,15 +39,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const; size_t device_id) const;
void Init() const;
private: private:
std::string loss_var_name_; mutable std::string loss_var_name_;
const std::vector<platform::Place> &places_; mutable std::vector<platform::Place> places_;
const std::vector<Scope *> &local_scopes_; mutable std::vector<Scope *> local_scopes_;
std::unordered_set<std::string> grad_names_; mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
...@@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &var_names) const; const std::vector<std::string> &var_names) const;
private: private:
BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_; mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_; mutable std::vector<int64_t> balance_vars_;
......
...@@ -22,26 +22,29 @@ namespace paddle { ...@@ -22,26 +22,29 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() { std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res( std::unique_ptr<SSAGraphBuilder> res(new MultiDevSSAGraphBuilder);
res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
res->SetNotOwned<std::vector<Scope *>>("local_scopes", &local_scopes_);
res->SetNotOwned<BuildStrategy>("strategy", &strategy_);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, res->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nccl_ctxs_);
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
#endif #endif
); // NOLINT
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout( SSAGraphBuilder *previous_pass = res.release();
new std::ofstream(strategy_.debug_graphviz_path_)); res.reset(new SSAGraghBuilderWithPrinter);
PADDLE_ENFORCE(fout->good()); res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer( res->SetNotOwned<std::string>("debug_graphviz_path",
new GraphvizSSAGraphPrinter()); &strategy_.debug_graphviz_path_);
res.reset(new SSAGraghBuilderWithPrinter( res->Set<GraphvizSSAGraphPrinter>("graph_printer",
std::move(fout), std::move(graphviz_printer), std::move(res))); new GraphvizSSAGraphPrinter);
} }
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
SSAGraphBuilder *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithChecker);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
return res; return res;
} }
......
...@@ -24,25 +24,19 @@ namespace details { ...@@ -24,25 +24,19 @@ namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder { class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: public:
explicit SSAGraghBuilderWithChecker(
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph =
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph; return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name); return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
} }
bool IsValidGraph(const ir::Graph* graph) const; bool IsValidGraph(const ir::Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
}; };
} // namespace details } // namespace details
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <fstream>
#include <iosfwd> #include <iosfwd>
#include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
...@@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { ...@@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public: public:
SSAGraghBuilderWithPrinter(std::ostream& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ref_(sout) {}
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph =
printer_->Print(*new_graph, stream_ref_); Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>("debug_graphviz_path")));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
return new_graph; return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name); return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
} }
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
std::unique_ptr<std::ostream> stream_ptr_;
std::ostream& stream_ref_;
}; };
} // namespace details } // namespace details
......
...@@ -23,7 +23,8 @@ namespace ir { ...@@ -23,7 +23,8 @@ namespace ir {
std::unique_ptr<ir::Graph> GraphVizPass::Apply( std::unique_ptr<ir::Graph> GraphVizPass::Apply(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path_)); const std::string graph_viz_path = Get<std::string>("graph_viz_path");
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
...@@ -61,6 +62,9 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply( ...@@ -61,6 +62,9 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
sout << "}\n"; sout << "}\n";
return graph; return graph;
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass);
...@@ -29,14 +29,8 @@ namespace ir { ...@@ -29,14 +29,8 @@ namespace ir {
class GraphVizPass : public Pass { class GraphVizPass : public Pass {
public: public:
explicit GraphVizPass(const std::string& graph_viz_path)
: graph_viz_path_(graph_viz_path) {}
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
private:
const std::string graph_viz_path_;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,5 +15,12 @@ limitations under the License. */ ...@@ -15,5 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
namespace ir {
PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
}
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,9 +14,14 @@ limitations under the License. */ ...@@ -14,9 +14,14 @@ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,10 +30,120 @@ namespace ir { ...@@ -25,10 +30,120 @@ namespace ir {
class Pass { class Pass {
public: public:
Pass() = default; Pass() = default;
virtual ~Pass() {} virtual ~Pass() {
for (auto &attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first]();
}
}
attrs_.clear();
attr_dels_.clear();
}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0; virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
}
private:
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
using PassCreator = std::function<std::unique_ptr<Pass>()>;
class Registrar {
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void Touch() {}
}; };
class PassRegistry {
public:
static PassRegistry &Instance();
bool Has(const std::string &pass_type) const {
return map_.find(pass_type) != map_.end();
}
void Insert(const std::string &type, const PassCreator &pass_creator) {
PADDLE_ENFORCE(!Has(type), "Pass %s has been registered", type);
map_.insert({type, pass_creator});
}
std::unique_ptr<Pass> Get(const std::string &type) const {
PADDLE_ENFORCE(Has(type), "Pass %s has not been registered", type);
return map_.at(type)();
}
private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> map_;
DISABLE_COPY_AND_ASSIGN(PassRegistry);
};
template <typename PassType>
struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> {
return std::unique_ptr<Pass>(new PassType());
});
}
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
}
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -132,19 +132,27 @@ ParallelExecutor::ParallelExecutor( ...@@ -132,19 +132,27 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA."); PADDLE_THROW("Not compiled with CUDA.");
#endif #endif
} }
builder_ = builder_factory.Create();
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!build_strategy.debug_graphviz_path_.empty()) { if (!build_strategy.debug_graphviz_path_.empty()) {
const std::string origin_graph_path = string::Sprintf( auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph"); "%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph");
graph = ir::GraphVizPass(origin_graph_path).Apply(std::move(graph)); viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
} }
builder_ = builder_factory.Create();
graph = builder_->Apply(std::move(graph)); graph = builder_->Apply(std::move(graph));
if (!build_strategy.debug_graphviz_path_.empty()) { if (!build_strategy.debug_graphviz_path_.empty()) {
const std::string origin_graph_path = string::Sprintf( auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec"); "%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec");
graph = ir::GraphVizPass(origin_graph_path).Apply(std::move(graph)); viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
} }
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
...@@ -297,3 +305,5 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -297,3 +305,5 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(graph_viz_pass);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册