diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index cd6e6835e5df30124d4fbf1c706a1e3b0688aa15..148eaef38bf539412c9cce3b14dd33d39f999c5a 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -2,7 +2,7 @@ cc_library(imperative_flag SRCS flags.cc DEPS gflags) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) -cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows var_type_traits layer) +cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer) add_subdirectory(jit) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer) diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index b1f4bb1b6c9f60627427b65146a65c4ceb32bb93..ad5a2f09bee0f1ffcd6aa61e909d26ddd3c90bb4 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -131,11 +131,12 @@ class GradOpBaseMakerBase { "VarBase grad of OP [%s] should not be null", fw_op_base_->Type()); auto grad_var_base_tmp = var_base_temp->GradVarBase(); - auto* tensor = grad_var_base_tmp->MutableVar() - ->GetMutable(); - tensor->Resize( - var_base_temp->Var().Get().dims()); - + if (!is_input) { + auto* tensor = grad_var_base_tmp->MutableVar() + ->GetMutable(); + tensor->Resize( + var_base_temp->Var().Get().dims()); + } vec_temp.emplace_back(grad_var_base_tmp); } else { vec_temp.emplace_back(var_base_temp); diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 873164fc28773c144f5f97c1498b732b7b0800e4..9da6a4da921d5de057a1054db062d717d28987c1 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -16,11 +16,13 @@ #include #include #include +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/profiler.h" @@ -84,7 +86,7 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { auto data_type = src_tensor.type(); auto place = src_tensor.place(); -#define PADDLE_TENSOR_ADD_MACRO(cpp_type) \ +#define PADDLE_TENSOR_ADD(cpp_type) \ if (data_type == framework::DataTypeTrait::DataType()) { \ TensorAddFunctor func( \ numel, src_tensor.data(), \ @@ -93,25 +95,155 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { return; \ } - PADDLE_TENSOR_ADD_MACRO(float); - PADDLE_TENSOR_ADD_MACRO(double); + PADDLE_TENSOR_ADD(float); + PADDLE_TENSOR_ADD(double); -#undef PADDLE_TENSOR_ADD_MACRO +#undef PADDLE_TENSOR_ADD PADDLE_THROW("Not supported data type %s for AddTo", framework::DataTypeToString(data_type)); } +void SelectedRowsAddToTensor(const framework::Variable& src, + framework::Variable* dst) { + auto* dst_tensor = dst->GetMutable(); + auto& src_selected_rows = src.Get(); + auto place = dst_tensor->place(); + auto data_type = src_selected_rows.value().type(); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + +#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type) \ + if (data_type == framework::DataTypeTrait::DataType()) { \ + paddle::platform::DeviceContext* dev_ctx = pool.Get(place); \ + paddle::operators::math::SelectedRowsAddToTensor \ + functor; \ + functor(*(dynamic_cast(dev_ctx)), src_selected_rows, \ + dst_tensor); \ + return; \ + } + +#ifdef PADDLE_WITH_CUDA + if (paddle::platform::is_gpu_place(place)) { + PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, float); + PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, double); + } else { +#endif + PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, float); + PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, double); +#ifdef PADDLE_WITH_CUDA + } +#endif + +#undef PADDLE_SELECTED_ROWS_ADD_TO_TENSOR + + PADDLE_THROW(platform::errors::InvalidArgument( + "Not supported data type %s for SelectedRowsAddToTensor", + framework::DataTypeToString(data_type))); +} + +// Note(chenweihang): when two selected rows need to be added, +// adding one to another is not equal to merging two selected rows +// to one then add it to a empty selected rows, the after is correct +std::shared_ptr SelectedRowsMerge(const framework::Variable& src1, + const framework::Variable& src2) { + auto& src_selected_rows1 = src1.Get(); + auto& src_selected_rows2 = src2.Get(); + auto place = src_selected_rows1.value().place(); + auto data_type = src_selected_rows1.value().type(); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + + std::vector src_selected_rows; + src_selected_rows.emplace_back(&src_selected_rows1); + src_selected_rows.emplace_back(&src_selected_rows2); + auto dst_var = std::make_shared(false, "Temp"); + auto* dst_selected_rows = + dst_var->MutableVar()->GetMutable(); + +#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type) \ + if (data_type == framework::DataTypeTrait::DataType()) { \ + paddle::platform::DeviceContext* dev_ctx = pool.Get(place); \ + paddle::operators::math::scatter::MergeAdd \ + merge_add; \ + merge_add(*(dynamic_cast(dev_ctx)), src_selected_rows, \ + dst_selected_rows); \ + return dst_var; \ + } + +#ifdef PADDLE_WITH_CUDA + if (paddle::platform::is_gpu_place(place)) { + PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, float); + PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, double); + } else { +#endif + PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, float); + PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, double); +#ifdef PADDLE_WITH_CUDA + } +#endif + +#undef PADDLE_SELECTED_ROWS_ADD + + PADDLE_THROW(platform::errors::InvalidArgument( + "Not supported data type %s for SelectedRowsMerge", + framework::DataTypeToString(data_type))); +} + +void VarBaseAdd(std::shared_ptr var, VarBase* var_) { + auto& src = var->Var(); + auto* dst = var_->MutableVar(); + if (dst->IsType()) { + if (src.IsType()) { + TensorAdd(src, dst); + } else if (src.IsType()) { + SelectedRowsAddToTensor(src, dst); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unexpected branch, output variable type is %s", + framework::ToTypeName(dst->Type()))); + } + } else { + if (src.IsType()) { + auto* src_mutable = var->MutableVar(); + SelectedRowsAddToTensor(*dst, src_mutable); + *dst = std::move(*(var->MutableVar())); + var_->SetType(framework::proto::VarType::LOD_TENSOR); + } else if (src.IsType()) { + std::shared_ptr temp = SelectedRowsMerge(src, *dst); + *dst = std::move(*(temp->MutableVar())); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unexpected branch, output variable type is %s", + framework::ToTypeName(dst->Type()))); + } + } +} + +platform::Place GetPlaceOfVarBase(const std::shared_ptr& var) { + platform::Place place; + if (var->Var().IsType()) { + place = var->Var().Get().place(); + } else if (var->Var().IsType()) { + place = var->Var().Get().place(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "only support LoDTensor and SelectedRows in dygraph")); + } + return place; +} + void EagerGradientAccumulator::Add(std::shared_ptr var, size_t trace_id) { auto* dst_var = var_->MutableVar(); - auto place = var->Var().Get().place(); + platform::Place place = GetPlaceOfVarBase(var); if (!var_->OverridedStopGradient()) { VLOG(3) << "Sum Gradient for: " << var_->Name(); if (cur_cnt_ == 0) { + if (var->Var().IsType()) { + var_->SetType(framework::proto::VarType::SELECTED_ROWS); + } *dst_var = std::move(*(var->MutableVar())); } else { - TensorAdd(var->Var(), dst_var); + VarBaseAdd(var, var_); } } else { if (!var_->Var().IsInitialized() || @@ -139,10 +271,15 @@ void EagerGradientAccumulator::Add(std::shared_ptr var, void SortedGradientAccumulator::Add(std::shared_ptr var, size_t trace_id) { auto* dst_var = var_->MutableVar(); - auto place = var->Var().Get().place(); + platform::Place place = GetPlaceOfVarBase(var); if (!var_->OverridedStopGradient()) { if (ref_cnt_ == 1) { - *dst_var = std::move(*(var->MutableVar())); + if (var->Var().IsType()) { + var_->SetType(framework::proto::VarType::SELECTED_ROWS); + *dst_var = std::move(*(var->MutableVar())); + } else { + *dst_var = std::move(*(var->MutableVar())); + } } else { if (tmp_grad_vars_.empty()) { tmp_grad_vars_.reserve(ref_cnt_); @@ -160,11 +297,47 @@ void SortedGradientAccumulator::Add(std::shared_ptr var, return p1.second > p2.second; }); - *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); - for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) { - TensorAdd(tmp_grad_vars_[i].first->Var(), dst_var); +#ifdef PADDLE_WITH_CUDA + if (paddle::platform::is_gpu_place(place)) { + bool dst_varbase_is_initialized = false; + // accumulate selected rows firstly + for (size_t i = 0; i < tmp_grad_vars_.size(); ++i) { + if (tmp_grad_vars_[i] + .first->Var() + .IsType()) { + if (!dst_varbase_is_initialized) { + dst_varbase_is_initialized = true; + var_->SetType(framework::proto::VarType::SELECTED_ROWS); + *dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar())); + } else { + VarBaseAdd(tmp_grad_vars_[i].first, var_); + } + } + } + // accumulate lod tensor + for (size_t i = 0; i < tmp_grad_vars_.size(); ++i) { + if (!dst_varbase_is_initialized) { + dst_varbase_is_initialized = true; + *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); + } + if (tmp_grad_vars_[i].first->Var().IsType()) { + VarBaseAdd(tmp_grad_vars_[i].first, var_); + } + } + } else { +#endif + if (tmp_grad_vars_[0].first->Var().IsType()) { + var_->SetType(framework::proto::VarType::SELECTED_ROWS); + *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); + } else { + *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); + } + for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) { + VarBaseAdd(tmp_grad_vars_[i].first, var_); + } +#ifdef PADDLE_WITH_CUDA } - +#endif tmp_grad_vars_.clear(); } } else { diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index c309c671b0d6f53d71ce01f0a91ce63c8731ead1..180c29c6559484889cdccec1df1fe3f42201a7b6 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -16,6 +16,7 @@ #include #include #include +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/imperative/prepared_operator.h" @@ -205,46 +206,79 @@ void VarBase::AddGradOps(const std::weak_ptr& op) { void VarBase::ClearGradient() { if (grad_var_) { - auto* grad_t = grad_var_->var_.GetMutable(); - - if (grad_t->IsInitialized()) { - auto* dev_ctx = - platform::DeviceContextPool::Instance().Get(grad_t->place()); - operators::math::set_constant(*dev_ctx, grad_t, 0.0); + if (grad_var_->var_.IsType()) { + auto* grad_t = grad_var_->var_.GetMutable(); + if (grad_t->mutable_value()->IsInitialized()) { + grad_t->mutable_rows()->clear(); + grad_t->mutable_value()->clear(); + } + } else { + auto* grad_t = grad_var_->var_.GetMutable(); + if (grad_t->IsInitialized()) { + auto* dev_ctx = + platform::DeviceContextPool::Instance().Get(grad_t->place()); + operators::math::set_constant(*dev_ctx, grad_t, 0.0); + } } } } std::shared_ptr VarBase::NewVarBase(const platform::Place& dst_place, const bool blocking) const { - PADDLE_ENFORCE_EQ(var_.IsInitialized() && var_.IsType(), - true, - "Variable must be initialized and type of LoDTensor when " - "getting numpy tensor"); - - auto& src_tensor = var_.Get(); - - // TODO(Jiabin): change this after move unique_name generator to CXX - auto new_var = std::make_shared( - false, "Itmp" + std::to_string(copied_counter_++)); - - auto* dst_tensor = new_var->var_.GetMutable(); - dst_tensor->set_lod(src_tensor.lod()); - - framework::TensorCopy(src_tensor, dst_place, dst_tensor); - if (blocking) { - platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); - auto src_place = src_tensor.place(); - if (!(src_place == dst_place)) { - platform::DeviceContextPool::Instance().Get(src_place)->Wait(); + PADDLE_ENFORCE_EQ( + var_.IsInitialized() && (var_.IsType() || + var_.IsType()), + true, platform::errors::InvalidArgument( + "Variable is not initialized or Variable's type is not " + "LoDTensor or SelectedRows when getting numpy tensor")); + if (var_.IsType()) { + auto& src_tensor = var_.Get(); + + // TODO(Jiabin): change this after move unique_name generator to CXX + auto new_var = std::make_shared( + false, "Itmp" + std::to_string(copied_counter_++)); + + auto* dst_tensor = new_var->var_.GetMutable(); + dst_tensor->set_lod(src_tensor.lod()); + + framework::TensorCopy(src_tensor, dst_place, dst_tensor); + if (blocking) { + platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); + auto src_place = src_tensor.place(); + if (!(src_place == dst_place)) { + platform::DeviceContextPool::Instance().Get(src_place)->Wait(); + } } - } - if (platform::is_gpu_place(dst_place)) { - VLOG(3) << "copy tensor " << Name() << " from gpu"; - } + if (platform::is_gpu_place(dst_place)) { + VLOG(3) << "copy tensor " << Name() << " from gpu"; + } - return new_var; + return new_var; + } else { + auto& src_selected_rows = var_.Get(); + auto new_var = std::make_shared( + false, "Itmp" + std::to_string(copied_counter_++)); + new_var->SetType(framework::proto::VarType::SELECTED_ROWS); + auto* dst_selected_rows = + new_var->var_.GetMutable(); + + framework::TensorCopy(src_selected_rows.value(), dst_place, + dst_selected_rows->mutable_value()); + if (blocking) { + platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); + auto src_place = src_selected_rows.place(); + if (!(src_place == dst_place)) { + platform::DeviceContextPool::Instance().Get(src_place)->Wait(); + } + } + dst_selected_rows->set_height(src_selected_rows.height()); + dst_selected_rows->set_rows(src_selected_rows.rows()); + if (platform::is_gpu_place(dst_place)) { + VLOG(3) << "copy selected rows " << Name() << " from gpu"; + } + return new_var; + } } // create OpBase from optype OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins, diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 4bc7955c2506eee5e247e93682c4e8549f5c27ca..9bf9366c535ebc823d9ab240a599aab11e0999a0 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -453,6 +453,10 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++"; } else { var_set_[name]->SetType(type); + if ((var_set_[name]->MutableVar()->IsInitialized() == true) && + (var_set_[name]->MutableVar()->Type() != type)) { + var_set_[name]->MutableVar()->Clear(); + } } } @@ -766,9 +770,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext { platform::errors::PreconditionNotMet( "The type of %s and %s is not the same.", in, out)); - auto& in_lod_tensor = in_var->Get(); - auto* out_lod_tensor = out_var->GetMutable(); - out_lod_tensor->Resize(in_lod_tensor.dims()); + if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } } void ShareAllLoD(const std::string& in, diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index 8d583f36f2b0abed5c33782ff5767b9761838ac5..67e6294f8bb1d2a40f63f3cff68af1be575f462a 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -5,7 +5,7 @@ else() endif(WIN32) -cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS gradient_accumulator memcpy) +cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 92b4bf1522be522a9bc60570ae6d9c88c2f064ee..b9e876b9b8a0d4bf7c549cd565b1829ca7ce19a6 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -277,14 +277,19 @@ void BindImperative(py::module *m_ptr) { .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); - auto *tensor = - grad_var->MutableVar()->GetMutable(); - if (grad_var && grad_var->Var().IsInitialized() && - tensor->IsInitialized()) { - return grad_var; - } else { - return std::shared_ptr(nullptr); + if (grad_var && grad_var->Var().IsInitialized()) { + auto *tensor = + grad_var->MutableVar()->IsType() + ? grad_var->MutableVar() + ->GetMutable() + : grad_var->MutableVar() + ->GetMutable() + ->mutable_value(); + if (tensor->IsInitialized()) { + return grad_var; + } } + return std::shared_ptr(nullptr); }, py::return_value_policy::copy) .def("_copy_to", @@ -305,6 +310,9 @@ void BindImperative(py::module *m_ptr) { if (self.Var().IsType()) { return framework::vectorize( self.Var().Get().dims()); + } else if (self.Var().IsType()) { + return framework::vectorize( + self.Var().Get().value().dims()); } else { VLOG(2) << "It is meaningless to get shape of variable type " << GetTypeName(self); diff --git a/python/paddle/fluid/dygraph_grad_clip.py b/python/paddle/fluid/dygraph_grad_clip.py index 826f918f36ece2eab5ddf17c1c0b3c86ca4e6438..4fdfc0bc9ded771f695923a7d3e33ca8eb94a1b7 100644 --- a/python/paddle/fluid/dygraph_grad_clip.py +++ b/python/paddle/fluid/dygraph_grad_clip.py @@ -263,7 +263,11 @@ class GradClipByGlobalNorm(GradClipBase): for p, g in para_and_grad: if g is None: continue - power = layers.square(g) + merge_grad = g + if g._ivar.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = layers.merge_selected_rows(g) + merge_grad = layers.get_tensor_from_selected_rows(merge_grad) + power = layers.square(merge_grad) sum_t = layers.reduce_sum(power) norm_arr.append(sum_t) @@ -280,7 +284,7 @@ class GradClipByGlobalNorm(GradClipBase): if g is None: out.append((p, g)) continue - new_grad = g * clip_scale + new_grad = layers.elementwise_mul(x=g, y=clip_scale) out.append((p, new_grad)) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ec6f8198e42ac137baeb59ab153a802e92fcddfc..4a840c04aa3e5e4fd39ac580837a0e0b7e2c8519 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -902,7 +902,7 @@ class Variable(object): Get the Gradient of Current Variable Returns: - ndarray: Numpy value of the gradient of current Variable + ndarray or tuple of ndarray: if Variable's type is LoDTensor, return numpy value of the gradient of current Variable, if Variable's type is SelectedRows, return tuple of ndarray, first element of tuple is numpy value of the gradient of current Variable, second element of tuple is numpy value of the rows of current Variable. Examples: .. code-block:: python @@ -929,12 +929,12 @@ class Variable(object): raise ValueError("%s has no grad, Please set Variable.stop_gradient=False, or " \ "check if this is the first and only variable need grad, if so, please set its pre-Variable's " \ "stop_gradient=False, to make sure it has gradient " % self.name) - if not self._ivar._grad_ivar().value().get_tensor()._is_initialized(): - raise ValueError( - "%s's Grad is Empty, Please check if it has no data in" % - self.name) new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True) - return np.array(new_ivar.value().get_tensor()) + if self._ivar._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS: + return (np.array(new_ivar.value().get_selected_rows().get_tensor()), + np.array(new_ivar.value().get_selected_rows().rows())) + else: + return np.array(new_ivar.value().get_tensor()) @dygraph_only def clear_gradient(self): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 904490b5df093d25457ba419f68014b3b478efaa..f0d2d24eb64394275b539b6d9605f6a6053fc145 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -529,9 +529,11 @@ class Optimizer(object): if not param.trainable: continue if param._ivar._grad_ivar() is not None: + ivar_type = param._ivar._grad_ivar().type # create gradient variable grad_var = Variable( block=loss.block, + type=ivar_type, name=param._ivar._grad_name(), stop_gradient=True, ivar=param._ivar._grad_ivar()) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py new file mode 100644 index 0000000000000000000000000000000000000000..eb06daa0c532b117992a2faa2bd6418e0f678df1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py @@ -0,0 +1,201 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.dygraph.nn import Embedding +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +import numpy as np +import six +from utils import DyGraphProgramDescTracerTestHelper + + +class SimpleNet(fluid.Layer): + def __init__(self, + name_scope, + hidden_size, + vocab_size, + num_steps=20, + init_scale=0.1, + is_sparse=False, + dtype='float32'): + super(SimpleNet, self).__init__(name_scope) + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.init_scale = init_scale + self.num_steps = num_steps + self.embedding = Embedding( + self.full_name(), + size=[vocab_size, hidden_size], + dtype=dtype, + is_sparse=is_sparse, + param_attr=fluid.ParamAttr( + name='embedding_para', + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale))) + self.softmax_bias = self.create_parameter( + attr=fluid.ParamAttr(), + shape=[self.vocab_size], + dtype=dtype, + default_initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)) + + def forward(self, input, label): + x_emb = self.embedding(input) + projection = fluid.layers.matmul( + x_emb, fluid.layers.transpose( + self.embedding._w, perm=[1, 0])) + projection = fluid.layers.elementwise_add(projection, self.softmax_bias) + projection = fluid.layers.reshape( + projection, shape=[-1, self.vocab_size]) + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=label, soft_label=False) + loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) + loss = fluid.layers.reduce_mean(loss, dim=[0]) + loss = fluid.layers.reduce_sum(loss) + loss.permissions = True + + return loss + + +class TestDygraphSimpleNet(unittest.TestCase): + def test_simple_net(self): + for is_sparse in [True, False]: + for dtype in ["float32", "float64"]: + self.simple_net_float32(is_sparse, dtype) + + def simple_net_float32(self, is_sparse, dtype): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + seed = 90 + hidden_size = 10 + vocab_size = 1000 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + batch_num = 200 + + for is_sort_sum_gradient in [True, False]: + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + simple_net = SimpleNet( + "simple_net", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_steps=num_steps, + init_scale=init_scale, + is_sparse=is_sparse, + dtype=dtype) + + sgd = SGDOptimizer(learning_rate=1e-3) + dy_param_updated = dict() + dy_param_init = dict() + dy_loss = None + + helper = DyGraphProgramDescTracerTestHelper(self) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = is_sort_sum_gradient + + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + + x = to_variable(x_data) + y = to_variable(y_data) + outs = simple_net(x, y) + dy_loss = outs + if i == 0: + for param in simple_net.parameters(): + dy_param_init[param.name] = param.numpy() + dy_loss.backward(backward_strategy) + sgd.minimize(dy_loss) + simple_net.clear_gradients() + if i == batch_num - 1: + for param in simple_net.parameters(): + dy_param_updated[param.name] = param.numpy() + dy_loss_value = dy_loss.numpy() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + simple_net = SimpleNet( + "simple_net", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_steps=num_steps, + is_sparse=is_sparse, + dtype=dtype) + + exe = fluid.Executor(place) + sgd = SGDOptimizer(learning_rate=1e-3) + x = fluid.layers.data( + name="x", shape=[-1, num_steps, 1], dtype='int64') + y = fluid.layers.data(name="y", shape=[-1, 1], dtype=dtype) + + static_loss = simple_net(x, y) + sgd.minimize(static_loss) + static_param_updated = dict() + static_param_init = dict() + static_param_name_list = list() + for param in simple_net.parameters(): + static_param_name_list.append(param.name) + + out = exe.run(fluid.default_startup_program(), + fetch_list=static_param_name_list) + for i in range(len(static_param_name_list)): + static_param_init[static_param_name_list[i]] = out[i] + static_loss_value = None + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + fetch_list = [static_loss] + fetch_list.extend(static_param_name_list) + out = exe.run(fluid.default_main_program(), + feed={"x": x_data, + "y": y_data}, + fetch_list=fetch_list) + static_loss_value = out[0] + + if i == batch_num - 1: + for k in range(3, len(out)): + static_param_updated[static_param_name_list[ + k - 1]] = out[k] + + self.assertTrue( + np.array_equal(static_loss_value, dy_loss_value)) + for key, value in six.iteritems(static_param_init): + self.assertTrue(np.array_equal(value, dy_param_init[key])) + for key, value in six.iteritems(static_param_updated): + self.assertTrue( + np.array_equal(value, dy_param_updated[key])) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py index 9dcea95aa97660673214dc37f77a7c3c4d8fe65a..1ef318194f3b78e50c15c5738f4a1ca4ba84e16e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -141,6 +141,7 @@ class PtbModel(fluid.Layer): num_layers=2, num_steps=20, init_scale=0.1, + is_sparse=False, dropout=None): super(PtbModel, self).__init__(name_scope) self.hidden_size = hidden_size @@ -160,7 +161,7 @@ class PtbModel(fluid.Layer): self.full_name(), size=[vocab_size, hidden_size], dtype='float32', - is_sparse=False, + is_sparse=is_sparse, param_attr=fluid.ParamAttr( name='embedding_para', initializer=fluid.initializer.UniformInitializer( @@ -212,7 +213,11 @@ class PtbModel(fluid.Layer): class TestDygraphPtbRnn(unittest.TestCase): - def test_ptb_rnn_cpu_float32(self): + def test_ptb_rnn(self): + for is_sparse in [True, False]: + self.ptb_rnn_cpu_float32(is_sparse) + + def ptb_rnn_cpu_float32(self, is_sparse): seed = 90 hidden_size = 10 vocab_size = 1000 @@ -233,7 +238,8 @@ class TestDygraphPtbRnn(unittest.TestCase): vocab_size=vocab_size, num_layers=num_layers, num_steps=num_steps, - init_scale=init_scale) + init_scale=init_scale, + is_sparse=is_sparse) sgd = SGDOptimizer(learning_rate=1e-3) dy_param_updated = dict() @@ -300,7 +306,8 @@ class TestDygraphPtbRnn(unittest.TestCase): vocab_size=vocab_size, num_layers=num_layers, num_steps=num_steps, - init_scale=init_scale) + init_scale=init_scale, + is_sparse=is_sparse) exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py index ca0b03c60ab2f41a13eb7431688976346c530a34..9e90f0f12a063dc91c96a6ef9b515189d1bc8e54 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py @@ -28,7 +28,11 @@ import six class TestDygraphPtbRnnSortGradient(unittest.TestCase): - def test_ptb_rnn_sort_gradient_cpu_float32(self): + def test_ptb_rnn_sort_gradient(self): + for is_sparse in [True, False]: + self.ptb_rnn_sort_gradient_cpu_float32(is_sparse) + + def ptb_rnn_sort_gradient_cpu_float32(self, is_sparse): seed = 90 hidden_size = 10 vocab_size = 1000 @@ -50,7 +54,8 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase): vocab_size=vocab_size, num_layers=num_layers, num_steps=num_steps, - init_scale=init_scale) + init_scale=init_scale, + is_sparse=is_sparse) sgd = SGDOptimizer(learning_rate=1e-3) dy_param_updated = dict() @@ -97,7 +102,8 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase): vocab_size=vocab_size, num_layers=num_layers, num_steps=num_steps, - init_scale=init_scale) + init_scale=init_scale, + is_sparse=is_sparse) exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py new file mode 100644 index 0000000000000000000000000000000000000000..ec68ff9be969849df90367bb0425feee8619d809 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py @@ -0,0 +1,136 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dygraph.nn import Embedding +from paddle.fluid.optimizer import SGDOptimizer +import numpy as np +import paddle.fluid.core as core + + +class SimpleNet(fluid.Layer): + def __init__(self, name_scope, vocab_size, hidden_size, dtype): + super(SimpleNet, self).__init__(name_scope) + self.emb = fluid.dygraph.Embedding( + self.full_name(), + size=[vocab_size, hidden_size], + dtype=dtype, + param_attr='emb.w', + is_sparse=True) + + def forward(self, input): + input_emb = self.emb(input) + return input_emb, self.emb + + +class TestSimpleNet(unittest.TestCase): + def test_selectedrows_gradient1(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + for dtype in ["float32", "float64"]: + for sort_sum_gradient in [True, False]: + with fluid.dygraph.guard(place): + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = sort_sum_gradient + adam = SGDOptimizer(learning_rate=0.001) + # grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0) + + input_word = np.array( + [[[1], [2]], [[2], [1]]]).astype('int64') + input = to_variable(input_word) + + simplenet = SimpleNet("SimpleNet", 20, 32, dtype) + input_emb, emb = simplenet(input) + + try: + emb._w.gradient() + except ValueError as e: + pass + try: + input_emb.gradient() + except ValueError as e: + pass + + input_emb.backward(backward_strategy) + adam.minimize(input_emb) # grad_clip=grad_clip + emb._w.gradient() + + emb.clear_gradients() + try: + emb._w.gradient() + except ValueError as e: + pass + + input_emb.clear_gradient() + try: + input_emb.gradient() + except ValueError as e: + pass + + def test_selectedrows_gradient2(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + for sort_sum_gradient in [True, False]: + with fluid.dygraph.guard(place): + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = sort_sum_gradient + adam = SGDOptimizer(learning_rate=0.001) + grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm( + 5.0) + + input_word = np.array( + [[[1], [2]], [[2], [1]]]).astype('int64') + input = to_variable(input_word) + + simplenet = SimpleNet("SimpleNet", 20, 32, "float32") + input_emb, emb = simplenet(input) + + try: + emb._w.gradient() + except ValueError as e: + pass + try: + input_emb.gradient() + except ValueError as e: + pass + + input_emb.backward(backward_strategy) + adam.minimize(input_emb, grad_clip=grad_clip) + emb._w.gradient() + + emb.clear_gradients() + try: + emb._w.gradient() + except ValueError as e: + pass + + input_emb.clear_gradient() + try: + input_emb.gradient() + except ValueError as e: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..c7054c15017d48650267899005cd3a634f03aa2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py @@ -0,0 +1,211 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.dygraph.nn import Embedding +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +import numpy as np +import six +from utils import DyGraphProgramDescTracerTestHelper, is_equal_program +from paddle.fluid.dygraph.jit import TracedLayer + + +class SimpleNet(fluid.Layer): + def __init__(self, + name_scope, + hidden_size, + vocab_size, + num_steps=20, + init_scale=0.1, + is_sparse=False, + dtype='float32'): + super(SimpleNet, self).__init__(name_scope) + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.init_scale = init_scale + self.num_steps = num_steps + self.embedding = Embedding( + self.full_name(), + size=[vocab_size, hidden_size], + dtype=dtype, + is_sparse=is_sparse, + param_attr=fluid.ParamAttr( + name='embedding_para', + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale))) + self.softmax_weight = self.create_parameter( + attr=fluid.ParamAttr(), + shape=[self.hidden_size, self.hidden_size], + dtype=dtype, + default_initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)) + self.softmax_bias = self.create_parameter( + attr=fluid.ParamAttr(), + shape=[self.hidden_size], + dtype=dtype, + default_initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)) + + def forward(self, input, label): + x_emb = self.embedding(input) + fc = fluid.layers.matmul(x_emb, self.softmax_weight) + fc = fluid.layers.elementwise_add(fc, self.softmax_bias) + projection = fluid.layers.matmul( + fc, fluid.layers.transpose( + self.embedding._w, perm=[1, 0])) + projection = fluid.layers.reshape( + projection, shape=[-1, self.vocab_size]) + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=label, soft_label=False) + loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) + loss = fluid.layers.reduce_mean(loss, dim=[0]) + loss = fluid.layers.reduce_sum(loss) + loss.permissions = True + + return loss + + +class TestDygraphSimpleNet(unittest.TestCase): + def test_simple_net(self): + for is_sparse in [True, False]: + for dtype in ["float32", "float64"]: + self.simple_net_float(is_sparse, dtype) + + def simple_net_float(self, is_sparse, dtype): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + seed = 90 + hidden_size = 10 + vocab_size = 1000 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + batch_num = 200 + + for is_sort_sum_gradient in [True, False]: + traced_layer = None + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + simple_net = SimpleNet( + "simple_net", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_steps=num_steps, + init_scale=init_scale, + is_sparse=is_sparse, + dtype=dtype) + + sgd = SGDOptimizer(learning_rate=1e-3) + dy_param_updated = dict() + dy_param_init = dict() + dy_loss = None + + helper = DyGraphProgramDescTracerTestHelper(self) + program = None + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = is_sort_sum_gradient + + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + + x = to_variable(x_data) + y = to_variable(y_data) + outs = simple_net(x, y) + dy_loss = outs + if i == 0: + for param in simple_net.parameters(): + dy_param_init[param.name] = param.numpy() + dy_loss.backward(backward_strategy) + sgd.minimize(dy_loss) + simple_net.clear_gradients() + if i == batch_num - 1: + for param in simple_net.parameters(): + dy_param_updated[param.name] = param.numpy() + dy_loss_value = dy_loss.numpy() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + simple_net = SimpleNet( + "simple_net", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_steps=num_steps, + is_sparse=is_sparse, + dtype=dtype) + + exe = fluid.Executor(place) + sgd = SGDOptimizer(learning_rate=1e-3) + x = fluid.layers.data( + name="x", shape=[-1, num_steps, 1], dtype='int64') + y = fluid.layers.data(name="y", shape=[-1, 1], dtype=dtype) + + static_loss = simple_net(x, y) + sgd.minimize(static_loss) + static_param_updated = dict() + static_param_init = dict() + static_param_name_list = list() + for param in simple_net.parameters(): + static_param_name_list.append(param.name) + + out = exe.run(framework.default_startup_program(), + fetch_list=static_param_name_list) + for i in range(len(static_param_name_list)): + static_param_init[static_param_name_list[i]] = out[i] + static_loss_value = None + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + fetch_list = [static_loss] + fetch_list.extend(static_param_name_list) + out = exe.run(fluid.default_main_program(), + feed={"x": x_data, + "y": y_data}, + fetch_list=fetch_list) + static_loss_value = out[0] + + if i == batch_num - 1: + for k in range(3, len(out)): + static_param_updated[static_param_name_list[ + k - 1]] = out[k] + + self.assertTrue( + np.array_equal(static_loss_value, dy_loss_value)) + for key, value in six.iteritems(static_param_init): + self.assertTrue(np.array_equal(value, dy_param_init[key])) + for key, value in six.iteritems(static_param_updated): + self.assertTrue( + np.array_equal(value, dy_param_updated[key])) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index c542235e15f0586f4ea6ecd04e99ce23c24105e4..56c7189e31aa5ef741817fa40f28eb6f2abea4f5 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -586,6 +586,7 @@ class PrepareEncoderDecoderLayer(Layer): src_emb_dim, src_max_len, dropout_rate, + is_sparse=False, word_emb_param_name=None, pos_enc_param_name=None): super(PrepareEncoderDecoderLayer, self).__init__(name_scope) @@ -596,6 +597,7 @@ class PrepareEncoderDecoderLayer(Layer): self._input_emb = Embedding( name_scope=self.full_name(), size=[src_vocab_size, src_emb_dim], + is_sparse=is_sparse, padding_idx=0, param_attr=fluid.ParamAttr( name=word_emb_param_name, @@ -608,6 +610,7 @@ class PrepareEncoderDecoderLayer(Layer): self._pos_emb = Embedding( name_scope=self.full_name(), size=[self._src_max_len, src_emb_dim], + is_sparse=is_sparse, param_attr=fluid.ParamAttr( name=pos_enc_param_name, initializer=fluid.initializer.NumpyArrayInitializer(pos_inp), @@ -633,10 +636,23 @@ class PrepareEncoderDecoderLayer(Layer): class WrapEncoderLayer(Layer): - def __init__(self, name_cope, src_vocab_size, max_length, n_layer, n_head, - d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, - attention_dropout, relu_dropout, preprocess_cmd, - postprocess_cmd, weight_sharing): + def __init__(self, + name_cope, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + is_sparse=False): """ The wrapper assembles together all needed layers for the encoder. """ @@ -648,6 +664,7 @@ class WrapEncoderLayer(Layer): d_model, max_length, prepostprocess_dropout, + is_sparse=is_sparse, word_emb_param_name=word_emb_param_names[0], pos_enc_param_name=pos_enc_param_names[0]) self._encoder = EncoderLayer( @@ -814,7 +831,8 @@ class WrapDecoderLayer(Layer): postprocess_cmd, weight_sharing, caches=None, - gather_idx=None): + gather_idx=None, + is_sparse=False): """ The wrapper assembles together all needed layers for the encoder. """ @@ -826,6 +844,7 @@ class WrapDecoderLayer(Layer): d_model, max_length, prepostprocess_dropout, + is_sparse=is_sparse, word_emb_param_name=word_emb_param_names[1], pos_enc_param_name=pos_enc_param_names[1]) self._decoder_layer = DecoderLayer( @@ -893,7 +912,8 @@ class TransFormer(Layer): weight_sharing, label_smooth_eps, use_py_reader=False, - is_test=False): + is_test=False, + is_sparse=False): super(TransFormer, self).__init__(name_scope) self._label_smooth_eps = label_smooth_eps self._trg_vocab_size = trg_vocab_size @@ -902,15 +922,39 @@ class TransFormer(Layer): "Vocabularies in source and target should be same for weight sharing." ) self._wrap_encoder_layer = WrapEncoderLayer( - self.full_name(), src_vocab_size, max_length, n_layer, n_head, - d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, - attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, - weight_sharing) + self.full_name(), + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + is_sparse=is_sparse) self._wrap_decoder_layer = WrapDecoderLayer( - self.full_name(), trg_vocab_size, max_length, n_layer, n_head, - d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, - attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, - weight_sharing) + self.full_name(), + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + is_sparse=is_sparse) if weight_sharing: self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w @@ -937,7 +981,11 @@ class TransFormer(Layer): class TestDygraphTransformerSortGradient(unittest.TestCase): - def test_transformer_sort_gradient_float32(self): + def test_transformer_sort_gradient(self): + for is_sparse in [True, False]: + self.transformer_sort_gradient_float32(is_sparse) + + def transformer_sort_gradient_float32(self, is_sparse): seed = 90 with guard(): @@ -964,7 +1012,8 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps, use_py_reader=use_py_reader, - is_test=False) + is_test=False, + is_sparse=is_sparse) if sync: lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) @@ -1045,7 +1094,8 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps, use_py_reader=use_py_reader, - is_test=False) + is_test=False, + is_sparse=is_sparse) exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) optimizer = fluid.optimizer.SGD(learning_rate=0.003)