提交 d8756913 编写于 作者: S sneaxiy

test=develop

上级 9ff5184f
...@@ -95,6 +95,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -95,6 +95,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") { if (pass->Type() == "multi_devices_pass") {
pass->Erase("enable_sequence_execution");
if (enable_sequence_execution_) {
pass->Set("enable_sequence_execution", new bool(true));
}
pass->Erase("places"); pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places); pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name"); pass->Erase("loss_var_name");
......
...@@ -69,6 +69,8 @@ struct BuildStrategy { ...@@ -69,6 +69,8 @@ struct BuildStrategy {
bool enable_data_balance_{false}; bool enable_data_balance_{false};
bool enable_sequence_execution_{false};
// User normally doesn't need to call this API. // User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes // The PassBuilder allows for more customized insert, remove of passes
// from python side. // from python side.
......
...@@ -20,11 +20,12 @@ namespace paddle { ...@@ -20,11 +20,12 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place) platform::Place place, size_t place_id)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())), op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place) {} place_(place),
place_id_(place_id) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
......
...@@ -28,7 +28,8 @@ namespace framework { ...@@ -28,7 +28,8 @@ namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
public: public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t place_id);
std::string Name() const override; std::string Name() const override;
...@@ -36,6 +37,10 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +37,10 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
const OperatorBase &GetOp() const { return *op_; }
size_t GetPlaceId() const { return place_id_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -45,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -45,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
size_t place_id_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <map>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -237,8 +238,24 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -237,8 +238,24 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
// some optimizer ops might not depend on any nodes), we manually move all // some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes. // optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future. // However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph); const ir::Graph &graph, bool enable_sequence_execution = false) {
std::vector<ir::Node *> ret;
if (enable_sequence_execution) {
VLOG(10) << "sequential execution mode is enabled";
for (auto *node : graph.Nodes()) {
if (node->IsOp()) {
ret.push_back(node);
}
}
std::sort(ret.begin(), ret.end(),
[](const ir::Node *n1, const ir::Node *n2) {
return n1->id() < n2->id();
});
} else {
ret = ir::TopologySortOperations(graph);
}
size_t last_backward = 0; size_t last_backward = 0;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>( if (boost::get<int>(
...@@ -287,7 +304,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -287,7 +304,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
// Give the topology sort order and rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph); bool enable_sequence_execution = Has("enable_sequence_execution") &&
Get<bool>("enable_sequence_execution");
std::vector<ir::Node *> sorted_ops =
SortOpsAndDelayOptimizeOp(*graph, enable_sequence_execution);
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
...@@ -443,6 +463,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -443,6 +463,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
} }
} }
// Insert dependencies between computation_ops
if (enable_sequence_execution) {
InsertSequenceDependenciesBetweenComputationOps(graph.get());
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
hazards need to be handled. hazards need to be handled.
...@@ -457,6 +483,34 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -457,6 +483,34 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph; return graph;
} }
void MultiDevSSAGraphBuilder::InsertSequenceDependenciesBetweenComputationOps(
ir::Graph *graph) const {
auto &ops = graph->Get<GraphOps>(kGraphOps);
// Use std::map instead of std::unordered_map for better log message
std::map<size_t, std::vector<ComputationOpHandle *>> compute_ops;
for (auto &op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
compute_ops[compute_op->GetPlaceId()].push_back(compute_op);
}
for (auto &pair : compute_ops) {
auto &ops = pair.second;
for (size_t i = 1; i < ops.size(); ++i) {
if (ops[i - 1]->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
ops[i - 1]->AddOutput(dep_var);
}
ops[i]->AddInput(ops[i - 1]->Outputs().front());
VLOG(10) << "sequential execution mode: device(" << pair.first
<< ") insert dependency between "
<< ops[i - 1]->GetOp().DebugString() << " -> "
<< ops[i]->GetOp().DebugString();
}
}
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0); PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
...@@ -513,7 +567,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -513,7 +567,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id])); local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
...@@ -630,8 +684,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -630,8 +684,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); result->CreateOpNode(node->Op()), s, p, scope_idx));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
} }
......
...@@ -86,6 +86,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -86,6 +86,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
void InsertSequenceDependenciesBetweenComputationOps(ir::Graph *graph) const;
mutable std::string loss_var_name_; mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_; mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_; mutable std::vector<Scope *> local_scopes_;
......
...@@ -694,6 +694,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -694,6 +694,13 @@ All parameter, weight, gradient are variables in Paddle.
"enable_data_balance", "enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; }, [](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property("enable_sequence_execution",
[](const BuildStrategy &self) {
return self.enable_sequence_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequence_execution_ = b;
})
.def_property("fuse_elewise_add_act_ops", .def_property("fuse_elewise_add_act_ops",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_; return self.fuse_elewise_add_act_ops_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册