提交 845bfd58 编写于 作者: Y Yancey1989

cleanup code

上级 41a64f6a
......@@ -19,6 +19,13 @@
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
// async nccl allreduce or sync issue:
// https://github.com/PaddlePaddle/Paddle/issues/15049
DEFINE_bool(
sync_nccl_allreduce, true,
"If set true, will call `cudaStreamSynchronize(nccl_stream)`"
"after allreduce, this mode can get better performance in some scenarios.");
namespace paddle {
namespace framework {
namespace details {
......@@ -48,111 +55,107 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
// this is a distributed or inter-process call, find a better way.
#ifdef PADDLE_WITH_CUDA
// All-reduce op_handle can run on the sub-scope, find the nccl id from
// the global scope.
if (NoDummyInputSize() == 1 &&
local_scopes_[0]->FindVar(NCCL_ID_VARNAME) == nullptr) {
#else
if (NoDummyInputSize() == 1) {
#endif
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name_)->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_,
"The name of input and output should be equal.");
}
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
// this is a distributed or inter-process call, find a better way.
// Wait input done
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name_)->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_,
"The name of input and output should be equal.");
}
if (platform::is_gpu_place(lod_tensors[0]->place())) {
if (platform::is_gpu_place(lod_tensors[0]->place())) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto &lod_tensor = *lod_tensors[i];
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
dtype = platform::ToNCCLDataType(lod_tensor.type());
}
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto &lod_tensor = *lod_tensors[i];
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
dtype = platform::ToNCCLDataType(lod_tensor.type());
}
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm, stream));
// TODO(Yancey1989): synchronize here can get better performance
// if don't use NCCL group call, but need more profiling.
if (local_scopes_.size() == 1UL) cudaStreamSynchronize(stream);
});
cudaStreamSynchronize(stream);
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
}
#else
PADDLE_THROW("Not compiled with CUDA");
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(out_var_handles[0]->name_)
->GetMutable<framework::LoDTensor>();
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
auto &tensor_cpu = trg;
TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu);
});
}
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(out_var_handles[0]->name_)
->GetMutable<framework::LoDTensor>();
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
auto &tensor_cpu = trg;
TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu);
});
}
}
}
......
......@@ -31,6 +31,8 @@ namespace framework {
namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) ||
strategy.enable_parallel_graph_;
......@@ -88,8 +90,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
multi_devices_pass->Set<int>("num_trainers",
new int(strategy_.num_trainers_));
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
......@@ -134,6 +134,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &num_parallel_devices,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
......@@ -152,6 +153,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("local_scopes");
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
pass->Set<size_t>("num_parallel_devices",
new size_t(num_parallel_devices));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase("nccl_ctxs");
......
......@@ -112,6 +112,7 @@ struct BuildStrategy {
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &num_parallel_devices_,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
......
......@@ -132,7 +132,7 @@ static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
static const char kNumTrainers[] = "num_trainers";
static const char kNumParallelDevices[] = "num_parallel_devices";
void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
......@@ -296,7 +296,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
int num_trainers = Get<int>(kNumTrainers);
size_t num_parallel_devices = Get<size_t>(kNumParallelDevices);
for (auto &node : nodes) {
if (node->IsVar() && node->Var()) {
......@@ -382,16 +382,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps(&result, node, places_.size());
}
// insert collective ops at the backpropagation; and
// insert collective ops if the graph contains mutilple places.
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (!is_forwarding &&
(places_.size() > 1 || num_trainers > 1 ||
(nccl_ctxs_ && nccl_ctxs_->contexts_.size() > 1))) {
#else
if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
#endif
if (!is_forwarding && num_parallel_devices > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
......@@ -668,12 +659,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node) const {
size_t num_parallel_devices = Get<size_t>("num_parallel_devices");
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
num_parallel_devices, local_scopes_[i], places_[i], dev_ctx);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......@@ -903,4 +895,4 @@ REGISTER_PASS(multi_devices_pass,
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNumTrainers);
.RequirePassAttr(paddle::framework::details::kNumParallelDevices);
......@@ -107,6 +107,7 @@ class ParallelExecutorPrivate {
bool own_local_scope_;
bool use_cuda_;
bool use_all_reduce_;
size_t num_parallel_devices_;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
......@@ -202,6 +203,7 @@ ParallelExecutor::ParallelExecutor(
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ =
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
member_->num_parallel_devices_ = num_trainers * places.size();
if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1,
......@@ -212,12 +214,12 @@ ParallelExecutor::ParallelExecutor(
if (build_strategy.enable_parallel_graph_) {
PADDLE_ENFORCE(
member_->use_all_reduce_,
"build_strategy.reduce should be `AllReduce` if you want to use"
"ParallelGraph executor.");
"build_strategy.reduce should be `AllReduce` if you want to enable"
"ParallelGraph.");
PADDLE_ENFORCE(
member_->use_cuda_,
"execution_strategy.use_cuda should be True if you want to use"
"ParallelGraph executor.");
"execution_strategy.use_cuda should be True if you want to enable "
"ParallelGraph.");
}
// Step 1. Bcast the bcast_vars to devs.
......@@ -241,27 +243,43 @@ ParallelExecutor::ParallelExecutor(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr;
// nccl collective would broadcast nccl id by gen_nccl_id operator.
if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
if (build_strategy.enable_parallel_graph_ && places.size() > 1) {
// parallel graph mode should initialize nccl by ncclCommInitRank since
// it call nccl operator per device per thread.
if (nccl_id_var == nullptr) {
if (nccl_id == nullptr) {
nccl_id = new ncclUniqueId();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
*member_->global_scope_->Var(NCCL_ID_VARNAME)
->GetMutable<ncclUniqueId>() = *nccl_id;
} else {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
} else if (nccl_id_var != nullptr) { // the other executor type.
// the distributed training with nccl mode would initialize the nccl id in
// startup_program.
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
} else {
// initlize NCCL by ncclCommInitAll, do not need to intialize the nccl_id.
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id));
/**
if (build_strategy.enable_parallel_graph_ && places.size() > 1) {
// parallel graph mode should initialize nccl by ncclCommInitRank since
// it call nccl operator per device per thread.
if (nccl_id_var == nullptr) {
nccl_id = new ncclUniqueId();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
*member_->global_scope_->Var(NCCL_ID_VARNAME)
->GetMutable<ncclUniqueId>() = *nccl_id;
} else {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
} else if (nccl_id_var != nullptr) { // the other executor type.
// the distributed training with nccl mode would initialize the nccl id in
// startup_program.
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
} else {
// initlize NCCL by ncclCommInitAll, do not need to intialize the nccl_id.
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id));
**/
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......@@ -274,25 +292,27 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs;
member_->num_parallel_devices_ = member_->places_.size() * num_trainers;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.enable_parallel_graph_) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, {member_->places_[i]},
loss_var_name, {member_->local_scopes_[i]},
member_->use_cuda_, member_->nccl_ctxs_.get());
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->num_parallel_devices_,
member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->use_cuda_, member_->nccl_ctxs_.get());
member_->num_parallel_devices_, member_->use_cuda_,
member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
#else
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name,
member_->local_scopes_, member_->use_cuda_);
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->num_parallel_devices_, member_->use_cuda_);
graphs.push_back(std::move(graph));
#endif
auto max_memory_size = GetEagerDeletionThreshold();
......
......@@ -60,71 +60,69 @@ class TestParallelExecutorBase(unittest.TestCase):
startup = fluid.Program()
startup.random_seed = 1 # Fix random seed
main.random_seed = 1
self.scope = fluid.Scope()
with fluid.scope_guard(self.scope):
with fluid.program_guard(main, startup):
if seed is not None:
startup.random_seed = seed
main.random_seed = seed
loss = method(use_feed=feed_dict is not None)
optimizer().minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place)
startup_exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy.enable_parallel_graph = use_parallel_graph
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda,
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
else:
exe = fluid.Executor(place=place)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time()
first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
print("%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin)))
avg_last_loss_val = np.array(last_loss).mean()
avg_first_loss_val = np.array(first_loss).mean()
if math.isnan(float(avg_last_loss_val)) or math.isnan(
float(avg_first_loss_val)):
sys.exit("got NaN loss, training failed.")
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
with fluid.program_guard(main, startup):
if seed is not None:
startup.random_seed = seed
main.random_seed = seed
loss = method(use_feed=feed_dict is not None)
optimizer().minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place)
startup_exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.enable_parallel_graph = use_parallel_graph
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda,
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
else:
exe = fluid.Executor(place=place)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time()
first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
print("%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin)))
avg_last_loss_val = np.array(last_loss).mean()
avg_first_loss_val = np.array(first_loss).mean()
if math.isnan(float(avg_last_loss_val)) or math.isnan(
float(avg_first_loss_val)):
sys.exit("got NaN loss, training failed.")
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
......@@ -175,44 +175,65 @@ class TestCRFModel(unittest.TestCase):
print(pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name])[0])
def test_update_sparse_parameter_all_reduce(self):
def _new_build_strategy(self, use_reduce=False, use_parallel_graph=False):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
if use_reduce:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.enable_parallel_graph = use_parallel_graph
return build_strategy
def test_update_sparse_parameter_all_reduce(self):
if core.is_compiled_with_cuda():
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=True)
is_sparse=True,
build_strategy=self._new_build_strategy(),
use_cuda=True)
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
is_sparse=True,
build_strategy=self._new_build_strategy(),
use_cuda=False)
def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
if core.is_compiled_with_cuda():
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=True)
is_sparse=False,
build_strategy=self._new_build_strategy(),
use_cuda=True)
self.check_network_convergence(
is_sparse=False,
build_strategy=self._new_build_strategy(
use_parallel_graph=True),
use_cuda=True)
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
if core.is_compiled_with_cuda():
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=True)
is_sparse=True,
build_strategy=self._new_build_strategy(use_reduce=True),
use_cuda=True)
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
is_sparse=True,
build_strategy=self._new_build_strategy(use_reduce=True),
use_cuda=False)
def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
if core.is_compiled_with_cuda():
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=True)
is_sparse=False,
build_strategy=self._new_build_strategy(use_reduce=True),
use_cuda=True)
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
is_sparse=False,
build_strategy=self._new_build_strategy(use_reduce=True),
use_cuda=False)
if __name__ == '__main__':
......
......@@ -312,7 +312,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=use_reduce,
optimizer=optimizer(lr_scale=lr_scale),
optimizer=optimizer(),
use_parallel_graph=use_parallel_graph)
self.assertAlmostEquals(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册