From a60f17b89d4de1fe6d9175af2954fa62b92b7a39 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Wed, 13 Jan 2021 11:36:36 +0800 Subject: [PATCH] Support unused parameters in dynamic graph distributed (#30224) --- paddle/fluid/imperative/reducer.cc | 278 ++++++++++++++---- paddle/fluid/imperative/reducer.h | 33 ++- paddle/fluid/pybind/imperative.cc | 22 +- python/paddle/fluid/dygraph/parallel.py | 22 +- .../fluid/tests/unittests/CMakeLists.txt | 3 + .../parallel_dygraph_sparse_embedding_fp64.py | 8 + .../parallel_dygraph_unused_variables.py | 133 +++++++++ .../test_parallel_dygraph_unused_variables.py | 68 +++++ 8 files changed, 483 insertions(+), 84 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 85f2831a06..10e8b39831 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -22,6 +22,11 @@ std::shared_ptr Reducer::s_instance_ = NULL; // context is used to select the stream for concat void Group::ConcatTensors(const platform::CUDADeviceContext &context) { + VLOG(3) << "Before concat, set output tensor size is " << all_length_; + auto tensor = dense_contents_.GetMutable(); + tensor->Resize(framework::make_ddim({all_length_})) + .mutable_data(context.GetPlace(), dtype_); + switch (dtype_) { case framework::proto::VarType::FP16: ConcatTensorsForAllReduce(context, dense_tensors_, @@ -88,23 +93,27 @@ Reducer::Reducer(const std::vector> &vars, const std::vector> &group_indices, const std::vector &is_sparse_gradient, std::shared_ptr parallel_ctx, - const std::vector &group_size_limits) + const std::vector &group_size_limits, + bool find_unused_vars) : vars_(vars), group_indices_(group_indices), is_sparse_gradient_(is_sparse_gradient), parallel_ctx_(parallel_ctx), - group_size_limits_(group_size_limits) { + group_size_limits_(group_size_limits), + find_unused_vars_(find_unused_vars) { VLOG(3) << "Start construct the Reducer ..."; nrings_ = parallel_ctx->GetNRings(); // initialize groups InitializeGroups(group_indices); for (size_t global_var_index = 0; global_var_index < vars_.size(); ++global_var_index) { - vars_[global_var_index]->SharedVar()->AddGradVarLeafBackwardHook( + auto var = vars_[global_var_index]; + var->SharedVar()->AddGradVarLeafBackwardHook( std::unique_ptr( new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) { - this->AddDistHook(grad, global_var_index); + this->AddDistHook(global_var_index); }))); + var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; } // create streams compute_stream_ = static_cast( @@ -169,8 +178,6 @@ void Reducer::InitializeDenseGroups( all_length += size; p_group->length_.push_back(size); - // for concat operator - p_group->dense_tensors_.push_back(framework::Tensor()); // check the dtype and place, it must be same. auto dtype = var->DataType(); @@ -193,7 +200,6 @@ void Reducer::InitializeDenseGroups( place_ = place; } } - p_group->all_length_ = all_length; } // Each parameter will be initialized according to the group information. @@ -228,10 +234,6 @@ void Reducer::InitializeGroups( } else { // process the dense gradient. InitializeDenseGroups(variable_indices_, &group); - // Alloc the continuous space - auto tensor = group.dense_contents_.GetMutable(); - tensor->Resize(framework::make_ddim({group.all_length_})) - .mutable_data(place_, group.dtype_); } // map variables to this group by VariableLocator @@ -244,21 +246,144 @@ void Reducer::InitializeGroups( } group.variable_indices_ = std::move(variable_indices_); groups_.emplace_back(std::move(group)); - // Debug Message For Reducer VLOG(3) << "The Group[" << group_index << "]:"; VLOG(3) << groups_.back(); } } +void Reducer::PrepareDeps(const std::unordered_set &init_nodes) { + PADDLE_ENFORCE_EQ( + node_deps_.empty(), true, + platform::errors::AlreadyExists("Op deps must be initialized here")); + + std::queue q; + std::unordered_set visited; + + for (auto pos = init_nodes.begin(); pos != init_nodes.end(); pos++) { + q.push(*pos); + visited.insert(*pos); + } + + while (!q.empty()) { + auto *cur_node = q.front(); + q.pop(); + + for (auto &cur_op : *cur_node) { + cur_op.EnforceHasInOut(); + } + + const auto &grad_pending_nodes = cur_node->GradPendingNodes(); + for (auto &grad_pending_node : grad_pending_nodes) { + PADDLE_ENFORCE_NOT_NULL( + grad_pending_node, + platform::errors::NotFound("Grad pending node should not be null")); + ++node_deps_[grad_pending_node.get()]; + if (visited.count(grad_pending_node.get()) == 0) { + visited.insert(grad_pending_node.get()); + q.push(grad_pending_node.get()); + } + } + } +} + // After each batch is calculated, the counter of each group(group.pending_) // and allreudce sequence counter(next_group_) will be cleaned up again. -void Reducer::PrepareForBackward() { +void Reducer::PrepareForBackward( + const std::vector> &outputs) { VLOG(3) << "start reseting count.."; next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](Group &group) { group.pending_ = group.variable_indices_.size(); + group.all_length_ = 0; + group.dense_tensors_.clear(); + group.dense_tensors_.reserve(group.pending_); + group.sparse_contents_ = nullptr; }); + + PADDLE_ENFORCE_EQ( + all_group_ready_, false, + platform::errors::PreconditionNotMet( + "Please note that all ``forward`` outputs derived from the module " + "parameters must participate in the calculation of losses and " + "subsequent gradient calculations. If not, the wrapper will hang, " + "waiting for autograd to generate gradients for these parameters. " + "you can use detach or stop_gradient to make the unused parameters " + "detached from the autograd graph.")); + + // The first var to trigger the unused parameter + has_marked_unused_vars_ = false; + if (!find_unused_vars_) { + return; + } + + // TODO(shenliang03) "find_unused_vars" interface will be exposed in the + // future to handle control flow to process unused parameters + find_unused_vars_ = false; + + unused_vars_.clear(); + node_deps_.clear(); + std::queue> q; + std::unordered_set var_visited; + std::unordered_set init_nodes; + + for (const auto &output : outputs) { + const auto &grad_node = output->GradVarBase()->GradNode(); + if (grad_node == nullptr || output->OverridedStopGradient()) { + VLOG(3) << "Skip auto grad since there is no grad op or output is " + "stop_gradient=True: " + << output->Name(); + continue; + } else { + init_nodes.insert(grad_node.get()); + var_visited.insert(output->SharedVar().get()); + q.push(grad_node); + } + } + + PrepareDeps(init_nodes); + // Traverse the autograd graph starting at the specified output + while (!q.empty()) { + auto cur_node = q.front(); + q.pop(); + + for (const auto &cur_op : *cur_node) { + cur_op.EnforceHasInOut(); + auto &bwd_outs = cur_op.GetOutsMap(); + for (const auto &pair : bwd_outs) { + if (!pair.second.IsGrad()) { + continue; + } + for (auto &var : pair.second) { + if (!var || var->OverridedStopGradient()) { + continue; + } else { + var_visited.insert(var.get()); + } + } + } + } + for (const auto &grad_pending_node : cur_node->GradPendingNodes()) { + PADDLE_ENFORCE_NOT_NULL(grad_pending_node, + platform::errors::NotFound( + "Grad pending node should not be nullptr")); + auto iter = node_deps_.find(grad_pending_node.get()); + if (iter == node_deps_.end()) { + continue; + } + if (--(iter->second) == 0) { + q.push(grad_pending_node); + } + } + } + + for (const auto &it : var_index_map_) { + if (var_visited.count(it.first) == 0) { + unused_vars_.push_back(it.second); + VLOG(3) << "Var[" << it.second << "] [" << it.first->Name() + << "] is not used"; + } + } } // Add hook function to each leaf node. When the gradient of a leaf node is @@ -270,23 +395,50 @@ void Reducer::PrepareForBackward() { // counter is 0, it means that allreduce can be emitted, and // concat + allreduce + split is emitted in turn according to next_group_. // 3, FinalizeBackward: after the end, synchronize each stream. -void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) { - const auto &var_locator = variable_locators_[var_index]; - auto group_index = var_locator.group_index; - auto &group = groups_[group_index]; +void Reducer::AddDistHook(size_t var_index) { + VLOG(3) << "Var[" << var_index << "] [" + << vars_[var_index]->GradVarBase()->Name() + << "] arrived and triggered disthook"; + if (!has_marked_unused_vars_) { + has_marked_unused_vars_ = true; + for (auto unused_index : unused_vars_) { + if (NeedRebuildGroup()) { + rebuild_vars_.push_back(vars_[unused_index]); + rebuild_var_indices_.push_back(unused_index); + } + MarkVarReady(unused_index, false); + } + } - if (!has_rebuilt_group_) { + if (NeedRebuildGroup()) { rebuild_vars_.push_back(vars_[var_index]); rebuild_var_indices_.push_back(var_index); } + MarkVarReady(var_index, true); +} - if (!group.is_sparse_) { - // Only dense_contents_ need memory copy - MarkDenseVarReady(var_index, var_warpper); - } else { - MarkSparseVarReady(var_index, var_warpper); - } +void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { + all_group_ready_ = true; + const auto &var_locator = variable_locators_[var_index]; + auto group_index = var_locator.group_index; + auto &group = groups_[group_index]; + if (is_used_var) { + auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar(); + if (!group.is_sparse_) { + auto grad = var_warpper->MutableVar(); + auto inside_group_index = var_locator.inside_group_index; + auto length = group.length_[inside_group_index]; + + auto tensor = grad->GetMutable(); + framework::Tensor tmp; + tmp.ShareDataWith(*tensor).Resize({static_cast(length)}); + group.dense_tensors_.push_back(std::move(tmp)); + group.all_length_ += length; + } else { + group.sparse_contents_ = var_warpper->MutableVar(); + } + } if (--group.pending_ == 0) { // can start allreduce MarkGroupReady(group_index); @@ -297,27 +449,6 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) { } } -void Reducer::MarkDenseVarReady(size_t var_index, - VariableWrapper *var_warpper) { - const auto &var_locator = variable_locators_[var_index]; - auto group_index = var_locator.group_index; - auto inside_group_index = var_locator.inside_group_index; - auto &group = groups_[group_index]; - auto length = group.length_[inside_group_index]; - - auto tensor = var_warpper->MutableVar()->GetMutable(); - group.dense_tensors_[inside_group_index].ShareDataWith(*tensor).Resize( - {static_cast(length)}); -} - -void Reducer::MarkSparseVarReady(size_t var_index, - VariableWrapper *var_warpper) { - const auto &var_locator = variable_locators_[var_index]; - auto group_index = var_locator.group_index; - auto &group = groups_[group_index]; - group.sparse_contents_ = var_warpper->MutableVar(); -} - void Reducer::MarkGroupReady(size_t group_index) { if (group_index > next_group_) { VLOG(3) << "It will adjust the order of group in next batch automatically"; @@ -326,6 +457,7 @@ void Reducer::MarkGroupReady(size_t group_index) { PADDLE_ENFORCE_CUDA_SUCCESS( cudaEventRecord(group_events_[group_index].get(), compute_stream_)); + for (int i = 0; i < nrings_; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent( comm_streams_[i], group_events_[group_index].get(), 0)); @@ -336,29 +468,48 @@ void Reducer::MarkGroupReady(size_t group_index) { auto &group = groups_[next_group_]; int run_order = next_group_ % nrings_; if (group.is_sparse_) { - VLOG(3) << "sparse group [" << next_group_ << "] start allreduce in ring[" - << run_order << "]"; - parallel_ctx_->AllReduceByStream( - *group.sparse_contents_, group.sparse_contents_, run_order, false); + if (group.sparse_contents_ != nullptr) { + VLOG(3) << "sparse group [" << next_group_ + << "] start allreduce in ring[" << run_order << "]"; + parallel_ctx_->AllReduceByStream( + *group.sparse_contents_, group.sparse_contents_, run_order, false); + } else { + VLOG(3) << "The sparse group[" << next_group_ + << "] has no var to allreduce"; + } } else { - VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring[" - << run_order << "]"; - // Select common commstream to concat tensors - // group.dense_tensors ---> group.dense_contents_ - group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); - - // Start allreduce - parallel_ctx_->AllReduceByStream( - group.dense_contents_, &(group.dense_contents_), run_order, false); - - // Select common commstream to split tensors - // group.dense_contents_ ---> group.dense_tensors - group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); + if (!group.dense_tensors_.empty()) { + VLOG(3) << "dense group [" << next_group_ + << "] start allreduce in ring[" << run_order << "]"; + // Select common commstream to concat tensors + // group.dense_tensors ---> group.dense_contents_ + group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); + + // Start allreduce + parallel_ctx_->AllReduceByStream( + group.dense_contents_, &(group.dense_contents_), run_order, false); + + // Select common commstream to split tensors + // group.dense_contents_ ---> group.dense_tensors + group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); + } else { + VLOG(3) << "The dense group[" << next_group_ + << "] has no var to allreduce"; + } } } } std::vector> Reducer::RebuildGruops() { + VLOG(3) << "The order of parameter arrival: " + << string::join_strings(rebuild_var_indices_, ','); + + PADDLE_ENFORCE_EQ( + rebuild_vars_.size(), vars_.size(), + platform::errors::PreconditionNotMet( + "Rebuild vars's number should be equal to original vars'number, " + "expect it to be %d, but got %d.", + vars_.size(), rebuild_vars_.size())); std::reverse(rebuild_vars_.begin(), rebuild_vars_.end()); std::reverse(rebuild_var_indices_.begin(), rebuild_var_indices_.end()); auto rebuild_group_indices = @@ -372,6 +523,7 @@ std::vector> Reducer::RebuildGruops() { } void Reducer::FinalizeBackward() { + all_group_ready_ = false; // Must prevent compute_stream_ starting until all comm streams have finished for (int i = 0; i < nrings_; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -382,7 +534,7 @@ void Reducer::FinalizeBackward() { cudaStreamWaitEvent(compute_stream_, comm_events_[i].get(), 0)); } - if (!has_rebuilt_group_) { + if (NeedRebuildGroup()) { VLOG(3) << "Start rebuilding the groups"; auto rebuild_group_indices = RebuildGruops(); auto rebuild_group_number = rebuild_group_indices.size(); diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 2bfc308de0..62b6161602 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -18,14 +18,18 @@ #include #include #include +#include #include #include +#include #include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/string/string_helper.h" #if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/imperative/all_reduce.h" @@ -121,7 +125,7 @@ class Reducer { const std::vector>& group_indices, const std::vector& is_sparse_gradient, std::shared_ptr parallel_ctx, - const std::vector& group_size_limits); + const std::vector& group_size_limits, bool find_unused_vars); virtual ~Reducer() {} @@ -130,13 +134,18 @@ class Reducer { void InitializeDenseGroups(const std::vector& variable_indices_, Group* p_group); - void PrepareForBackward(); + void PrepareDeps(const std::unordered_set& init_nodes); - void AddDistHook(VariableWrapper* var_warpper, size_t var_index); + void PrepareForBackward( + const std::vector>& outputs); - void MarkDenseVarReady(size_t var_index, VariableWrapper* var_warpper); + void AddDistHook(size_t var_index); - void MarkSparseVarReady(size_t var_index, VariableWrapper* var_warpper); + // void MarkDenseVarReady(size_t var_index); + + // void MarkSparseVarReady(size_t var_index); + + void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkGroupReady(size_t group_index); @@ -148,17 +157,19 @@ class Reducer { void CreateGroupEvents(int group_num); + inline bool NeedRebuildGroup() { return !has_rebuilt_group_; } + // Reducer Singleton static std::shared_ptr SetInstance( const std::vector>& vars, const std::vector>& group_indices, const std::vector& is_sparse_gradient, std::shared_ptr parallel_ctx, - const std::vector& group_size_limits) { + const std::vector& group_size_limits, bool find_unused_vars) { if (NULL == s_instance_) { s_instance_.reset(new paddle::imperative::Reducer( vars, group_indices, is_sparse_gradient, parallel_ctx, - group_size_limits)); + group_size_limits, find_unused_vars)); } return s_instance_; } @@ -194,6 +205,14 @@ class Reducer { std::vector> rebuild_vars_; std::vector rebuild_var_indices_; const std::vector group_size_limits_; + + // Following variables are to help unused vars + std::unordered_map node_deps_; + std::unordered_map var_index_map_; + std::vector unused_vars_; + bool has_marked_unused_vars_{false}; + bool find_unused_vars_{false}; + bool all_group_ready_{false}; }; std::vector> AssignGroupBySize( diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 505d94559d..c4377b3140 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1358,18 +1358,18 @@ void BindImperative(py::module *m_ptr) { py::class_>( m, "Reducer", R"DOC()DOC") - .def(py::init( - [](const std::vector> &vars, - const std::vector> &group_indices, - const std::vector &is_sparse_gradient, - std::shared_ptr parallel_ctx, - const std::vector &group_size_limits) { - return imperative::Reducer::SetInstance( - vars, group_indices, is_sparse_gradient, parallel_ctx, - group_size_limits); - })) + .def(py::init([]( + const std::vector> &vars, + const std::vector> &group_indices, + const std::vector &is_sparse_gradient, + std::shared_ptr parallel_ctx, + const std::vector &group_size_limits, bool find_unused_vars) { + return imperative::Reducer::SetInstance( + vars, group_indices, is_sparse_gradient, parallel_ctx, + group_size_limits, find_unused_vars); + })) .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, - py::call_guard()); + py::arg("vars"), py::call_guard()); m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"), py::arg("is_sparse_gradient"), diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index a9ed2f9f52..a80f6b3f49 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -26,6 +26,7 @@ from paddle.fluid.dygraph import to_variable, no_grad from paddle.utils import deprecated import warnings import paddle +import itertools __all__ = ["prepare_context", "ParallelEnv", "DataParallel"] @@ -465,17 +466,32 @@ class DataParallel(layers.Layer): "ParallelContext must be initialized before. You should use init_parallel_env() before" \ "constructing the DataParallel." + # TODO(shenliang03) "find_unused_vars" interface will be exposed in the future + # to handle control flow to process unused parameters + find_unused_vars = True self._reducer = core.Reducer( trainable_parameters, list(reversed(self.group_indices)), is_sparse_gradient, parallel_helper.__parallel_ctx__clz__, - [self.last_comm_buffer_size, self.comm_buffer_size]) + [self.last_comm_buffer_size, self.comm_buffer_size], + find_unused_vars) + + def _find_varbase(self, obj): + if isinstance(obj, core.VarBase): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(self._find_varbase, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(self._find_varbase, obj.values())) + return [] def forward(self, *inputs, **kwargs): + outputs = self._layers(*inputs, **kwargs) if self._strategy.nranks > 1: - self._reducer.prepare_for_backward() + self._reducer.prepare_for_backward( + list(self._find_varbase(outputs))) - return self._layers(*inputs, **kwargs) + return outputs @deprecated( since="2.0.0", reason="This method does not need to be called anymore.") diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2ec2ea2872..269cb8d28b 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -18,6 +18,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -151,6 +152,7 @@ if (NOT ${WITH_GPU}) LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future LIST(REMOVE_ITEM TEST_OPS test_batch_fc_op) # TODO(shenliang03): batch_fc_op support CPU device in future LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_unused_variables) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_se_resnext) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) @@ -813,6 +815,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) endif() endif() if(WITH_GPU AND NOT WIN32) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py index 47050b7bfc..65c242a702 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py @@ -55,10 +55,18 @@ class SimpleNet(Layer): dtype=dtype, default_initializer=paddle.nn.initializer.Uniform( low=-self.init_scale, high=self.init_scale)) + self.tmp = self.create_parameter( + attr=paddle.ParamAttr(), + shape=[self.hidden_size, self.vocab_size], + dtype=dtype, + default_initializer=paddle.nn.initializer.Uniform( + low=-self.init_scale, high=self.init_scale)) def forward(self, input, label): x_emb = self.embedding(input) fc = paddle.matmul(x_emb, self.softmax_weight) + # use detach to stop gradient + fc = fc.detach() fc = paddle.add(fc, self.softmax_bias) projection = paddle.reshape(fc, shape=[-1, self.vocab_size]) loss = paddle.nn.functional.softmax_with_cross_entropy( diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py new file mode 100644 index 0000000000..1884eef15e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py @@ -0,0 +1,133 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase +from paddle.nn import Layer, Embedding + + +class SimpleNet(Layer): + def __init__(self, + hidden_size, + vocab_size, + num_steps=20, + init_scale=0.1, + is_sparse=False, + dtype="float32"): + super(SimpleNet, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.init_scale = init_scale + self.num_steps = num_steps + self.embedding = Embedding( + self.vocab_size, + self.hidden_size, + sparse=True, + weight_attr=paddle.ParamAttr( + name='embedding_param', + initializer=paddle.nn.initializer.Uniform( + low=-init_scale, high=init_scale))) + self.softmax_weight = self.create_parameter( + attr=paddle.ParamAttr(), + shape=[self.hidden_size, self.vocab_size], + dtype=dtype, + default_initializer=paddle.nn.initializer.Uniform( + low=-self.init_scale, high=self.init_scale)) + self.softmax_bias = self.create_parameter( + attr=paddle.ParamAttr(), + shape=[self.vocab_size], + dtype=dtype, + default_initializer=paddle.nn.initializer.Uniform( + low=-self.init_scale, high=self.init_scale)) + # add tmp var + self.tmp = self.create_parameter( + attr=paddle.ParamAttr(), + shape=[self.vocab_size], + dtype=dtype, + default_initializer=paddle.nn.initializer.Uniform( + low=-self.init_scale, high=self.init_scale)) + + def forward(self, input, label): + x_emb = self.embedding(input) + fc = paddle.matmul(x_emb, self.softmax_weight) + + # it use stop gradient to block gradient return + fc.stop_gradient = True + fc = paddle.add(fc, self.softmax_bias) + projection = paddle.reshape(fc, shape=[-1, self.vocab_size]) + loss = paddle.nn.functional.softmax_with_cross_entropy( + logits=projection, label=label, soft_label=False) + loss = paddle.reshape(loss, shape=[-1, self.num_steps]) + loss = paddle.mean(loss, axis=[0]) + loss = paddle.sum(loss) + + return {"loss": loss} + + +# global configs +batch_size = 4 +batch_num = 200 +hidden_size = 10 +vocab_size = 1000 +num_steps = 3 +init_scale = 0.1 + + +def fake_sample_reader(): + def __reader__(): + for i in range(batch_num): + x_data = np.arange(num_steps).astype('int64') + y_data = np.arange(1, 1 + num_steps).astype('int64') + yield x_data, y_data + + return __reader__ + + +class TestSparseEmbeddingUnusedVars(TestParallelDyGraphRunnerBase): + def get_model(self): + model = SimpleNet( + hidden_size=hidden_size, + vocab_size=vocab_size, + num_steps=num_steps, + init_scale=init_scale, + is_sparse=True) + + train_reader = paddle.batch( + fake_sample_reader(), batch_size=batch_size, drop_last=True) + + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + + return model, train_reader, optimizer + + def run_one_loop(self, model, optimizer, batch): + x_data = np.array([x[0].reshape(3) for x in batch]).astype('int64') + y_data = np.array([x[1].reshape(3) for x in batch]).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + + dy_loss = model(x, y) + + return dy_loss["loss"] + + +if __name__ == "__main__": + runtime_main(TestSparseEmbeddingUnusedVars) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py new file mode 100644 index 0000000000..d7f8b61ac5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py @@ -0,0 +1,68 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_unused_variables import TestSparseEmbeddingUnusedVars + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphMnist(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_unused_variables.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestSparseEmbeddingUnusedVarsSpawn(TestDistSpawnRunner): + def test_mnist_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbeddingUnusedVars, delta=1e-5) + + +class TestFleetDygraphMnist(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + self._gpu_fleet_api = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_unused_variables.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() -- GitLab