提交 64eaa4c8 编写于 作者: X Xin Pan

clean

上级 10786a24
......@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
......@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -27,6 +27,8 @@ namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
class OpHandleBase {
public:
explicit OpHandleBase(ir::Node *node) : node_(node) {}
......
......@@ -17,6 +17,9 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
// A SSA graph used by parallel executor.
struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -81,9 +81,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
}
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
GraphOps &all_ops = graph->Get<GraphOps>("ops");
for (auto &op : all_ops) {
for (auto &op : graph->Get<GraphOps>("ops")) {
if (!op->Outputs().empty()) {
continue;
}
......
......@@ -18,7 +18,9 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
......@@ -29,10 +31,20 @@ namespace paddle {
namespace framework {
namespace details {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder : public ir::Pass {
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......
......@@ -18,8 +18,8 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......
......@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......
......@@ -173,9 +173,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
ir::Node *fetch_n = new ir::Node("fetch");
auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_);
temp_nodes->emplace_back(fetch_n);
temp_nodes->emplace_back(new ir::Node("fetch"));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
&local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......@@ -186,11 +186,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
ir::Node *dummy_n = new ir::Node("fetch");
auto *fetch_dummy = new DummyVarHandle(dummy_n);
temp_nodes->emplace_back(new ir::Node("fetch"));
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
temp_nodes->emplace_back(dummy_n);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op);
}
......
......@@ -28,6 +28,9 @@ namespace framework {
namespace details {
class OpHandleBase;
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
//
// VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
......
......@@ -19,37 +19,35 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
std::unique_ptr<Graph> graph(new Graph(program));
Graph::Graph(const ProgramDesc &program) : program_(program) {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = graph->CreateOpNode(op);
ir::Node *node = CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name));
var = CreateVarNode(all_vars.at(each_var_name));
} else {
LOG(ERROR) << "input var not in all_var list: " << each_var_name;
var = graph->CreateEmptyNode(each_var_name);
// TODO(paddle-dev): Seems some assumption doesn't hold?
LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name);
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name));
ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}
return std::move(graph);
}
} // namespace framework
} // namespace paddle
......@@ -29,7 +29,7 @@ namespace framework {
class Graph {
public:
explicit Graph(const ProgramDesc& program) : program_(program) {}
explicit Graph(const ProgramDesc& program);
virtual ~Graph() {
for (auto& attr : attrs_) {
......@@ -46,6 +46,7 @@ class Graph {
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
......@@ -53,14 +54,6 @@ class Graph {
};
}
template <typename AttrType>
AttrType* Erase(const std::string& attr_name) {
AttrType* attr = boost::any_cast<AttrType*>(attrs_[attr_name]);
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
return attr;
}
ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
......@@ -71,14 +64,14 @@ class Graph {
return nodes.back().get();
}
// TODO(panyx0718): Need to handle CreateOpNode(nullptr).
// TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph.
// node should either be a executable kOperation or a kVariable. kNone
// node is a temporary solution.
ir::Node* CreateEmptyNode(const std::string& name) {
nodes.emplace_back(new ir::Node(name));
return nodes.back().get();
}
std::vector<ir::Node*> inputs;
std::vector<ir::Node*> outputs;
std::vector<std::unique_ptr<ir::Node>> nodes;
private:
......@@ -88,7 +81,5 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_;
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program);
} // namespace framework
} // namespace paddle
......@@ -14,17 +14,11 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <map>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <tuple>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
......@@ -132,7 +132,8 @@ ParallelExecutor::ParallelExecutor(
#endif
}
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Apply(ProgramToGraph(main_program));
std::unique_ptr<Graph> graph(new Graph(main_program));
graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册