From 1b031987c563038dc33370182e978ffe32b54abe Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 4 Apr 2022 18:48:21 +0800 Subject: [PATCH] [Dygraph] Support sparse tensor in refactored reducer (#40836) * [Dygraph] Support sparse tensor in refactored reducer * add uts * refactor * update * fix bugs --- .../fluid/distributed/collective/reducer.cc | 233 +++++++++++++++--- paddle/fluid/distributed/collective/reducer.h | 3 + .../fluid/tests/unittests/CMakeLists.txt | 7 +- .../parallel_dygraph_sparse_embedding.py | 5 +- .../parallel_dygraph_sparse_embedding_fp64.py | 1 - .../parallel_dygraph_unused_variables.py | 1 - .../test_parallel_dygraph_sparse_embedding.py | 42 ++++ ..._parallel_dygraph_sparse_embedding_gloo.py | 30 +++ ...el_dygraph_sparse_embedding_over_height.py | 27 ++ ...graph_sparse_embedding_over_height_gloo.py | 15 ++ .../test_parallel_dygraph_sync_batch_norm.py | 16 ++ .../test_parallel_dygraph_transformer.py | 16 ++ .../test_parallel_dygraph_transformer_gloo.py | 15 ++ .../test_parallel_dygraph_unused_variables.py | 66 +++++ 14 files changed, 440 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index ec02406efc..71741515c9 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -360,6 +360,7 @@ void EagerReducer::InitializeGroups( is_sparse_gradient_[tensor_indices_.front()]) { // process the sparse gradient. one sparse, one group group.dtype_ = first_var.dtype(); + group.is_sparse_ = true; } else { // process the dense gradient. InitializeDenseGroups(tensor_indices_, &group); @@ -391,6 +392,12 @@ void EagerReducer::InitializeDenseGroups( auto &tensor = tensors_[tensor_index]; auto &tensor_name = tensor.name(); + PADDLE_ENFORCE_EQ(is_sparse_gradient_[tensor_index], false, + platform::errors::PreconditionNotMet( + "Tensor %s's GRAD must be Tensor, but received " + "GRAD is SelectedRows", + tensor_name)); + PADDLE_ENFORCE_EQ(tensor.is_initialized(), true, platform::errors::PreconditionNotMet( "Tensor %s is not initialized.", tensor_name)); @@ -480,6 +487,7 @@ void EagerReducer::PrepareForBackward(const std::vector &outputs) { next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { group.pending_ = group.tensor_indices_.size(); + group.sparse_contents_ = Tensor(); }); // reinitialize vars_marked_ready_ for next iteration @@ -544,9 +552,6 @@ void EagerReducer::AddDistHook(size_t var_index) { return; } - auto &tensor = tensors_[var_index]; - const auto &grad_node = GetGradNodeFromTensor(&tensor); - VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name() << "@Grad] arrived and triggered disthook"; @@ -608,33 +613,69 @@ void EagerReducer::MarkVarReady(const size_t var_index, auto &group_tensor = group.dense_tensors_[inside_group_index]; const auto length = group.length_[inside_group_index]; - if (is_used_var) { - auto *autograd_meta = tensors_[var_index].get_autograd_meta(); - auto &grad_tensor = static_cast(autograd_meta)->Grad(); - group_tensor - .ShareDataWith( - *(std::dynamic_pointer_cast(grad_tensor.impl()))) - .Resize({grad_tensor.numel()}); - } else { - // TODO(shenliang03): maybe save the memory by avoiding tensor construction - if (!group_tensor.initialized()) { - group_tensor.Resize({static_cast(length)}); - group_tensor.mutable_data(inner_place_, group.dtype_); - } - if (HasGrad(var_index)) { - VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad"; - auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]); + if (!group.is_sparse_) { + if (is_used_var) { + auto *autograd_meta = tensors_[var_index].get_autograd_meta(); + auto &grad_tensor = + static_cast(autograd_meta)->Grad(); group_tensor .ShareDataWith(*( - std::dynamic_pointer_cast(grad_tensor->impl()))) - .Resize({length}); + std::dynamic_pointer_cast(grad_tensor.impl()))) + .Resize({grad_tensor.numel()}); } else { - VLOG(3) << "Tensor[" << tensors_[var_index].name() - << "] doesn't have grad"; - auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_); - group_tensor.Resize({static_cast(length)}); - phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0); + // TODO(shenliang03): maybe save the memory by avoiding tensor + // construction + if (!group_tensor.initialized()) { + group_tensor.Resize({static_cast(length)}); + group_tensor.mutable_data(inner_place_, group.dtype_); + } + if (HasGrad(var_index)) { + VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad"; + auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]); + group_tensor + .ShareDataWith(*(std::dynamic_pointer_cast( + grad_tensor->impl()))) + .Resize({length}); + } else { + VLOG(3) << "Tensor[" << tensors_[var_index].name() + << "] doesn't have grad"; + auto *dev_ctx = + platform::DeviceContextPool::Instance().Get(inner_place_); + group_tensor.Resize({static_cast(length)}); + phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0); + } } + } else { + auto *autograd_meta = tensors_[var_index].get_autograd_meta(); + auto &grad_tensor = static_cast(autograd_meta)->Grad(); + + // process sparse group + PADDLE_ENFORCE_EQ( + HasGrad(var_index), true, + platform::errors::PreconditionNotMet( + "The sparse parameter[%d][%s] should have gradient. " + "Currently, DataParallel does not support sparse " + "parameters without generating gradients during training. " + "For example, if is_sparese=True is used in Embedding, " + "the current step of this parameter cannot generate gradient " + "because of stop_gradient/detatch, where error will occur.", + var_index, tensors_[var_index].name())); + + // need to check tensor type + PADDLE_ENFORCE_EQ( + grad_tensor.is_selected_rows(), true, + platform::errors::PreconditionNotMet( + "The sparse parameter[%d][%s] must have a selectedrows gradient. " + "Before forward pass, the parameter type is inferred to be " + "SelectedRows, but after backward pass, its actual type becomes " + "LodTensor. It is currently not supported by DataParallel. " + "For example, if sparse embedding is used, and the weight of " + "embedding is shared with subsequent dense parameters, then " + "the parameter gradient of the embedding will be converted " + "to dense parameters.", + var_index, tensors_[var_index].name())); + + group.sparse_contents_.set_impl(grad_tensor.impl()); } if (--group.pending_ == 0) { @@ -666,7 +707,11 @@ void EagerReducer::MarkGroupReady(size_t group_index) { for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { UNUSED auto &group = groups_[next_group_]; - FusedAllReduceSchedule(&group, next_group_); + if (group.is_sparse_) { + AllReduceSparse(&group, next_group_); + } else { + FusedAllReduceSchedule(&group, next_group_); + } } } @@ -725,6 +770,11 @@ void EagerReducer::ProcessUnusedDenseVars() { const auto inside_group_index = var_locator.inside_group_index; auto &src_tensor = group.dense_tensors_[inside_group_index]; + // sparse no need to check and no support find_unused_parameters + if (group.is_sparse_) { + continue; + } + Tensor grad_value(std::make_shared(src_tensor)); auto dest_var_base = tensors_[var_index]; @@ -739,11 +789,15 @@ void EagerReducer::FinalizeBackward() { groups_need_finalize_ = false; grad_need_hooks_ = false; for (auto &group : groups_) { - group.task->Synchronize(); + if (!group.is_sparse_) { + group.task->Synchronize(); + } } for (auto &group : groups_) { - group.SplitTensors(inner_place_); + if (!group.is_sparse_) { + group.SplitTensors(inner_place_); + } } if (find_unused_vars_each_step_) { @@ -778,6 +832,127 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, // split in FinalizeBackward() } +void EagerReducer::AllReduceSparse(EagerGroup *group, + const int curr_group_index) { + // div nranks + Tensor sparse_tensor(group->sparse_contents_); + paddle::experimental::scale_(sparse_tensor, 1.0 / nranks_, 0.0, false); + + VLOG(3) << "sparse_group [" << curr_group_index << "] start allreduce."; + + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_); + if (platform::is_gpu_place(inner_place_)) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(inner_place_)); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat grad tensors since it's not compiled with NCCL," + "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_cpu_place(inner_place_)) { + dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(inner_place_)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Split grad tensor not supported on place (%s)", inner_place_)); + } + + auto src = std::dynamic_pointer_cast( + group->sparse_contents_.impl()); + const auto &src_rows = src->rows(); + + const auto &rank_ = process_group_->GetRank(); + const auto &size_ = process_group_->GetSize(); + + framework::Vector rows_num_vector(size_); + rows_num_vector[rank_] = static_cast(src_rows.size()); + + Tensor rows_num_tensor = paddle::experimental::empty( + IntArray({static_cast(size_)}), DataType::INT64, inner_place_); + auto *rows_num_dense_tensor = + std::dynamic_pointer_cast(rows_num_tensor.impl()).get(); + framework::TensorFromVector(rows_num_vector, *dev_ctx, + rows_num_dense_tensor); + + distributed::AllreduceOptions opts; + opts.reduce_op = ReduceOp::SUM; + std::vector reduce_tensors = {rows_num_tensor}; + process_group_->AllReduce(reduce_tensors, opts)->Synchronize(); + + framework::TensorToVector(*rows_num_dense_tensor, *dev_ctx, + &rows_num_vector); + dev_ctx->Wait(); + + const auto *cpu_rows_num_ptr = rows_num_vector.data(); + auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + size_, + static_cast(0)); + + VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',') + << ", total rows number: " << rows_num + << ", height: " << src->height(); + + dev_ctx->Wait(); + + if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_, + [&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) { + // During sparse communication, the number of each card is same. + // allgather is used to speed up the allreduce by replacing broadcast. + + VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce"; + + Tensor dst_rows_tensor = + paddle::experimental::empty(IntArray({static_cast(rows_num)}), + DataType::INT64, inner_place_); + Tensor src_rows_tensor = paddle::experimental::empty( + IntArray({static_cast((*src).rows().size())}), DataType::INT64, + inner_place_); + auto *src_rows_dense_tensor = + std::dynamic_pointer_cast(src_rows_tensor.impl()) + .get(); + framework::TensorFromVector((*src).rows(), *dev_ctx, + src_rows_dense_tensor); + + std::vector src_rows_tensors = {src_rows_tensor}; + std::vector dst_rows_tensors = {dst_rows_tensor}; + process_group_->AllGather(src_rows_tensors, dst_rows_tensors) + ->Synchronize(); + + framework::Vector dst_rows_vector(rows_num, 0); + auto *dst_rows_dense_tensor = + std::dynamic_pointer_cast(dst_rows_tensor.impl()) + .get(); + framework::TensorToVector(*dst_rows_dense_tensor, *dev_ctx, + &dst_rows_vector); + dev_ctx->Wait(); + + Tensor src_value_tensor(std::make_shared(src->value())); + std::vector dst_shape = src_value_tensor.shape(); + dst_shape[dst_shape.size() - 2] = rows_num; + auto dst_dense_tensor = std::dynamic_pointer_cast( + paddle::experimental::full(IntArray(dst_shape), 0, + src_value_tensor.dtype(), inner_place_) + .impl()); + + auto dst = + std::make_shared(dst_rows_vector, (*src).height()); + *(dst->mutable_value()) = *dst_dense_tensor; + Tensor dst_value_tensor(std::make_shared(dst->value())); + + std::vector src_value_tensors = {src_value_tensor}; + std::vector dst_value_tensors = {dst_value_tensor}; + process_group_->AllGather(src_value_tensors, dst_value_tensors) + ->Synchronize(); + + src->set_rows(dst_rows_vector); + *(src->mutable_value()) = + *(std::dynamic_pointer_cast(dst_value_tensor.impl())); + } else { + PADDLE_THROW( + platform::errors::Unimplemented("This case is not supported.")); + } +} + std::ostream &operator<<(std::ostream &out, const EagerGroup &group) { const auto &tensors_ = group.tensor_indices_; out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size() diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index 848277f5fa..12c0250988 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -47,6 +47,8 @@ std::vector> Eager_AssignGroupBySize( class EagerGroup { public: Tensor dense_contents_; + Tensor sparse_contents_; + bool is_sparse_ = false; // for concat kernel std::vector dense_tensors_; @@ -104,6 +106,7 @@ class EagerReducer { void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkGroupReady(const size_t group_index); void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index); + void AllReduceSparse(EagerGroup *group, const int curr_group_index); void FinalizeBackward(); void TraverseBackwardGraph(const std::vector &outputs); void ProcessUnusedDenseVars(); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8184960637..663dd9b9e1 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1128,7 +1128,7 @@ set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 150) + set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 300) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150) @@ -1153,8 +1153,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) - set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200) + set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height PROPERTIES TIMEOUT 150) + set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 150) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py index 226f1293ef..33ae0acf43 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py @@ -42,7 +42,6 @@ class SimpleNet(fluid.Layer): dtype=dtype, is_sparse=is_sparse, param_attr=fluid.ParamAttr( - name='embedding_param', initializer=fluid.initializer.UniformInitializer( low=-init_scale, high=init_scale))) self.softmax_weight = self.create_parameter( @@ -103,8 +102,8 @@ class TestSparseEmbedding(TestParallelDyGraphRunnerBase): train_reader = paddle.batch( fake_sample_reader(), batch_size=batch_size, drop_last=True) - optimizer = fluid.optimizer.SGD(learning_rate=0.001, - parameter_list=model.parameters()) + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) return model, train_reader, optimizer diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py index a15b263a29..b341a22728 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py @@ -40,7 +40,6 @@ class SimpleNet(Layer): self.hidden_size, sparse=True, weight_attr=paddle.ParamAttr( - name='embedding_param', initializer=paddle.nn.initializer.Uniform( low=-init_scale, high=init_scale))) self.softmax_weight = self.create_parameter( diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py index 9f87738110..b4dd03aecf 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py @@ -39,7 +39,6 @@ class SimpleNet(Layer): self.hidden_size, sparse=is_sparse, weight_attr=paddle.ParamAttr( - name='embedding_param', initializer=paddle.nn.initializer.Uniform( low=-init_scale, high=init_scale))) self.softmax_weight = self.create_parameter( diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py index 43907da609..30349270b9 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py @@ -64,5 +64,47 @@ class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner): test_class=TestSparseEmbedding, delta=1e-5) +class TestParallelDygraphSparseEmdeddingEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._eager_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingFP64Eager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_sparse_embedding_fp64(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding_fp64.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingSpawnEager(TestDistSpawnRunner): + def _args_config(self, args): + args.eager_mode = True + + def test_sparse_embedding_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbedding, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py index 56fcf806c4..e461bf2a26 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py @@ -55,5 +55,35 @@ class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase): log_name=flag_name) +class TestParallelDygraphSparseEmdeddingEager_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingEagerFP64_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding_fp64(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding_fp64.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py index 9aca448f16..fb4c992d35 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py @@ -48,5 +48,32 @@ class TestParallelDygraphSparseEmdeddingOverHeightSpawn(TestDistSpawnRunner): test_class=TestSparseEmbeddingOverHeight, delta=1e-5) +class TestParallelDygraphSparseEmdeddingOverHeightEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding_over_height.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingOverHeightSpawnEager( + TestDistSpawnRunner): + def _args_config(self, args): + args.eager_mode = True + + def test_sparse_embedding_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbeddingOverHeight, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py index ba43e26e23..0acec54ca6 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py @@ -40,5 +40,20 @@ class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase): log_name=flag_name) +class TestParallelDygraphSparseEmdeddingOverHeightEager_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding_over_height.py", + delta=1e-7, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py index 7cf1e9711b..3a7a32c2ec 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py @@ -36,5 +36,21 @@ class TestParallelDygraphMnist(TestDistBase): log_name=flag_name) +class TestParallelDygraphMnistEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sync_batch_norm.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py index e0aab8541a..2141cceb79 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py @@ -65,5 +65,21 @@ class TestParallelDygraphTransformerAccGrad(TestDistBase): log_name=flag_name) +class TestParallelDygraphTransformerEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_transformer(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py index d3619cc1b9..6d4dd6433a 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py @@ -57,5 +57,20 @@ class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase): log_name=flag_name) +class TestParallelDygraphTransformerEager_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_transformer(self): + self.check_with_place( + "parallel_dygraph_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py index 75fa6f7c71..f2225111d1 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py @@ -86,5 +86,71 @@ class TestParallelDygraphSharedUnusedVariables(TestDistBase): log_name=flag_name) +class TestParallelDygraphUnusedVarEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_unused_variables.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestDygraphUnusedVarEager(TestParallelDygraphUnusedVar): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + +class TestSparseEmbeddingUnusedVarsSpawnEager(TestDistSpawnRunner): + def _args_config(self, args): + args.eager_mode = True + + def test_mnist_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbeddingUnusedVars, delta=1e-5) + + +class TestParallelDygraphNoVarEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_none_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSharedUnusedVariablesEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_shared_unused_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() -- GitLab