提交 64eaa4c8 编写于 作者: X Xin Pan

clean

上级 10786a24
...@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod ...@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
...@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ...@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -27,6 +27,8 @@ namespace details { ...@@ -27,6 +27,8 @@ namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
explicit OpHandleBase(ir::Node *node) : node_(node) {} explicit OpHandleBase(ir::Node *node) : node_(node) {}
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
// A SSA graph used by parallel executor.
struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -81,9 +81,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ...@@ -81,9 +81,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
} }
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
GraphOps &all_ops = graph->Get<GraphOps>("ops"); for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : all_ops) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
......
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -29,10 +31,20 @@ namespace paddle { ...@@ -29,10 +31,20 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector< typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>> std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder : public ir::Pass { class SSAGraphBuilder : public ir::Pass {
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_checker.h" #include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string> #include <string>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -173,9 +173,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -173,9 +173,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name); auto &vars = fetched_vars.at(var_name);
ir::Node *fetch_n = new ir::Node("fetch"); temp_nodes->emplace_back(new ir::Node("fetch"));
auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
temp_nodes->emplace_back(fetch_n); &local_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
...@@ -186,11 +186,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -186,11 +186,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
ir::Node *dummy_n = new ir::Node("fetch"); temp_nodes->emplace_back(new ir::Node("fetch"));
auto *fetch_dummy = new DummyVarHandle(dummy_n); auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
temp_nodes->emplace_back(dummy_n);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op); this->InsertPendingOp(pending_ops, op);
} }
......
...@@ -28,6 +28,9 @@ namespace framework { ...@@ -28,6 +28,9 @@ namespace framework {
namespace details { namespace details {
class OpHandleBase; class OpHandleBase;
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
//
// VarHandleBase is the var node in the dependency graph. // VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
......
...@@ -19,37 +19,35 @@ limitations under the License. */ ...@@ -19,37 +19,35 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
std::unique_ptr<Graph> graph(new Graph(program));
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = graph->CreateOpNode(op); ir::Node *node = CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) { if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name)); var = CreateVarNode(all_vars.at(each_var_name));
} else { } else {
LOG(ERROR) << "input var not in all_var list: " << each_var_name; // TODO(paddle-dev): Seems some assumption doesn't hold?
var = graph->CreateEmptyNode(each_var_name); LOG(ERROR) << op->Type()
<< " input var not in all_var list: " << each_var_name;
var = CreateEmptyNode(each_var_name);
} }
node->inputs.push_back(var); node->inputs.push_back(var);
var->outputs.push_back(node); var->outputs.push_back(node);
} }
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name)); ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
node->outputs.push_back(var); node->outputs.push_back(var);
var->inputs.push_back(node); var->inputs.push_back(node);
} }
} }
return std::move(graph);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,7 +29,7 @@ namespace framework { ...@@ -29,7 +29,7 @@ namespace framework {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc& program) : program_(program) {} explicit Graph(const ProgramDesc& program);
virtual ~Graph() { virtual ~Graph() {
for (auto& attr : attrs_) { for (auto& attr : attrs_) {
...@@ -46,6 +46,7 @@ class Graph { ...@@ -46,6 +46,7 @@ class Graph {
template <typename AttrType> template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) { void Set(const std::string& attr_name, AttrType* attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(3) << "deleting " << attr_name;
...@@ -53,14 +54,6 @@ class Graph { ...@@ -53,14 +54,6 @@ class Graph {
}; };
} }
template <typename AttrType>
AttrType* Erase(const std::string& attr_name) {
AttrType* attr = boost::any_cast<AttrType*>(attrs_[attr_name]);
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
return attr;
}
ir::Node* CreateVarNode(VarDesc* var_desc) { ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc)); nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get(); return nodes.back().get();
...@@ -71,14 +64,14 @@ class Graph { ...@@ -71,14 +64,14 @@ class Graph {
return nodes.back().get(); return nodes.back().get();
} }
// TODO(panyx0718): Need to handle CreateOpNode(nullptr). // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph.
// node should either be a executable kOperation or a kVariable. kNone
// node is a temporary solution.
ir::Node* CreateEmptyNode(const std::string& name) { ir::Node* CreateEmptyNode(const std::string& name) {
nodes.emplace_back(new ir::Node(name)); nodes.emplace_back(new ir::Node(name));
return nodes.back().get(); return nodes.back().get();
} }
std::vector<ir::Node*> inputs;
std::vector<ir::Node*> outputs;
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
private: private:
...@@ -88,7 +81,5 @@ class Graph { ...@@ -88,7 +81,5 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,17 +14,11 @@ limitations under the License. */ ...@@ -14,17 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <map>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/ir/graph.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -132,7 +132,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -132,7 +132,8 @@ ParallelExecutor::ParallelExecutor(
#endif #endif
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Apply(ProgramToGraph(main_program)); std::unique_ptr<Graph> graph(new Graph(main_program));
graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册