提交 aa1085dd 编写于 作者: X Xin Pan

all passes

add doc
上级 e4d7d7ae
......@@ -71,6 +71,44 @@ is a `Graph` and its output is also a `Graph`. For example,
a `Pass` can simply print out the `Graph`. A `Pass`
can also fuse some `Graph`'s `Node`s.
```cpp
class Pass {
public:
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) ;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr);
};
// In my_pass.cc
class MyPass : public Pass {
public:
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
// do something.
return graph;
}
}
REGISTER_PASS(my_pass, MyPass);
// To use the pass.
auto my_pass = ir::PassRegistry::Instance().Get("my_pass");
graph = my_pass->Apply(std::move(graph));
// Note: to force link my_pass.cc, in the code:
USE_PASS(my_pass);
```
#### Optimize
`Optimize` contains a series of `Pass` with defined order.
......
......@@ -99,7 +99,7 @@ else()
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -35,15 +35,15 @@ namespace framework {
namespace details {
void MultiDevSSAGraphBuilder::Init() const {
loss_var_name_ = Get<std::string>("loss_var_name");
places_ = Get<std::vector<platform::Place>>("places");
local_scopes_ = Get<std::vector<Scope *>>("local_scopes");
strategy_ = Get<BuildStrategy>("strategy");
loss_var_name_ = Get<const std::string>("loss_var_name");
places_ = Get<const std::vector<platform::Place>>("places");
local_scopes_ = Get<const std::vector<Scope *>>("local_scopes");
strategy_ = Get<const BuildStrategy>("strategy");
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif
for (auto &p : Get<std::unordered_set<std::string>>("params")) {
for (auto &p : Get<const std::unordered_set<std::string>>("params")) {
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Pass> ParallelExecutorPassManager::Create() {
std::unique_ptr<ir::Pass> res(new MultiDevSSAGraphBuilder);
res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
res->SetNotOwned<std::vector<Scope *>>("local_scopes", &local_scopes_);
res->SetNotOwned<BuildStrategy>("strategy", &strategy_);
#ifdef PADDLE_WITH_CUDA
res->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nccl_ctxs_);
#endif
if (!strategy_.debug_graphviz_path_.empty()) {
ir::Pass *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithPrinter);
res->Set<ir::Pass>("previous_pass", previous_pass);
res->SetNotOwned<std::string>("debug_graphviz_path",
&strategy_.debug_graphviz_path_);
res->Set<GraphvizSSAGraphPrinter>("graph_printer",
new GraphvizSSAGraphPrinter);
}
ir::Pass *previous_pass = res.release();
res.reset(new SSAGraghBuilderWithChecker);
res->Set<ir::Pass>("previous_pass", previous_pass);
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
class Scope;
namespace details {
class ParallelExecutorPassManager {
public:
ParallelExecutorPassManager(
const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes, const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = nullptr;
#endif
}
#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
nccl_ctxs_ = nccl_ctxs;
}
#endif
std::unique_ptr<ir::Pass> Create();
private:
std::vector<platform::Place> places_;
std::string loss_var_name_;
std::unordered_set<std::string> param_names_;
std::vector<Scope*> local_scopes_;
BuildStrategy strategy_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap* nccl_ctxs_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -26,9 +26,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph;
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph* graph) const;
......
......@@ -39,13 +39,11 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>("debug_graphviz_path")));
new std::ofstream(Get<const std::string>("debug_graphviz_path")));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
return new_graph;
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
}
};
......
......@@ -42,6 +42,7 @@ class Pass {
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
......@@ -49,6 +50,7 @@ class Pass {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
......@@ -59,6 +61,8 @@ class Pass {
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
......@@ -127,6 +131,7 @@ struct PassRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
......
......@@ -26,7 +26,8 @@ limitations under the License. */
#endif
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -43,16 +44,6 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
#else
const BuildStrategy &strategy) {
#endif
details::ParallelExecutorPassManager builder_factory(
places, loss_var_name, param_names, local_scopes, strategy);
if (use_cuda) {
#ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(nccl_ctxs);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
......@@ -62,8 +53,37 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
graph = viz_pass->Apply(std::move(graph));
}
auto builder = builder_factory.Create();
graph = builder->Apply(std::move(graph));
auto multi_device_pass =
ir::PassRegistry::Instance().Get("multi_device_pass");
multi_device_pass->SetNotOwned<const std::vector<platform::Place>>("places",
&places);
multi_device_pass->SetNotOwned<const std::string>("loss_var_name",
&loss_var_name);
multi_device_pass->SetNotOwned<const std::unordered_set<std::string>>(
"params", &param_names);
multi_device_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
multi_device_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
multi_device_pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
graph = multi_device_pass->Apply(std::move(graph));
if (!strategy.debug_graphviz_path_.empty()) {
auto multi_device_print_pass =
ir::PassRegistry::Instance().Get("multi_device_print_pass");
multi_device_print_pass->SetNotOwned<const std::string>(
"debug_graphviz_path", &strategy.debug_graphviz_path_);
multi_device_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
graph = multi_device_print_pass->Apply(std::move(graph));
}
auto multi_device_check_pass =
ir::PassRegistry::Instance().Get("multi_device_check_pass");
graph = multi_device_check_pass->Apply(std::move(graph));
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册