提交 142e832d 编写于 作者: X Xin Pan

pass registration

上级 5b183557
......@@ -34,30 +34,16 @@ namespace paddle {
namespace framework {
namespace details {
void MultiDevSSAGraphBuilder::Init() const {
loss_var_name_ = Get<std::string>("loss_var_name");
places_ = Get<std::vector<platform::Place>>("places");
local_scopes_ = Get<std::vector<Scope *>>("local_scopes");
strategy_ = Get<BuildStrategy>("strategy");
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs),
strategy_(strategy) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
strategy_(strategy) {
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif
for (auto &p : params) {
for (auto &p : Get<std::unordered_set<std::string>>("params")) {
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0);
......@@ -241,6 +227,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = graph->ReleaseNodes();
......
......@@ -32,20 +32,6 @@ namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs,
const BuildStrategy &strategy);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
......@@ -53,15 +39,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
void Init() const;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
std::unordered_set<std::string> grad_names_;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
bool IsScaleLossOp(ir::Node *node) const;
......@@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &var_names) const;
private:
BuildStrategy strategy_;
mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
......
......@@ -22,26 +22,29 @@ namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(
std::unique_ptr<SSAGraphBuilder> res(new MultiDevSSAGraphBuilder);
res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
res->SetNotOwned<std::vector<Scope *>>("local_scopes", &local_scopes_);
res->SetNotOwned<BuildStrategy>("strategy", &strategy_);
#ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
res->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nccl_ctxs_);
#endif
); // NOLINT
if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout(
new std::ofstream(strategy_.debug_graphviz_path_));
PADDLE_ENFORCE(fout->good());
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
new GraphvizSSAGraphPrinter());
res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res)));
SSAGraphBuilder *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithPrinter);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
res->SetNotOwned<std::string>("debug_graphviz_path",
&strategy_.debug_graphviz_path_);
res->Set<GraphvizSSAGraphPrinter>("graph_printer",
new GraphvizSSAGraphPrinter);
}
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
SSAGraphBuilder *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithChecker);
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
return res;
}
......
......@@ -24,25 +24,19 @@ namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
explicit SSAGraghBuilderWithChecker(
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
auto new_graph =
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
}
bool IsValidGraph(const ir::Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
};
} // namespace details
......
......@@ -14,7 +14,9 @@
#pragma once
#include <fstream>
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
......@@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
SSAGraghBuilderWithPrinter(std::ostream& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ref_(sout) {}
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_);
auto new_graph =
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>("debug_graphviz_path")));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
std::unique_ptr<std::ostream> stream_ptr_;
std::ostream& stream_ref_;
};
} // namespace details
......
......@@ -23,7 +23,8 @@ namespace ir {
std::unique_ptr<ir::Graph> GraphVizPass::Apply(
std::unique_ptr<ir::Graph> graph) const {
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path_));
const std::string graph_viz_path = Get<std::string>("graph_viz_path");
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout;
......@@ -61,6 +62,9 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
sout << "}\n";
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass);
......@@ -29,14 +29,8 @@ namespace ir {
class GraphVizPass : public Pass {
public:
explicit GraphVizPass(const std::string& graph_viz_path)
: graph_viz_path_(graph_viz_path) {}
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override;
private:
const std::string graph_viz_path_;
};
} // namespace ir
......
......@@ -15,5 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {} // namespace framework
namespace framework {
namespace ir {
PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,9 +14,14 @@ limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
......@@ -25,10 +30,120 @@ namespace ir {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual ~Pass() {
for (auto &attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first]();
}
}
attrs_.clear();
attr_dels_.clear();
}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
}
private:
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
using PassCreator = std::function<std::unique_ptr<Pass>()>;
class Registrar {
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
class PassRegistry {
public:
static PassRegistry &Instance();
bool Has(const std::string &pass_type) const {
return map_.find(pass_type) != map_.end();
}
void Insert(const std::string &type, const PassCreator &pass_creator) {
PADDLE_ENFORCE(!Has(type), "Pass %s has been registered", type);
map_.insert({type, pass_creator});
}
std::unique_ptr<Pass> Get(const std::string &type) const {
PADDLE_ENFORCE(Has(type), "Pass %s has not been registered", type);
return map_.at(type)();
}
private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> map_;
DISABLE_COPY_AND_ASSIGN(PassRegistry);
};
template <typename PassType>
struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> {
return std::unique_ptr<Pass>(new PassType());
});
}
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
}
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -132,19 +132,27 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
builder_ = builder_factory.Create();
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!build_strategy.debug_graphviz_path_.empty()) {
const std::string origin_graph_path = string::Sprintf(
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph");
graph = ir::GraphVizPass(origin_graph_path).Apply(std::move(graph));
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
builder_ = builder_factory.Create();
graph = builder_->Apply(std::move(graph));
if (!build_strategy.debug_graphviz_path_.empty()) {
const std::string origin_graph_path = string::Sprintf(
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec");
graph = ir::GraphVizPass(origin_graph_path).Apply(std::move(graph));
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
......@@ -297,3 +305,5 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework
} // namespace paddle
USE_PASS(graph_viz_pass);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册