diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 3e724f52e4fcb2818f961181a2830dfd653d2733..3455d1ee54e8e6e498d0b0e6932ec099af9c0b30 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" @@ -199,23 +200,9 @@ struct SparseAdamFunctor { row_numel_(row_numel), row_count_(row_count) {} - inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const { - int64_t beg = 0, end = row_count_ - 1; - while (beg <= end) { - auto mid = ((beg + end) >> 1); - if (rows_[mid] == row) - return mid; - else if (rows_[mid] < row) - beg = mid + 1; - else - end = mid - 1; - } - return -1; - } - inline HOSTDEVICE void operator()(size_t i) const { - int64_t row = i / row_numel_; - auto row_idx = BinarySearchInRows(row); + auto row_idx = + math::BinarySearch(rows_, row_count_, i / row_numel_); T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; // The following code is the same as dense diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h new file mode 100644 index 0000000000000000000000000000000000000000..262469beea7449eb5820b86de1ac4f790a833e79 --- /dev/null +++ b/paddle/fluid/operators/math/algorithm.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for int64_t +#include + +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { + int64_t beg = 0, end = num - 1; + while (beg <= end) { + auto mid = ((beg + end) >> 1); + if (x[mid] == val) + return mid; + else if (x[mid] < val) + beg = mid + 1; + else + end = mid - 1; + } + return -1; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 6810c24227ed074d4ba747399348aa12f55467a0..08f57dd45ad76946cbcafb98a3414003ed9d67a9 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" namespace paddle { @@ -245,40 +247,42 @@ struct MergeAdd { const framework::SelectedRows& input, framework::SelectedRows* output) { framework::SelectedRows& out = *output; - auto input_rows = input.rows(); - std::vector merge_rows; - merge_rows.reserve(input_rows.size()); - std::unordered_map rows_pos_map; - rows_pos_map.reserve(input_rows.size()); - size_t idx = 0u; - for (std::vector::iterator iter = input_rows.begin(); - iter != input_rows.end(); ++iter) { - if (rows_pos_map.find(*iter) == rows_pos_map.end()) { - rows_pos_map[*iter] = idx++; - merge_rows.emplace_back(*iter); - } + std::vector input_rows(input.rows()); + + std::map> merge_row_map; + for (size_t i = 0; i < input_rows.size(); ++i) { + merge_row_map[input_rows[i]].push_back(i); } - auto input_width = input.value().dims()[1]; - out.set_rows(merge_rows); + std::vector merge_rows(merge_row_map.size()); + size_t idx = 0; + int64_t input_width = input.value().dims()[1]; out.set_height(input.height()); - out.mutable_value()->mutable_data( + + T* out_data = out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); - - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); - - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_pos_map[input_rows[i]]; - for (int64_t j = 0; j < input_width; j++) { - out_data[out_i * input_width + j] += input_data[i * input_width + j]; + const T* in_data = input.value().data(); + + for (auto& row_pair : merge_row_map) { + auto* out_ptr = out_data + idx * input_width; + auto& rows = row_pair.second; + merge_rows[idx] = row_pair.first; + ++idx; + // rows.size() is always larger than 0 + std::memcpy(out_ptr, in_data + rows[0] * input_width, + sizeof(T) * input_width); + + for (size_t i = 1; i < rows.size(); ++i) { + auto* in_ptr = in_data + rows[i] * input_width; + for (int64_t j = 0; j < input_width; ++j) { + out_ptr[j] += in_ptr[j]; + } } } + + out.set_rows(merge_rows); } }; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 3d99c9b3f2303c941a150f77516703b8103c6c25..900be86f91c6658a5265189a6745316c6471209e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/eigen.h" @@ -97,41 +98,39 @@ struct MergeAdd { const framework::SelectedRows& input, framework::SelectedRows* output) { framework::SelectedRows& out = *output; - auto input_rows = input.rows(); - std::vector merge_rows; - merge_rows.reserve(input_rows.size()); - std::unordered_map rows_pos_map; - rows_pos_map.reserve(input_rows.size()); - size_t idx = 0u; - for (std::vector::iterator iter = input_rows.begin(); - iter != input_rows.end(); ++iter) { - if (rows_pos_map.find(*iter) == rows_pos_map.end()) { - rows_pos_map[*iter] = idx++; - merge_rows.emplace_back(*iter); - } + std::vector input_rows(input.rows()); + + std::map> merge_row_map; + for (size_t i = 0; i < input_rows.size(); ++i) { + merge_row_map[input_rows[i]].push_back(i); } - auto input_width = input.value().dims()[1]; - out.set_rows(merge_rows); + std::vector merge_rows(merge_row_map.size()); + size_t idx = 0; + int64_t input_width = input.value().dims()[1]; out.set_height(input.height()); - out.mutable_value()->mutable_data( + + auto* out_data = out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); - - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); + auto* in_data = input.value().data(); auto blas = GetBlas(context); - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_pos_map[input_rows[i]]; - float* y = out_data + out_i * input_width; - const float* x = input_data + i * input_width; - blas.AXPY(input_width, 1., x, y); + for (auto& row_pair : merge_row_map) { + auto* out_ptr = out_data + idx * input_width; + auto& rows = row_pair.second; + merge_rows[idx] = row_pair.first; + ++idx; + // rows.size() is always larger than 0 + blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr); + + for (size_t i = 1; i < rows.size(); ++i) { + blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr); + } } + + out.set_rows(merge_rows); } }; @@ -148,41 +147,39 @@ struct MergeAdd { const framework::SelectedRows& input, framework::SelectedRows* output) { framework::SelectedRows& out = *output; - auto input_rows = input.rows(); - std::vector merge_rows; - merge_rows.reserve(input_rows.size()); - std::unordered_map rows_pos_map; - rows_pos_map.reserve(input_rows.size()); - size_t idx = 0u; - for (std::vector::iterator iter = input_rows.begin(); - iter != input_rows.end(); ++iter) { - if (rows_pos_map.find(*iter) == rows_pos_map.end()) { - rows_pos_map[*iter] = idx++; - merge_rows.emplace_back(*iter); - } + std::vector input_rows(input.rows()); + + std::map> merge_row_map; + for (size_t i = 0; i < input_rows.size(); ++i) { + merge_row_map[input_rows[i]].push_back(i); } - auto input_width = input.value().dims()[1]; - out.set_rows(merge_rows); + std::vector merge_rows(merge_row_map.size()); + size_t idx = 0; + int64_t input_width = input.value().dims()[1]; out.set_height(input.height()); - out.mutable_value()->mutable_data( + + auto* out_data = out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); - - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); + auto* in_data = input.value().data(); auto blas = GetBlas(context); - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_pos_map[input_rows[i]]; - double* y = out_data + out_i * input_width; - const double* x = input_data + i * input_width; - blas.AXPY(input_width, 1., x, y); + for (auto& row_pair : merge_row_map) { + auto* out_ptr = out_data + idx * input_width; + auto& rows = row_pair.second; + merge_rows[idx] = row_pair.first; + ++idx; + // rows.size() is always larger than 0 + blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr); + + for (size_t i = 1; i < rows.size(); ++i) { + blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr); + } } + + out.set_rows(merge_rows); } }; diff --git a/paddle/fluid/operators/rmsprop_op.h b/paddle/fluid/operators/rmsprop_op.h index a04d1bd2ca5128714d10155a9679d921519cfe07..797cd45fdcdbd5c3567d1676f37e148304ee6e2d 100644 --- a/paddle/fluid/operators/rmsprop_op.h +++ b/paddle/fluid/operators/rmsprop_op.h @@ -13,72 +13,254 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; template using EigenVector = framework::EigenVector; +template +struct DenseRmspropGradFunctor { + inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; } + + const T *grad_; +}; + +template +struct SparseRmspropGradFunctor { + inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows, + int64_t row_numel, int64_t row_count) + : grad_(grad), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { + auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_); + return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0; + } + + const T *grad_; + const int64_t *rows_; + int64_t row_numel_; + int64_t row_count_; +}; + +template +struct UncenteredRmspropFunctor { + UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho, + T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + } + + T *param_; + T *ms_; + T *mom_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + +template +struct CenteredRmspropFunctor { + CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr, + T rho, T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + mean_grad_(mean_grad), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g; + T mom_out = momentum_ * mom_[idx] + + lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + mean_grad_[idx] = mg_out; + } + + T *param_; + T *ms_; + T *mom_; + T *mean_grad_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + template class RmspropOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE(param_var->IsType(), - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.Inputs("Param").front(), param_var->Type().name()); - - auto* param_out = ctx.Output("ParamOut"); - auto* moment_out = ctx.Output("MomentOut"); - auto* mean_square_out = ctx.Output("MeanSquareOut"); - - auto grad = ctx.Input("Grad"); - - param_out->mutable_data(ctx.GetPlace()); - moment_out->mutable_data(ctx.GetPlace()); - mean_square_out->mutable_data(ctx.GetPlace()); - - float epsilon = ctx.Attr("epsilon"); - float rho = ctx.Attr("decay"); - float momentum = ctx.Attr("momentum"); + void Compute(const framework::ExecutionContext &ctx) const override { + using LoDTensor = framework::LoDTensor; + auto *grad_var = ctx.InputVar("Grad"); + auto *param_out = ctx.Output("ParamOut"); + auto *moment_out = ctx.Output("MomentOut"); + auto *mean_square_out = ctx.Output("MeanSquareOut"); + + auto epsilon = static_cast(ctx.Attr("epsilon")); + auto rho = static_cast(ctx.Attr("decay")); + auto momentum = static_cast(ctx.Attr("momentum")); bool centered = ctx.Attr("centered"); - auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); - auto g = EigenVector::Flatten(*grad); - auto mom = EigenVector::Flatten(*ctx.Input("Moment")); - - auto p_out = EigenVector::Flatten(*param_out); - auto mom_out = EigenVector::Flatten(*moment_out); - auto ms_out = EigenVector::Flatten(*mean_square_out); - auto& place = *ctx.template device_context().eigen_device(); - - Eigen::DSizes grad_dsize(static_cast(grad->numel())); - - ms_out.device(place) = rho * ms + (1 - rho) * g * g; - if (centered) { - auto mg = EigenVector::Flatten(*ctx.Input("MeanGrad")); - auto* mean_grad_out = ctx.Output("MeanGradOut"); - mean_grad_out->mutable_data(ctx.GetPlace()); - auto mg_out = EigenVector::Flatten(*mean_grad_out); - - mg_out.device(place) = rho * mg + (1 - rho) * g; - mom_out.device(place) = momentum * mom + - lr.broadcast(grad_dsize) * g / - (ms_out - mg_out.square() + epsilon).sqrt(); + auto &p_tensor = *ctx.Input("Param"); + auto &ms_tensor = *ctx.Input("MeanSquare"); + auto &lr_tensor = *ctx.Input("LearningRate"); + auto &mom_tensor = *ctx.Input("Moment"); + + PADDLE_ENFORCE_EQ(&p_tensor, param_out, + "Param and ParamOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&mom_tensor, moment_out, + "Moment and MomentOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out, + "MeanSquare and MeanSquareOut must be the same Tensor"); + + auto &dev_ctx = ctx.template device_context(); + size_t limit = static_cast(ms_tensor.numel()); + + if (grad_var->IsType()) { + auto &grad_tensor = grad_var->Get(); + + if (std::is_same::value) { + auto &place = + *ctx.template device_context().eigen_device(); + auto lr_value = lr_tensor.data()[0]; + + auto p = EigenVector::Flatten(p_tensor); + auto ms = EigenVector::Flatten(ms_tensor); + auto g = EigenVector::Flatten(grad_tensor); + auto mom = EigenVector::Flatten(mom_tensor); + + auto p_out = EigenVector::Flatten(*param_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); + + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto mg = EigenVector::Flatten(mg_tensor); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + auto mg_out = EigenVector::Flatten(*mean_grad_out); + + mg_out.device(place) = rho * mg + (1 - rho) * g; + mom_out.device(place) = + momentum * mom + + lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt(); + } else { + mom_out.device(place) = + momentum * mom + lr_value * g / (ms_out + epsilon).sqrt(); + } + p_out.device(place) = p - mom_out; + } else { + DenseRmspropGradFunctor grad_func(grad_tensor.data()); + platform::ForRange for_range(dev_ctx, limit); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), + lr_tensor.data(), rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } + } + } else if (grad_var->IsType()) { + auto &grad = grad_var->Get(); + auto *merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + + math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, merged_grad); + + platform::ForRange for_range(dev_ctx, limit); + const int64_t *rows; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { +#endif + rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + auto &merged_tensor = merged_grad->value(); + int64_t row_count = merged_grad->rows().size(); + int64_t row_numel = merged_tensor.numel() / row_count; + SparseRmspropGradFunctor grad_func(merged_tensor.data(), rows, + row_numel, row_count); + + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } } else { - mom_out.device(place) = - momentum * mom + - lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); + PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient"); } - p_out.device(place) = p - mom_out; } }; diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 70848e4e2239e2be160bb0c1a28a5aecd01a87dc..eb12bc741767340a3e7e3580a8b95065d4267693 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -19,33 +19,76 @@ import unittest import numpy as np import paddle.fluid.core as core from paddle.fluid.op import Operator +import paddle.fluid as fluid + + +def create_selected_rows_and_tensor(scope, place, height, row_num, + embedding_size): + sr = scope.var("@selected_rows@").get_selected_rows() + tensor = scope.var("grad").get_tensor() + + rows = np.random.random_integers( + low=0, high=height - 1, size=[row_num, ]).astype('int64') + sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32') + + sr.set_height(height) + sr.set_rows(rows) + sr.get_tensor().set(sr_val, place) + + tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32') + for i in range(row_num): + row = rows[i] + tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :] + + tensor.set(tensor_val, place) + return tensor_val, sr_val class TestBase(unittest.TestCase): - def setup(self, centered, epsilon=1e-6): + def setup(self, + place, + is_sparse, + centered, + size, + row_num=None, + epsilon=1e-6): np.random.seed(5) # fix seed + self.scope = fluid.global_scope() + self.place = place + self.param_name = "param" - self.param = np.random.random((123, 321)).astype("float32") + self.param = np.random.random(size).astype("float32") self.mean_square_name = "mean_square" - self.mean_square = np.random.random((123, 321)).astype("float32") + self.mean_square = np.random.uniform( + low=1, high=2, size=size).astype("float32") self.mean_grad_name = "mean_grad" - self.mean_grad = np.random.random((123, 321)).astype("float32") + self.mean_grad = np.random.random(size).astype("float32") self.lr_name = "lr" self.learning_rate = np.array([0.01]).astype("float32") self.grad_name = "grad" - self.grad = np.random.random((123, 321)).astype("float32") + + self.is_sparse = is_sparse + if self.is_sparse: + self.grad_sr_name = "@selected_rows@" + self.grad, self.grad_sr = create_selected_rows_and_tensor( + self.scope, place, size[0], row_num, size[1]) + else: + self.grad = np.random.random(size).astype("float32") + grad_tensor = self.scope.var(self.grad_name).get_tensor() + grad_tensor.set(self.grad, place) self.moment_name = "moment" - self.moment = np.zeros((123, 321)).astype("float32") + self.moment = np.random.uniform( + low=0, high=1, size=size).astype("float32") self.epsilon = epsilon self.decay = 0.9 - self.momentum = 0.0 + self.momentum = 0.1 self.centered = centered self.ms_out = self.decay * self.mean_square + (1 - self.decay @@ -61,118 +104,122 @@ class TestBase(unittest.TestCase): self.param_out = self.param - self.moment_out - def check(self, - actual_t, - expect_t, - place, - out_name, - atol=1e-5, - equal_nan=False): - self.assertTrue( - np.allclose( - actual_t, expect_t, atol=atol, equal_nan=equal_nan), - "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " - + str(expect_t) + "\n" + "But Got" + str(actual_t)) - - -class TestRmspropOp(TestBase): - def check_with_place(self, place, centered, epsilon): - self.setup(centered, epsilon) - scope = core.Scope() - # create and initialize Param Variable - param = scope.var(self.param_name).get_tensor() - param.set(self.param, place) + self.param_tensor = self.scope.var(self.param_name).get_tensor() + self.param_tensor.set(self.param, place) - mean_square = scope.var(self.mean_square_name).get_tensor() - mean_square.set(self.mean_square, place) + self.mean_square_tensor = self.scope.var( + self.mean_square_name).get_tensor() + self.mean_square_tensor.set(self.mean_square, place) - lr = scope.var(self.lr_name).get_tensor() + lr = self.scope.var(self.lr_name).get_tensor() lr.set(self.learning_rate, place) - grad = scope.var(self.grad_name).get_tensor() - grad.set(self.grad, place) + self.moment_tensor = self.scope.var(self.moment_name).get_tensor() + self.moment_tensor.set(self.moment, place) - moment = scope.var(self.moment_name).get_tensor() - moment.set(self.moment, place) + if self.centered: + self.mean_grad_tensor = self.scope.var( + self.mean_grad_name).get_tensor() + self.mean_grad_tensor.set(self.mean_grad, place) - # create and run sgd operator + def check(self, actual_t, expect_t, place, out_name, atol=1e-5): + self.assertTrue( + np.allclose( + actual_t, expect_t, atol=atol), + "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " + + str(expect_t) + "\n" + "But Got" + str(actual_t)) - if self.centered: - mean_grad = scope.var(self.mean_grad_name).get_tensor() - mean_grad.set(self.mean_grad, place) - - rmsprop_op = Operator( - "rmsprop", - Param=self.param_name, - Grad=self.grad_name, - MeanSquare=self.mean_square_name, - MeanGrad=self.mean_grad_name, - Moment=self.moment_name, - LearningRate=self.lr_name, - ParamOut=self.param_name, - MeanSquareOut=self.mean_square_name, - MomentOut=self.moment_name, - MeanGradOut=self.mean_grad_name, - epsilon=self.epsilon, - decay=self.decay, - momentum=self.momentum, - centered=True) - else: - rmsprop_op = Operator( - "rmsprop", - Param=self.param_name, - Grad=self.grad_name, - MeanSquare=self.mean_square_name, - Moment=self.moment_name, - LearningRate=self.lr_name, - ParamOut=self.param_name, - MeanSquareOut=self.mean_square_name, - MomentOut=self.moment_name, - epsilon=self.epsilon, - decay=self.decay, - momentum=self.momentum, - centered=False) - - rmsprop_op.run(scope, place) - - atol = 1e-5 - equal_nan = False + +class TestRmspropOp(TestBase): + def check_with_place(self, + place, + is_sparse, + centered, + size, + row_num=None, + epsilon=1e-6): + self.setup(place, is_sparse, centered, size, row_num, epsilon) + self.run_and_check() + + def run_and_check(self): + grad_name = self.grad_sr_name if self.is_sparse else self.grad_name + + kwargs = { + 'Param': self.param_name, + 'Grad': grad_name, + 'MeanSquare': self.mean_square_name, + 'Moment': self.moment_name, + 'LearningRate': self.lr_name, + 'ParamOut': self.param_name, + 'MeanSquareOut': self.mean_square_name, + 'MomentOut': self.moment_name, + 'epsilon': self.epsilon, + 'decay': self.decay, + 'momentum': self.momentum, + 'centered': self.centered + } if self.centered: - atol = 1e-3 - equal_nan = True + kwargs['MeanGrad'] = self.mean_grad_name + kwargs['MeanGradOut'] = self.mean_grad_name + + rmsprop_op = Operator('rmsprop', **kwargs) + atol = 1e-6 + + rmsprop_op.run(self.scope, self.place) self.check( - np.array(mean_square), self.ms_out, place, self.mean_square_name) + np.array(self.mean_square_tensor), + self.ms_out, + self.place, + self.mean_square_name, + atol=atol) self.check( - np.array(moment), + np.array(self.moment_tensor), self.moment_out, - place, + self.place, self.moment_name, - atol=atol, - equal_nan=equal_nan) + atol=atol) self.check( - np.array(param), + np.array(self.param_tensor), self.param_out, - place, + self.place, self.param_name, - atol=atol, - equal_nan=equal_nan) + atol=atol) if self.centered: self.check( - np.array(mean_grad), self.mg_out, place, self.mean_grad_name) + np.array(self.mean_grad_tensor), self.mg_out, self.place, + self.mean_grad_name) def test_rmsprop(self): places = [core.CPUPlace()] if core.is_compiled_with_cuda(): places.append(core.CUDAPlace(0)) + + size = (128, 320) for place in places: - self.check_with_place(place, False, 1e-6) - self.check_with_place(place, False, 1e-10) - self.check_with_place(place, True, 1e-6) - self.check_with_place(place, True, 1e-10) + for centered in [False, True]: + with fluid.scope_guard(core.Scope()): + self.check_with_place( + place, is_sparse=False, centered=centered, size=size) + + with fluid.scope_guard(core.Scope()): + self.check_with_place( + place, + is_sparse=True, + centered=centered, + row_num=512, + size=size) + + with fluid.scope_guard(core.Scope()): + self.check_with_place( + place, + is_sparse=True, + centered=centered, + row_num=60, + size=size) if __name__ == "__main__":