“4a3387963583eccad4fa254aa414bc27486fc025”上不存在“paddle/phi/kernels/cpu/selu_grad_kernel.cc”
提交 f3463ecb 编写于 作者: Y Yancey1989

refine pg execution

上级 46a6cac9
......@@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) ||
strategy.enable_parallel_graph_;
strategy.num_trainers_ > 1) &&
!strategy.enable_parallel_graph_;
}
class ParallelExecutorPassBuilder : public ir::PassBuilder {
......@@ -106,7 +106,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
auto multi_devices_pass = AppendPass("multi_devices_check_pass");
multi_devices_pass->Set<bool>(kEnablePG,
new bool(strategy.enable_parallel_graph_));
if (SeqOnlyAllReduceOps(strategy)) {
AppendPass("all_reduce_deps_pass");
......@@ -180,6 +182,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
&local_scopes);
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(kEnablePG);
pass->Set<bool>(kEnablePG, new bool(true));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......
......@@ -36,11 +36,6 @@ namespace framework {
namespace details {
namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
return boost::get<int>(
......@@ -206,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, p_name, g_name);
InsertCollectiveOp(&result, node, p_name, g_name);
}
} catch (boost::bad_get e) {
}
......@@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
// result.Erase(kGraphOps);
return graph;
}
......@@ -391,20 +386,34 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
}
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
ir::Graph *result, ir::Node *node, const std::string &og) const {
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_));
scopes, places, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_));
scopes, places));
#endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
return result->Get<GraphOps>(kGraphOps).back();
};
if (!strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(local_scopes_, places_);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto p = places_[i];
std::vector<Scope *> ss{local_scopes_[i]};
std::vector<platform::Place> ps{p};
if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);
SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
......@@ -501,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
}
void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
ir::Graph *result, ir::Node *node, const std::string &p_name,
const std::string &g_name) const {
if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
CreateAllReduceOp(result, node, g_name);
}
}
......@@ -580,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
}
void ReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
ir::Graph *result, ir::Node *node, const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id);
......@@ -900,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = 0;
......@@ -915,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
CreateAllReduceOp(result, node, g_name);
}
break;
default:
......@@ -966,7 +975,8 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks)
.RequirePassAttr(paddle::framework::details::kNRanks) \
.RequirePassAttr(paddle::framework::details::kEnablePG)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder);
......
......@@ -36,6 +36,7 @@ constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
constexpr char kEnablePG[] = "enable_pg";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
......@@ -46,7 +47,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const = 0;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
......@@ -75,7 +77,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, ir::Node *node,
const std::string &og) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
......@@ -106,7 +109,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
......@@ -135,7 +139,8 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
......@@ -164,7 +169,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const;
virtual void ResetState() const;
......
......@@ -36,13 +36,20 @@ namespace details {
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase*> GraphDepVars;
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -70,6 +70,9 @@ class OpHandleBase {
auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
}
const std::map<platform::Place, platform::DeviceContext *> &DeviceContext() {
return dev_ctxes_;
}
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
dev_ctxes_[place] = ctx_;
......
......@@ -13,11 +13,74 @@
// limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps);
}
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first;
#ifdef PADDLE_WITH_CUDA
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
dev_ops.emplace_back(op);
graphs[dev_id]->AddNode(graph->ReleaseNode(op->Node()).release());
for (auto &var : op->Inputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release());
}
}
for (auto &var : op->Outputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release());
}
}
#else
PADDLE_THROW("Parallel Graph Execution only support CUDAPlace.");
#endif
}
for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
dev_vars.emplace(name_pair.first, name_pair.second);
for (auto &version_pair : name_pair.second) {
if (graph->Nodes().count(version_pair->Node())) {
graphs[dev_id]->AddNode(
graph->ReleaseNode(version_pair->Node()).release());
}
}
}
}
return graphs;
}
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
......@@ -37,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i))));
}
}
......
......@@ -14,16 +14,24 @@
#pragma once
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph);
class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
......@@ -31,11 +39,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs);
~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
// std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph();
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
......
......@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var);
}
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op);
......@@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted";
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
......
......@@ -167,6 +167,14 @@ class Graph {
return ret;
}
std::unique_ptr<ir::Node> ReleaseNode(ir::Node *node) {
std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release());
nodes_.erase(node);
node_set_.erase(node);
return ret;
}
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
......@@ -183,13 +191,6 @@ class Graph {
return nullptr;
}
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
......@@ -198,6 +199,17 @@ class Graph {
return node;
}
bool ContainNode(ir::Node *node) {
return node_set_.find(node) != node_set_.end();
}
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_;
std::map<std::string, boost::any> attrs_;
......
......@@ -59,7 +59,9 @@ template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
if (n->IsWrappedBy<T>()) {
ret.push_back(&n->Wrapper<T>());
}
}
return ret;
}
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -201,7 +202,6 @@ ParallelExecutor::ParallelExecutor(
member_->use_all_reduce_ =
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
member_->nranks_ = build_strategy.num_trainers_ * places.size();
if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1,
"If you set build_strategy.reduce with 'Reduce',"
......@@ -229,9 +229,10 @@ ParallelExecutor::ParallelExecutor(
// choice the execution strategy.
build_strategy.enable_parallel_graph_ =
EnableParallelGraphExecution(main_program, exec_strategy, build_strategy);
VLOG(1) << "Enable ParallelGraph Execution: "
<< build_strategy.enable_parallel_graph_;
if (build_strategy.enable_parallel_graph_)
VLOG(0) << "The Executor would execute the graph by ParallelGraph "
"Execution which can get better performance,"
<< "you can force it off by env FLAGS_enable_parallel_graph=0";
if (member_->use_cuda_) {
// Bcast Parameters to all GPUs
......@@ -265,40 +266,25 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs;
std::unique_ptr<ir::Graph> graph;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.enable_parallel_graph_) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_, member_->nccl_ctxs_.get());
#else
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
#endif
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (size_t i = 0; i < graphs.size(); ++i) {
graphs[i] = member_->PrepareGCAndRefCnts(
std::move(graphs[i]), static_cast<size_t>(max_memory_size));
}
graph = member_->PrepareGCAndRefCnts(std::move(graph),
static_cast<size_t>(max_memory_size));
}
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
for (auto &graph : graphs) {
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
......@@ -307,16 +293,15 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
}
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
size_t graph_num = ir::GraphNum(*graphs[0]);
size_t graph_num = ir::GraphNum(*graph);
if (graph_num > 1) {
LOG(WARNING)
<< "The number of graph should be only one, "
"but the current graph has "
<< ir::GraphNum(*graphs[0])
<< ir::GraphNum(*graph)
<< " sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, "
......@@ -325,18 +310,30 @@ ParallelExecutor::ParallelExecutor(
}
if (build_strategy.enable_parallel_graph_) {
auto parallel_graph =
details::SeparateMultiDevicesGraph(member_->places_, std::move(graph));
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
for (size_t i = 0; i < parallel_graph.size(); ++i) {
parallel_graph[i] =
seq_allreduce_pass->Apply(std::move(parallel_graph[i]));
}
member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs)));
std::move(parallel_graph)));
} else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
std::move(graph)));
} else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
std::move(graph)));
}
}
......@@ -487,8 +484,8 @@ bool ParallelExecutor::EnableParallelGraphExecution(
}
}
if (!member_->use_all_reduce_ || !member_->use_cuda_)
enable_parallel_graph = false;
// if (!member_->use_all_reduce_ || !member_->use_cuda_)
if (!member_->use_all_reduce_) enable_parallel_graph = false;
if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
......
......@@ -72,6 +72,7 @@ class TestParallelExecutorBase(unittest.TestCase):
exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
exec_strategy.num_threads = 1
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
......@@ -99,7 +100,7 @@ class TestParallelExecutorBase(unittest.TestCase):
first_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
for _ in range(iter):
run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import os
os.environ['FLAGS_enable_parallel_graph'] = str(1)
import paddle.fluid.core as core
import os
import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase
def simple_fc_net(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _init_data(self):
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
# simple_fc
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data()
self.check_network_convergence(
simple_fc_net,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce)
def test_simple_fc(self):
# use_cuda
self.check_simple_fc_convergence(True)
def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data()
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=True)
self.assertAlmostEquals(
np.mean(parallel_first_loss),
single_first_loss,
delta=1e-6, )
self.assertAlmostEquals(
np.mean(parallel_last_loss), single_last_loss, delta=1e-6)
def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy(True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册