未验证 提交 1b031987 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Support sparse tensor in refactored reducer (#40836)

* [Dygraph] Support sparse tensor in refactored reducer

* add uts

* refactor

* update

* fix bugs
上级 625dd722
...@@ -360,6 +360,7 @@ void EagerReducer::InitializeGroups( ...@@ -360,6 +360,7 @@ void EagerReducer::InitializeGroups(
is_sparse_gradient_[tensor_indices_.front()]) { is_sparse_gradient_[tensor_indices_.front()]) {
// process the sparse gradient. one sparse, one group // process the sparse gradient. one sparse, one group
group.dtype_ = first_var.dtype(); group.dtype_ = first_var.dtype();
group.is_sparse_ = true;
} else { } else {
// process the dense gradient. // process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group); InitializeDenseGroups(tensor_indices_, &group);
...@@ -391,6 +392,12 @@ void EagerReducer::InitializeDenseGroups( ...@@ -391,6 +392,12 @@ void EagerReducer::InitializeDenseGroups(
auto &tensor = tensors_[tensor_index]; auto &tensor = tensors_[tensor_index];
auto &tensor_name = tensor.name(); auto &tensor_name = tensor.name();
PADDLE_ENFORCE_EQ(is_sparse_gradient_[tensor_index], false,
platform::errors::PreconditionNotMet(
"Tensor %s's GRAD must be Tensor, but received "
"GRAD is SelectedRows",
tensor_name));
PADDLE_ENFORCE_EQ(tensor.is_initialized(), true, PADDLE_ENFORCE_EQ(tensor.is_initialized(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Tensor %s is not initialized.", tensor_name)); "Tensor %s is not initialized.", tensor_name));
...@@ -480,6 +487,7 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) { ...@@ -480,6 +487,7 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
next_group_ = 0; next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) {
group.pending_ = group.tensor_indices_.size(); group.pending_ = group.tensor_indices_.size();
group.sparse_contents_ = Tensor();
}); });
// reinitialize vars_marked_ready_ for next iteration // reinitialize vars_marked_ready_ for next iteration
...@@ -544,9 +552,6 @@ void EagerReducer::AddDistHook(size_t var_index) { ...@@ -544,9 +552,6 @@ void EagerReducer::AddDistHook(size_t var_index) {
return; return;
} }
auto &tensor = tensors_[var_index];
const auto &grad_node = GetGradNodeFromTensor(&tensor);
VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name() VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "@Grad] arrived and triggered disthook"; << "@Grad] arrived and triggered disthook";
...@@ -608,15 +613,18 @@ void EagerReducer::MarkVarReady(const size_t var_index, ...@@ -608,15 +613,18 @@ void EagerReducer::MarkVarReady(const size_t var_index,
auto &group_tensor = group.dense_tensors_[inside_group_index]; auto &group_tensor = group.dense_tensors_[inside_group_index];
const auto length = group.length_[inside_group_index]; const auto length = group.length_[inside_group_index];
if (!group.is_sparse_) {
if (is_used_var) { if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta(); auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad(); auto &grad_tensor =
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
group_tensor group_tensor
.ShareDataWith( .ShareDataWith(*(
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl()))) std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()}); .Resize({grad_tensor.numel()});
} else { } else {
// TODO(shenliang03): maybe save the memory by avoiding tensor construction // TODO(shenliang03): maybe save the memory by avoiding tensor
// construction
if (!group_tensor.initialized()) { if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)}); group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_); group_tensor.mutable_data(inner_place_, group.dtype_);
...@@ -625,17 +633,50 @@ void EagerReducer::MarkVarReady(const size_t var_index, ...@@ -625,17 +633,50 @@ void EagerReducer::MarkVarReady(const size_t var_index,
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad"; VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]); auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
group_tensor group_tensor
.ShareDataWith(*( .ShareDataWith(*(std::dynamic_pointer_cast<phi::DenseTensor>(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor->impl()))) grad_tensor->impl())))
.Resize({length}); .Resize({length});
} else { } else {
VLOG(3) << "Tensor[" << tensors_[var_index].name() VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad"; << "] doesn't have grad";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_); auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)}); group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0); phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
} }
} }
} else {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
// process sparse group
PADDLE_ENFORCE_EQ(
HasGrad(var_index), true,
platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] should have gradient. "
"Currently, DataParallel does not support sparse "
"parameters without generating gradients during training. "
"For example, if is_sparese=True is used in Embedding, "
"the current step of this parameter cannot generate gradient "
"because of stop_gradient/detatch, where error will occur.",
var_index, tensors_[var_index].name()));
// need to check tensor type
PADDLE_ENFORCE_EQ(
grad_tensor.is_selected_rows(), true,
platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] must have a selectedrows gradient. "
"Before forward pass, the parameter type is inferred to be "
"SelectedRows, but after backward pass, its actual type becomes "
"LodTensor. It is currently not supported by DataParallel. "
"For example, if sparse embedding is used, and the weight of "
"embedding is shared with subsequent dense parameters, then "
"the parameter gradient of the embedding will be converted "
"to dense parameters.",
var_index, tensors_[var_index].name()));
group.sparse_contents_.set_impl(grad_tensor.impl());
}
if (--group.pending_ == 0) { if (--group.pending_ == 0) {
// can start allreduce // can start allreduce
...@@ -666,8 +707,12 @@ void EagerReducer::MarkGroupReady(size_t group_index) { ...@@ -666,8 +707,12 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) { ++next_group_) {
UNUSED auto &group = groups_[next_group_]; UNUSED auto &group = groups_[next_group_];
if (group.is_sparse_) {
AllReduceSparse(&group, next_group_);
} else {
FusedAllReduceSchedule(&group, next_group_); FusedAllReduceSchedule(&group, next_group_);
} }
}
} }
bool EagerReducer::HasGrad(size_t var_index) { bool EagerReducer::HasGrad(size_t var_index) {
...@@ -725,6 +770,11 @@ void EagerReducer::ProcessUnusedDenseVars() { ...@@ -725,6 +770,11 @@ void EagerReducer::ProcessUnusedDenseVars() {
const auto inside_group_index = var_locator.inside_group_index; const auto inside_group_index = var_locator.inside_group_index;
auto &src_tensor = group.dense_tensors_[inside_group_index]; auto &src_tensor = group.dense_tensors_[inside_group_index];
// sparse no need to check and no support find_unused_parameters
if (group.is_sparse_) {
continue;
}
Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor)); Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));
auto dest_var_base = tensors_[var_index]; auto dest_var_base = tensors_[var_index];
...@@ -739,12 +789,16 @@ void EagerReducer::FinalizeBackward() { ...@@ -739,12 +789,16 @@ void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false; groups_need_finalize_ = false;
grad_need_hooks_ = false; grad_need_hooks_ = false;
for (auto &group : groups_) { for (auto &group : groups_) {
if (!group.is_sparse_) {
group.task->Synchronize(); group.task->Synchronize();
} }
}
for (auto &group : groups_) { for (auto &group : groups_) {
if (!group.is_sparse_) {
group.SplitTensors(inner_place_); group.SplitTensors(inner_place_);
} }
}
if (find_unused_vars_each_step_) { if (find_unused_vars_each_step_) {
ProcessUnusedDenseVars(); ProcessUnusedDenseVars();
...@@ -778,6 +832,127 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, ...@@ -778,6 +832,127 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
// split in FinalizeBackward() // split in FinalizeBackward()
} }
void EagerReducer::AllReduceSparse(EagerGroup *group,
const int curr_group_index) {
// div nranks
Tensor sparse_tensor(group->sparse_contents_);
paddle::experimental::scale_(sparse_tensor, 1.0 / nranks_, 0.0, false);
VLOG(3) << "sparse_group [" << curr_group_index << "] start allreduce.";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
if (platform::is_gpu_place(inner_place_)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(inner_place_));
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(inner_place_)) {
dev_ctx = static_cast<platform::CPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(inner_place_));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split grad tensor not supported on place (%s)", inner_place_));
}
auto src = std::dynamic_pointer_cast<phi::SelectedRows>(
group->sparse_contents_.impl());
const auto &src_rows = src->rows();
const auto &rank_ = process_group_->GetRank();
const auto &size_ = process_group_->GetSize();
framework::Vector<int64_t> rows_num_vector(size_);
rows_num_vector[rank_] = static_cast<int64_t>(src_rows.size());
Tensor rows_num_tensor = paddle::experimental::empty(
IntArray({static_cast<int64_t>(size_)}), DataType::INT64, inner_place_);
auto *rows_num_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(rows_num_tensor.impl()).get();
framework::TensorFromVector<int64_t>(rows_num_vector, *dev_ctx,
rows_num_dense_tensor);
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {rows_num_tensor};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
framework::TensorToVector<int64_t>(*rows_num_dense_tensor, *dev_ctx,
&rows_num_vector);
dev_ctx->Wait();
const auto *cpu_rows_num_ptr = rows_num_vector.data();
auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
static_cast<int64_t>(0));
VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
<< ", total rows number: " << rows_num
<< ", height: " << src->height();
dev_ctx->Wait();
if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
[&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
// allgather is used to speed up the allreduce by replacing broadcast.
VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
Tensor dst_rows_tensor =
paddle::experimental::empty(IntArray({static_cast<int64_t>(rows_num)}),
DataType::INT64, inner_place_);
Tensor src_rows_tensor = paddle::experimental::empty(
IntArray({static_cast<int64_t>((*src).rows().size())}), DataType::INT64,
inner_place_);
auto *src_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(src_rows_tensor.impl())
.get();
framework::TensorFromVector<int64_t>((*src).rows(), *dev_ctx,
src_rows_dense_tensor);
std::vector<Tensor> src_rows_tensors = {src_rows_tensor};
std::vector<Tensor> dst_rows_tensors = {dst_rows_tensor};
process_group_->AllGather(src_rows_tensors, dst_rows_tensors)
->Synchronize();
framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
auto *dst_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(dst_rows_tensor.impl())
.get();
framework::TensorToVector<int64_t>(*dst_rows_dense_tensor, *dev_ctx,
&dst_rows_vector);
dev_ctx->Wait();
Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();
dst_shape[dst_shape.size() - 2] = rows_num;
auto dst_dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
paddle::experimental::full(IntArray(dst_shape), 0,
src_value_tensor.dtype(), inner_place_)
.impl());
auto dst =
std::make_shared<phi::SelectedRows>(dst_rows_vector, (*src).height());
*(dst->mutable_value()) = *dst_dense_tensor;
Tensor dst_value_tensor(std::make_shared<phi::DenseTensor>(dst->value()));
std::vector<Tensor> src_value_tensors = {src_value_tensor};
std::vector<Tensor> dst_value_tensors = {dst_value_tensor};
process_group_->AllGather(src_value_tensors, dst_value_tensors)
->Synchronize();
src->set_rows(dst_rows_vector);
*(src->mutable_value()) =
*(std::dynamic_pointer_cast<phi::DenseTensor>(dst_value_tensor.impl()));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("This case is not supported."));
}
}
std::ostream &operator<<(std::ostream &out, const EagerGroup &group) { std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
const auto &tensors_ = group.tensor_indices_; const auto &tensors_ = group.tensor_indices_;
out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size() out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size()
......
...@@ -47,6 +47,8 @@ std::vector<std::vector<size_t>> Eager_AssignGroupBySize( ...@@ -47,6 +47,8 @@ std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
class EagerGroup { class EagerGroup {
public: public:
Tensor dense_contents_; Tensor dense_contents_;
Tensor sparse_contents_;
bool is_sparse_ = false;
// for concat kernel // for concat kernel
std::vector<phi::DenseTensor> dense_tensors_; std::vector<phi::DenseTensor> dense_tensors_;
...@@ -104,6 +106,7 @@ class EagerReducer { ...@@ -104,6 +106,7 @@ class EagerReducer {
void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index); void MarkGroupReady(const size_t group_index);
void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index); void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index);
void AllReduceSparse(EagerGroup *group, const int curr_group_index);
void FinalizeBackward(); void FinalizeBackward();
void TraverseBackwardGraph(const std::vector<Tensor> &outputs); void TraverseBackwardGraph(const std::vector<Tensor> &outputs);
void ProcessUnusedDenseVars(); void ProcessUnusedDenseVars();
......
...@@ -1128,7 +1128,7 @@ set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) ...@@ -1128,7 +1128,7 @@ set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150)
...@@ -1153,8 +1153,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1153,8 +1153,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300) set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 150)
endif() endif()
endif() endif()
......
...@@ -42,7 +42,6 @@ class SimpleNet(fluid.Layer): ...@@ -42,7 +42,6 @@ class SimpleNet(fluid.Layer):
dtype=dtype, dtype=dtype,
is_sparse=is_sparse, is_sparse=is_sparse,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='embedding_param',
initializer=fluid.initializer.UniformInitializer( initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))) low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter( self.softmax_weight = self.create_parameter(
...@@ -103,8 +102,8 @@ class TestSparseEmbedding(TestParallelDyGraphRunnerBase): ...@@ -103,8 +102,8 @@ class TestSparseEmbedding(TestParallelDyGraphRunnerBase):
train_reader = paddle.batch( train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True) fake_sample_reader(), batch_size=batch_size, drop_last=True)
optimizer = fluid.optimizer.SGD(learning_rate=0.001, optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters()) parameters=model.parameters())
return model, train_reader, optimizer return model, train_reader, optimizer
......
...@@ -40,7 +40,6 @@ class SimpleNet(Layer): ...@@ -40,7 +40,6 @@ class SimpleNet(Layer):
self.hidden_size, self.hidden_size,
sparse=True, sparse=True,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
name='embedding_param',
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-init_scale, high=init_scale))) low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter( self.softmax_weight = self.create_parameter(
......
...@@ -39,7 +39,6 @@ class SimpleNet(Layer): ...@@ -39,7 +39,6 @@ class SimpleNet(Layer):
self.hidden_size, self.hidden_size,
sparse=is_sparse, sparse=is_sparse,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
name='embedding_param',
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-init_scale, high=init_scale))) low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter( self.softmax_weight = self.create_parameter(
......
...@@ -64,5 +64,47 @@ class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner): ...@@ -64,5 +64,47 @@ class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner):
test_class=TestSparseEmbedding, delta=1e-5) test_class=TestSparseEmbedding, delta=1e-5)
class TestParallelDygraphSparseEmdeddingEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._eager_mode = True
self._dygraph = True
def test_sparse_embedding(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingFP64Eager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_sparse_embedding_fp64(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding_fp64.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingSpawnEager(TestDistSpawnRunner):
def _args_config(self, args):
args.eager_mode = True
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbedding, delta=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -55,5 +55,35 @@ class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase): ...@@ -55,5 +55,35 @@ class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphSparseEmdeddingEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingEagerFP64_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding_fp64(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_fp64.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -48,5 +48,32 @@ class TestParallelDygraphSparseEmdeddingOverHeightSpawn(TestDistSpawnRunner): ...@@ -48,5 +48,32 @@ class TestParallelDygraphSparseEmdeddingOverHeightSpawn(TestDistSpawnRunner):
test_class=TestSparseEmbeddingOverHeight, delta=1e-5) test_class=TestSparseEmbeddingOverHeight, delta=1e-5)
class TestParallelDygraphSparseEmdeddingOverHeightEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_sparse_embedding(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding_over_height.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingOverHeightSpawnEager(
TestDistSpawnRunner):
def _args_config(self, args):
args.eager_mode = True
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbeddingOverHeight, delta=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -40,5 +40,20 @@ class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase): ...@@ -40,5 +40,20 @@ class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphSparseEmdeddingOverHeightEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_over_height.py",
delta=1e-7,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -36,5 +36,21 @@ class TestParallelDygraphMnist(TestDistBase): ...@@ -36,5 +36,21 @@ class TestParallelDygraphMnist(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphMnistEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sync_batch_norm.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -65,5 +65,21 @@ class TestParallelDygraphTransformerAccGrad(TestDistBase): ...@@ -65,5 +65,21 @@ class TestParallelDygraphTransformerAccGrad(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphTransformerEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_transformer(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -57,5 +57,20 @@ class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase): ...@@ -57,5 +57,20 @@ class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphTransformerEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_transformer(self):
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -86,5 +86,71 @@ class TestParallelDygraphSharedUnusedVariables(TestDistBase): ...@@ -86,5 +86,71 @@ class TestParallelDygraphSharedUnusedVariables(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphUnusedVarEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_unused_variables.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestDygraphUnusedVarEager(TestParallelDygraphUnusedVar):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
class TestSparseEmbeddingUnusedVarsSpawnEager(TestDistSpawnRunner):
def _args_config(self, args):
args.eager_mode = True
def test_mnist_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbeddingUnusedVars, delta=1e-5)
class TestParallelDygraphNoVarEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_none_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSharedUnusedVariablesEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_shared_unused_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册