提交 d8756913 编写于 作者: S sneaxiy

test=develop

上级 9ff5184f
......@@ -95,6 +95,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->Erase("enable_sequence_execution");
if (enable_sequence_execution_) {
pass->Set("enable_sequence_execution", new bool(true));
}
pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name");
......
......@@ -69,6 +69,8 @@ struct BuildStrategy {
bool enable_data_balance_{false};
bool enable_sequence_execution_{false};
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
......
......@@ -20,11 +20,12 @@ namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place)
platform::Place place, size_t place_id)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place) {}
place_(place),
place_id_(place_id) {}
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
......
......@@ -28,7 +28,8 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t place_id);
std::string Name() const override;
......@@ -36,6 +37,10 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; }
const OperatorBase &GetOp() const { return *op_; }
size_t GetPlaceId() const { return place_id_; }
protected:
void RunImpl() override;
......@@ -45,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
size_t place_id_;
};
} // namespace details
} // namespace framework
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <fstream>
#include <map>
#include <string>
#include <utility>
#include <vector>
......@@ -237,8 +238,24 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(
const ir::Graph &graph, bool enable_sequence_execution = false) {
std::vector<ir::Node *> ret;
if (enable_sequence_execution) {
VLOG(10) << "sequential execution mode is enabled";
for (auto *node : graph.Nodes()) {
if (node->IsOp()) {
ret.push_back(node);
}
}
std::sort(ret.begin(), ret.end(),
[](const ir::Node *n1, const ir::Node *n2) {
return n1->id() < n2->id();
});
} else {
ret = ir::TopologySortOperations(graph);
}
size_t last_backward = 0;
for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>(
......@@ -287,7 +304,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
bool enable_sequence_execution = Has("enable_sequence_execution") &&
Get<bool>("enable_sequence_execution");
std::vector<ir::Node *> sorted_ops =
SortOpsAndDelayOptimizeOp(*graph, enable_sequence_execution);
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
......@@ -443,6 +463,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
// Insert dependencies between computation_ops
if (enable_sequence_execution) {
InsertSequenceDependenciesBetweenComputationOps(graph.get());
}
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
......@@ -457,6 +483,34 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph;
}
void MultiDevSSAGraphBuilder::InsertSequenceDependenciesBetweenComputationOps(
ir::Graph *graph) const {
auto &ops = graph->Get<GraphOps>(kGraphOps);
// Use std::map instead of std::unordered_map for better log message
std::map<size_t, std::vector<ComputationOpHandle *>> compute_ops;
for (auto &op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
compute_ops[compute_op->GetPlaceId()].push_back(compute_op);
}
for (auto &pair : compute_ops) {
auto &ops = pair.second;
for (size_t i = 1; i < ops.size(); ++i) {
if (ops[i - 1]->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
ops[i - 1]->AddOutput(dep_var);
}
ops[i]->AddInput(ops[i - 1]->Outputs().front());
VLOG(10) << "sequential execution mode: device(" << pair.first
<< ") insert dependency between "
<< ops[i - 1]->GetOp().DebugString() << " -> "
<< ops[i]->GetOp().DebugString();
}
}
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
......@@ -513,7 +567,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id]));
local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id);
}
......@@ -630,8 +684,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
result->CreateOpNode(node->Op()), s, p, scope_idx));
CreateOpHandleIOs(result, node, scope_idx);
}
}
......
......@@ -86,6 +86,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
void InsertSequenceDependenciesBetweenComputationOps(ir::Graph *graph) const;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
......
......@@ -694,6 +694,13 @@ All parameter, weight, gradient are variables in Paddle.
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property("enable_sequence_execution",
[](const BuildStrategy &self) {
return self.enable_sequence_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequence_execution_ = b;
})
.def_property("fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册