提交 642fd68c 编写于 作者: Y Yancey1989

update by comment test=develop

上级 7cd6de37
......@@ -21,8 +21,6 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
......
......@@ -29,8 +29,6 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// NOTE(dzh): A ordered set for node reuse in memory optimize.
......
......@@ -221,7 +221,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
// result.Erase(kGraphOps);
return graph;
}
......
......@@ -19,12 +19,12 @@ namespace paddle {
namespace framework {
namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph) {
std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) {
graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
......@@ -60,7 +60,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
}
}
for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) {
for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
......@@ -80,14 +80,26 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs)
const framework::ProgramDesc &main_prog, std::unique_ptr<ir::Graph> &&graph)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
main_prog_(main_prog),
// TODO(Yancey1989): copy graphs is not safely since it deleted the attrs.
graphs_(SeparateMultiDevicesGraph(std::move(graph))) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_prog_.Block(0).AllOps()));
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
}
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
......
......@@ -28,16 +28,13 @@ namespace paddle {
namespace framework {
namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph);
class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs);
const framework::ProgramDesc &main_prog,
std::unique_ptr<ir::Graph> &&graph);
~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
......@@ -45,10 +42,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
std::unique_ptr<ir::Graph> &&graph);
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
framework::ProgramDesc main_prog_;
std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
......
......@@ -26,6 +26,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
} // namespace details
namespace ir {
/*
......
......@@ -305,21 +305,11 @@ ParallelExecutor::ParallelExecutor(
if (build_strategy.enable_parallel_graph_) {
#ifdef PADDLE_WITH_CUDA
auto parallel_graph =
details::SeparateMultiDevicesGraph(member_->places_, std::move(graph));
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
for (size_t i = 0; i < parallel_graph.size(); ++i) {
parallel_graph[i] =
seq_allreduce_pass->Apply(std::move(parallel_graph[i]));
}
// TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr.
member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(parallel_graph)));
exec_strategy, member_->local_scopes_, member_->places_, main_program,
std::move(graph)));
#else
PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册