提交 2fa8df1c 编写于 作者: X Xin Pan

separate graph building pass and graph-based pe builder

上级 37e51443
......@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle {
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(new ir::Node(ir::Node::Type::kOperation));
std::unique_ptr<ir::Node> n(new ir::Node());
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
......@@ -114,7 +114,7 @@ struct TestBroadcastOpHandle {
#endif
}
std::unique_ptr<ir::Node> v(new ir::Node(ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v(new ir::Node());
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
......@@ -122,7 +122,7 @@ struct TestBroadcastOpHandle {
// add dummy var
std::unique_ptr<ir::Node> v2(new ir::Node(ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v2(new ir::Node());
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......@@ -133,7 +133,7 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
std::unique_ptr<ir::Node> v3(new ir::Node(ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v3(new ir::Node());
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
......@@ -141,7 +141,7 @@ struct TestBroadcastOpHandle {
}
// add dummy var
std::unique_ptr<ir::Node> v4(new ir::Node(ir::Node::Type::kVariable));
std::unique_ptr<ir::Node> v4(new ir::Node());
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......
......@@ -82,13 +82,13 @@ struct TestGatherOpHandle {
}
param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
nodes.emplace_back(new ir::Node());
op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
nodes.emplace_back(new ir::Node());
auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
......@@ -96,7 +96,7 @@ struct TestGatherOpHandle {
}
// add dummy var
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
nodes.emplace_back(new ir::Node());
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......@@ -104,14 +104,14 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle);
// add output
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
nodes.emplace_back(new ir::Node());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
nodes.emplace_back(new ir::Node());
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
......
......@@ -46,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t device_id) const;
void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const;
private:
std::string loss_var_name_;
......@@ -64,40 +62,39 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform::NCCLContextMap *nccl_ctxs_;
#endif
bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(Graph *result, const OpDesc &op) const;
void CreateDistTrainOp(Graph *result, const OpDesc &op) const;
void CreateRPCOp(Graph *result, ir::Node *node) const;
void CreateDistTrainOp(Graph *result, ir::Node *node) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(Graph *result, const OpDesc &op,
void CreateComputationalOps(Graph *result, ir::Node *node,
size_t num_places) const;
void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const;
void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const;
bool IsParameterGradientOnce(
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(const OpDesc &op) const;
int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const;
......
......@@ -97,7 +97,7 @@ struct TestReduceOpHandle {
}
param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
nodes.emplace_back(new ir::Node());
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
......@@ -121,7 +121,7 @@ struct TestReduceOpHandle {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
nodes.emplace_back(new ir::Node());
auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp();
......@@ -137,7 +137,7 @@ struct TestReduceOpHandle {
op_handle_->AddInput(in_dummy_var_handle);
// add output
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
nodes.emplace_back(new ir::Node());
auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle);
......
......@@ -37,8 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue;
}
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dep_var = new DummyVarHandle(graph->nodes.back().get());
auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy"));
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
......@@ -49,15 +48,14 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
Graph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[each_var_name];
auto &var_holder = var_holders[node->Var()->Name()];
VarHandle *var = nullptr;
if (var_holder.empty()) {
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
var = new VarHandle(graph->nodes.back().get(), 0, place_offset,
each_var_name, place);
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
node->Var()->Name(), place);
var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
......@@ -66,14 +64,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
}
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
ir::Node *node,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][each_var_name];
auto &vars = graph->Get<GraphVars>("vars")[place_offset][node->Var()->Name()];
size_t version = vars.size();
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto var = new VarHandle(graph->nodes.back().get(), version, place_offset,
each_var_name, place);
auto var = new VarHandle(graph->CreateVarNode(node->Var()), version,
place_offset, node->Var()->Name(), place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
......@@ -85,8 +82,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if (!op->Outputs().empty()) {
continue;
}
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get());
auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy"));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
......@@ -34,11 +35,11 @@ typedef std::vector<
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder {
class SSAGraphBuilder : public ir::Pass {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......@@ -53,16 +54,15 @@ class SSAGraphBuilder {
*/
static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph,
const std::string &each_var_name,
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
const platform::Place &place,
size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
ir::Node *node, const platform::Place &place,
size_t place_offset);
static void AddOutputToLeafOps(Graph *graph);
};
......
......@@ -28,10 +28,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Build(std::move(graph));
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph;
return std::move(new_graph);
}
int GetVarDeviceID(const std::string& var_name) const override {
......
......@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Build(std::move(graph));
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_);
return new_graph;
return std::move(new_graph);
}
int GetVarDeviceID(const std::string& var_name) const override {
......
......@@ -13,12 +13,45 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
std::unique_ptr<Graph> graph(new Graph(program));
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = graph->CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name));
} else {
var = graph->CreateVarNode(each_var_name);
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name));
} else {
var = graph->CreateVarNode(each_var_name);
}
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}
return std::move(graph);
}
......
......@@ -39,8 +39,6 @@ class Graph {
attr_dels_.clear();
}
const ProgramDesc& Program() const { return program_; }
template <typename AttrType>
AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
......@@ -63,11 +61,30 @@ class Graph {
return attr;
}
ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
}
ir::Node* CreateOpNode(OpDesc* op_desc) {
nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get();
}
// TODO(panyx0718): Need to handle CreateOpNode(nullptr).
ir::Node* CreateVarNode(const std::string& var_name) {
var_descs_.emplace_back(new VarDesc(var_name));
nodes.emplace_back(new ir::Node(var_descs_.back().get()));
return nodes.back().get();
}
std::vector<ir::Node*> inputs;
std::vector<ir::Node*> outputs;
std::vector<std::unique_ptr<ir::Node>> nodes;
std::vector<std::unique_ptr<VarDesc>> var_descs_;
private:
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h"
......@@ -32,10 +34,12 @@ class Node {
public:
enum class Type { kNone = -1, kOperation, kVariable };
Node() : type_(Type::kNone) {}
explicit Node(Type type) : type_(type) {}
virtual ~Node() {
for (auto &attr : attrs_) {
for (auto& attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first]();
}
......@@ -47,23 +51,34 @@ class Node {
Type NodeType() const { return type_; }
template <typename AttrType>
void Set(const std::string &name, AttrType attr) {
void Set(const std::string& name, AttrType attr) {
attrs_[name] = attr;
}
template <typename AttrType>
void Set(const std::string &name, AttrType *attr,
void Set(const std::string& name, AttrType* attr,
std::function<void(void)> attr_del) {
attrs_[name] = attr;
attr_dels_[name] = attr_del;
}
std::vector<Node *> inputs;
std::vector<Node *> outputs;
VarDesc* Var() { return var_desc_; }
OpDesc* Op() { return op_desc_; }
explicit Node(VarDesc* var_desc)
: var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc)
: var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected:
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
VarDesc* var_desc_;
OpDesc* op_desc_;
Type type_;
private:
......
......@@ -20,15 +20,15 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) {
return std::move(graph);
}
};
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -131,13 +131,10 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA.");
#endif
}
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program));
std::unique_ptr<Graph> graph = builder_->Apply(ProgramToGraph(main_program));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
......
......@@ -148,6 +148,7 @@ class ParallelExecutor(object):
lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars())
]
sys.stderr.write('!!!!!!!!before\n')
self.executor = core.ParallelExecutor(
self._places,
......@@ -158,6 +159,7 @@ class ParallelExecutor(object):
set(self.persistable_vars), main.desc, loss_name
if loss_name else '', scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id)
sys.stderr.write('!!!!!!!!after\n')
self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册