diff --git a/paddle/fluid/operators/optimizers/sparse_momentum_op.cc b/paddle/fluid/operators/optimizers/sparse_momentum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c38545df173115112afbad58d941112fee61f40f --- /dev/null +++ b/paddle/fluid/operators/optimizers/sparse_momentum_op.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/sparse_momentum_op.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class SparseMomentumOpInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + auto in_var_type = ctx->GetInputType("Param"); + PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::LOD_TENSOR, + true, + platform::errors::InvalidArgument( + "Only support LodTensor, Unexpected Input Type.")); + + ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS); + } +}; + +void SparseMomentumOpMaker::Make() { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter that has to be updated"); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter"); + AddInput("Velocity", + "(Tensor, default Tensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated"); + AddInput("Index", + "(Tensor, default Tensor) " + "Input index of Param to do update operation"); + AddInput("Axis", + "The Tensor which contains the axis that we do update operation.") + .AsDispensable(); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "Input learning rate"); + AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); + AddOutput("ParamOut", + "(Tensor) This output is updated parameter. " + "It shared memory with Input(Param)."); + AddOutput("VelocityOut", + "(Tensor) This output is updated velocity. " + "It shared memory with Input(Velocity)."); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable(); + + AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("use_nesterov", + "(bool, default false) " + "Use Nesterov Momentum") + .SetDefault(false); + AddAttr( + "regularization_method", + "(string) regularization_method, right now only support l2decay or none") + .SetDefault(""); + AddAttr("regularization_coeff", "(float) regularization_coeff") + .SetDefault(0.0f); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + AddAttr( + "rescale_grad", + "(float, default 1.0) Multiply the gradient with `rescale_grad`" + "before updating. Often choose to be `1.0/batch_size`.") + .SetDefault(1.0f); + AddAttr("axis", + "(int, default 0) The integer which specific the axis that we " + "do update operation.") + .SetDefault(0); + + AddComment(R"DOC( +Sparse Momentum Optimizer. + +This optimizer has a flag for Nestrov Momentum. +The update equations are as follows: + +$$ +velocity = mu * velocity + gradient \\ +if (use\_nesterov): \\ + param = param - (gradient + mu * velocity) * learning\_rate \\ +else: \\ + param = param - learning\_rate * velocity. \\ +$$ + +)DOC"); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + sparse_momentum, ops::SparseMomentumOp, ops::SparseMomentumOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::SparseMomentumOpInferVarType); +REGISTER_OP_CPU_KERNEL( + sparse_momentum, + ops::SparseMomentumOpKernel, + ops::SparseMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/sparse_momentum_op.cu b/paddle/fluid/operators/optimizers/sparse_momentum_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..64dcbc88a1e14a7bef8cc8ccec08abcd4a844f1a --- /dev/null +++ b/paddle/fluid/operators/optimizers/sparse_momentum_op.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/optimizers/sparse_momentum_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sparse_momentum, + ops::SparseMomentumOpKernel, + ops::SparseMomentumOpKernel, + ops::SparseMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/sparse_momentum_op.h b/paddle/fluid/operators/optimizers/sparse_momentum_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f1516320ec5733a4ce3b1f7ccabea3409fc05bb7 --- /dev/null +++ b/paddle/fluid/operators/optimizers/sparse_momentum_op.h @@ -0,0 +1,492 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/for_range.h" + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +using MultiPrecisionType = typename details::MPTypeTrait::Type; + +enum class RegularizationType { + kNONE = 0, + kL1DECAY = 1, // do not need support right now + kL2DECAY = 2, +}; + +template +struct NoNesterov { + HOSTDEVICE inline T operator()(const T& grad, const T& velocity, + const T& mu) const { + return velocity; + } +}; + +template +struct UseNesterov { + HOSTDEVICE inline T operator()(const T& grad, const T& velocity, + const T& mu) const { + return grad + velocity * mu; + } +}; + +// The following code is from +// https://en.cppreference.com/w/cpp/algorithm/lower_bound +// https://en.cppreference.com/w/cpp/algorithm/upper_bound +template +HOSTDEVICE inline void BinarySearchLowerUpperBound(const T* x, int64_t num, + const T& value, + int64_t* lower_bound, + int64_t* upper_bound) { + *lower_bound = -1; + *upper_bound = -1; + + auto* first = x; + int64_t count = static_cast(num); + while (count > 0) { + int64_t step = (count >> 1); + auto* it = first + step; + if (*it < value) { + first = ++it; + count -= (step + 1); + } else { + count = step; + } + } + auto idx = static_cast(first - x); + if ((idx > 0 && idx < num) || (idx == 0 && x[idx] == value)) { + *lower_bound = idx; + } + + if (*lower_bound >= 0) { + first = x + idx; + count = static_cast(num - idx); + while (count > 0) { + auto step = (count >> 1); + auto* it = first + step; + if (value < *it) { + count = step; + } else { + first = ++it; + count -= (step + 1); + } + } + auto upper_idx = static_cast(first - x) - 1; + if ((upper_idx >= 0 && upper_idx < num - 1) || + (upper_idx == num - 1 && x[upper_idx] == value)) { + *upper_bound = upper_idx; + } + } + return; +} + +template +class RangeFunctor { + private: + T* value_; + + public: + explicit RangeFunctor(T* value) : value_(value) {} + inline HOSTDEVICE void operator()(size_t i) { value_[i] = static_cast(i); } +}; + +class SparseMomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +class SparseMomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "SparseMomentum"); + OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "SparseMomentum"); + OP_INOUT_CHECK(ctx->HasInput("Velocity"), "Input", "Velocity", + "SparseMomentum"); + OP_INOUT_CHECK(ctx->HasInput("Index"), "Input", "Index", "SparseMomentum"); + OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate", + "SparseMomentum"); + OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", + "SparseMomentum"); + OP_INOUT_CHECK(ctx->HasOutput("VelocityOut"), "Output", "VelocityOut", + "SparseMomentum"); + + auto lr_dims = framework::product(ctx->GetInputDim("LearningRate")); + PADDLE_ENFORCE_EQ(lr_dims != 0 && lr_dims == 1, true, + platform::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + lr_dims)); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + platform::errors::InvalidArgument( + "Param and Velocity of SparseMomentumOp should have the same " + "dimension. But received Param's dim [%s] and Velocity [%s].", + param_dim, ctx->GetInputDim("Velocity"))); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("VelocityOut", param_dim); + if (ctx->HasOutput("MasterParamOut")) { + ctx->SetOutputDim("MasterParamOut", param_dim); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class IndexMomentumFunctor { + private: + const T* param_; + const T* grad_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; + const IndexT* sorted_index_; + const IndexT* grad_index_; + const int64_t num_index_; + const int axis_; + const int64_t param_row_numel_; + const int64_t grad_row_numel_; + T* param_out_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; + const UpdateMethod& update_method_; + + public: + IndexMomentumFunctor(const T* param, const T* grad, const MT* velocity, + const MultiPrecisionType* lr, const MT* master_param, + const MT mu, const MT rescale_grad, + const IndexT* sorted_index, const IndexT* grad_index, + int64_t num_index, int axis, int64_t param_row_numel, + int64_t grad_row_numel, + const RegularizationType regularization_flag, + const MT regularization_coeff, + const UpdateMethod& update_method, T* param_out, + MT* velocity_out, MT* master_param_out) + : param_(param), + grad_(grad), + velocity_(velocity), + lr_(lr), + master_param_(master_param), + mu_(mu), + rescale_grad_(rescale_grad), + sorted_index_(sorted_index), + grad_index_(grad_index), + num_index_(num_index), + axis_(axis), + param_row_numel_(param_row_numel), + grad_row_numel_(grad_row_numel), + param_out_(param_out), + velocity_out_(velocity_out), + master_param_out_(master_param_out), + regularization_flag_(regularization_flag), + regularization_coeff_(regularization_coeff), + update_method_(update_method) {} + + inline HOSTDEVICE void operator()(size_t i) { + MT grad = static_cast(0); + size_t row = i / param_row_numel_; + size_t col = i % param_row_numel_; + if (axis_ == 0) { + int64_t row_idx0, row_idx1; + BinarySearchLowerUpperBound(sorted_index_, num_index_, row, + &row_idx0, &row_idx1); + if (row_idx0 >= 0 && row_idx1 >= 0) { + for (int64_t row_idx = row_idx0; row_idx <= row_idx1; row_idx++) { + size_t offset = grad_index_[row_idx] * param_row_numel_ + col; + grad += static_cast(grad_[offset]) * rescale_grad_; + } + } + } else if (axis_ == 1) { + int64_t col_idx0, col_idx1; + BinarySearchLowerUpperBound(sorted_index_, num_index_, col, + &col_idx0, &col_idx1); + if (col_idx0 >= 0 && col_idx1 >= 0) { + for (int64_t col_idx = col_idx0; col_idx <= col_idx1; col_idx++) { + size_t offset = row * grad_row_numel_ + grad_index_[col_idx]; + grad += static_cast(grad_[offset]) * rescale_grad_; + } + } + } + + // put memory access in register + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; + + grad = regularization_flag_ == RegularizationType::kL2DECAY + ? grad + regularization_coeff_ * param + : grad; + + MT velocity_out = velocity * mu_ + grad; + MT velocity_tmp = update_method_(grad, velocity_out, mu_); + MT param_out = param - velocity_tmp * lr; + // write reigster to memory + velocity_out_[i] = velocity_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } + } +}; + +template +class SparseMomentumOpKernel : public framework::OpKernel { + using MPDType = MultiPrecisionType; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const bool multi_precision = ctx.Attr("multi_precision"); + bool use_nesterov = ctx.Attr("use_nesterov"); + auto index = ctx.Input("Index"); + const auto& index_type = index->type(); + if (multi_precision) { + if (use_nesterov) { + auto update_method = UseNesterov(); + if (index_type == framework::proto::VarType::INT32) { + InnerCompute>(ctx, multi_precision, + update_method); + } else { + InnerCompute>( + ctx, multi_precision, update_method); + } + } else { + auto update_method = NoNesterov(); + if (index_type == framework::proto::VarType::INT32) { + InnerCompute>(ctx, multi_precision, + update_method); + } else { + InnerCompute>( + ctx, multi_precision, update_method); + } + } + } else { + if (use_nesterov) { + auto update_method = UseNesterov(); + if (index_type == framework::proto::VarType::INT32) { + InnerCompute>(ctx, multi_precision, + update_method); + } else { + InnerCompute>(ctx, multi_precision, + update_method); + } + } else { + auto update_method = NoNesterov(); + if (index_type == framework::proto::VarType::INT32) { + InnerCompute>(ctx, multi_precision, + update_method); + } else { + InnerCompute>(ctx, multi_precision, + update_method); + } + } + } + } + + private: + template + void InnerCompute(const framework::ExecutionContext& ctx, + const bool multi_precision, + const UpdateMethod& update_method) const { + std::string regularization_method = + ctx.Attr("regularization_method"); + MT regularization_coeff = + static_cast(ctx.Attr("regularization_coeff")); + RegularizationType regularization_flag{ + RegularizationType::kNONE}; // disable regularization + if (regularization_method == "l2_decay") { + regularization_flag = RegularizationType::kL2DECAY; + } + + MT mu = static_cast(ctx.Attr("mu")); + MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); + + int axis = ctx.Attr("axis"); + // get axis from tensor + if (ctx.HasInput("Axis")) { + Tensor cpu_axis; + const Tensor* axis_tensor = ctx.Input("Axis"); + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto& axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); + } + } + PADDLE_ENFORCE_EQ( + axis == 0 || axis == 1, true, + platform::errors::InvalidArgument("The axis of sparse_momentum_op only " + "support axis=0 or axis=1 now.")); + + auto learning_rate = ctx.Input("LearningRate"); + auto param = ctx.Input("Param"); + auto param_out = ctx.Output("ParamOut"); + auto velocity = ctx.Input("Velocity"); + auto velocity_out = ctx.Output("VelocityOut"); + auto index = ctx.Input("Index"); + int64_t num_index = index->numel(); + + // check index of shape 1-D + if (index->dims().size() == 1) { + PADDLE_ENFORCE_GT( + index->dims()[0], 0, + platform::errors::InvalidArgument( + "The index of sparse_momentum_op should not be empty" + "when the index's rank is 1.")); + } else if (index->dims().size() == 2) { + PADDLE_ENFORCE_EQ(index->dims()[1], 1, + platform::errors::InvalidArgument( + "If the index's rank of sparse_momentum_op is 2," + " the second dimension should be 1.")); + } + + const framework::Tensor* master_param = nullptr; + framework::Tensor* master_param_out = nullptr; + if (multi_precision) { + bool has_master = + ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + PADDLE_ENFORCE_EQ(has_master, true, + platform::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + master_param = ctx.Input("MasterParam"); + master_param_out = ctx.Output("MasterParamOut"); + } + + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + const MT* master_in_data = + multi_precision ? master_param->data() : nullptr; + MT* master_out_data = + multi_precision ? master_param_out->mutable_data(ctx.GetPlace()) + : nullptr; + + auto grad = ctx.Input("Grad"); + + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + + auto param_dims = param->dims(); + auto grad_dims = grad->dims(); + + PADDLE_ENFORCE_EQ(param_dims.size(), 2, + platform::errors::InvalidArgument( + "The Param's rank of sparse_momentum_op" + " must be 2 now.")); + PADDLE_ENFORCE_EQ(grad_dims.size(), 2, + platform::errors::InvalidArgument( + "The Grad's rank of sparse_momentum_op" + " must be 2 now.")); + + Tensor sorted_index, grad_index, sort_value; + auto sorted_index_ptr = + sorted_index.mutable_data({num_index}, ctx.GetPlace()); + auto grad_index_ptr = + grad_index.mutable_data({num_index}, ctx.GetPlace()); + + if (platform::is_gpu_place(ctx.GetPlace())) { +#if defined(__NVCC__) || defined(__HIPCC__) + auto sort_value_ptr = + sort_value.mutable_data({num_index}, ctx.GetPlace()); + + platform::ForRange for_range_index( + static_cast(ctx.device_context()), num_index); + RangeFunctor range_functor(sort_value_ptr); + for_range_index(range_functor); + + size_t temp_storage_bytes = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + (cub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, nullptr, nullptr, nullptr, nullptr, + static_cast(num_index)))); + auto d_temp_storage = memory::Alloc(ctx.GetPlace(), temp_storage_bytes); + PADDLE_ENFORCE_CUDA_SUCCESS( + (cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, index->data(), + sorted_index_ptr, sort_value_ptr, grad_index_ptr, + static_cast(num_index), 0, sizeof(IndexT) * 8, + ctx.cuda_device_context().stream()))); +#endif + } else if (platform::is_cpu_place(ctx.GetPlace())) { + std::vector> vec_tosort; + auto index_ptr = index->data(); + for (IndexT i = 0; i < num_index; i++) { + vec_tosort.push_back({index_ptr[i], i}); + } + std::sort(vec_tosort.begin(), vec_tosort.end(), + [](const std::pair& k1, + const std::pair& k2) { + return k1.first < k2.first; + }); + for (IndexT i = 0; i < num_index; i++) { + sorted_index_ptr[i] = vec_tosort[i].first; + grad_index_ptr[i] = vec_tosort[i].second; + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "sparse_momentum %s is not supported.", ctx.GetPlace())); + } + + IndexMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), master_in_data, mu, rescale_grad, + sorted_index_ptr, grad_index_ptr, num_index, axis, param_dims[1], + grad_dims[1], regularization_flag, regularization_coeff, update_method, + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 4b610f3bccba0f1800a33fe4bf66de073c8749cc..a706cc49f5cf6edaab277d4be92d7bd5d7454bc7 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -64,6 +64,7 @@ std::map> op_ins_map = { {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, {"momentum", {"Param", "Grad", "Velocity", "LearningRate"}}, + {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"run_program", {"X", "Params"}}, }; @@ -97,6 +98,7 @@ std::map> op_outs_map = { {"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, {"momentum", {"ParamOut", "VelocityOut"}}, + {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"lamb", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, @@ -124,6 +126,7 @@ std::map> op_passing_outs_map = { {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", "out_old_num_accumulates", "out_num_updates"}}, {"momentum", {"ParamOut", "VelocityOut"}}, + {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, {"accuracy", {"Correct", "Total"}}, diff --git a/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py b/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e36cb72efc725347cf4d4d480e740768775e4052 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py @@ -0,0 +1,242 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from op_test import OpTest +import paddle +import paddle.fluid as fluid + + +def calculate_sparse_momentum_by_numpy(param, + grad, + mu, + velocity, + use_nesterov, + learning_rate, + index, + axis, + regularization_method=None, + regularization_coeff=1.0): + sub_grad = grad.copy() + grad = np.zeros_like(param) + if axis == 0: + unique_index = np.unique(index) + for idx in unique_index: + grad[idx, :] = np.sum(sub_grad[index == idx, :], axis=0) + else: + unique_index = np.unique(index) + for idx in unique_index: + grad[:, idx] = np.sum(sub_grad[:, index == idx], axis=1) + if regularization_method == "l2_decay": + grad = grad + regularization_coeff * param + + velocity_out = mu * velocity + grad + if use_nesterov: + param_out = param - (grad + velocity_out * mu) * learning_rate + else: + param_out = param - learning_rate * velocity_out + else: + velocity_out = mu * velocity + grad + if use_nesterov: + param_out = param - grad * learning_rate - \ + velocity_out * mu * learning_rate + else: + param_out = param - learning_rate * velocity_out + + return param_out, velocity_out + + +class TestSparseMomentumOp(OpTest): + def setUp(self): + self.op_type = "sparse_momentum" + self.dtype = np.float32 + self.index_dtype = np.int32 + self.axis = 0 + self.multi_precision = False + self.use_nesterov = False + self.batch_size = 20 + self.num_classes = 20 + self.init_dtype() + self.init_axis() + self.init_multi_precision() + self.init_use_nesterov() + + if self.multi_precision: + assert self.dtype == np.float16 + + param = np.random.random( + (self.batch_size, self.num_classes)).astype(self.dtype) + grad = np.random.random( + (self.batch_size, self.num_classes)).astype(self.dtype) + if self.axis == 0: + index = np.random.randint( + 0, + self.batch_size, + size=(self.batch_size // 2, ), + dtype=self.index_dtype) + grad = grad[index] + else: + index = np.random.randint( + 0, + self.num_classes, + size=(self.num_classes // 2, ), + dtype=self.index_dtype) + grad = grad[:, index] + velocity = np.random.random( + (self.batch_size, self.num_classes)).astype(self.dtype) + learning_rate = np.array([0.001]).astype(self.dtype) + + mu = 0.9 + regularization_method = "l2_decay" + regularization_coeff = 1.0 + + param_out, velocity_out = calculate_sparse_momentum_by_numpy( + param=param, + grad=grad, + mu=mu, + velocity=velocity, + use_nesterov=self.use_nesterov, + learning_rate=learning_rate, + regularization_method=regularization_method, + regularization_coeff=regularization_coeff, + index=index, + axis=self.axis) + + self.attrs = { + 'mu': mu, + 'use_nesterov': self.use_nesterov, + 'regularization_method': regularization_method, + 'regularization_coeff': regularization_coeff, + 'multi_precision': self.multi_precision, + 'axis': self.axis, + } + + self.inputs = { + 'Param': param.astype("float16") if self.multi_precision else param, + 'Velocity': velocity.astype("float32") + if self.multi_precision else velocity, + 'LearningRate': learning_rate.astype("float32") + if self.multi_precision else learning_rate, + 'Grad': grad.astype("float16") if self.multi_precision else grad, + 'Index': index, + 'Axis': np.array(self.axis).astype(np.int32), + } + self.outputs = { + 'ParamOut': param_out.astype("float16") + if self.multi_precision else param_out, + 'VelocityOut': velocity_out.astype("float32") + if self.multi_precision else velocity_out, + } + + if self.multi_precision: + self.inputs['MasterParam'] = param.astype( + "float32") if self.multi_precision else param + self.outputs['MasterParamOut'] = param_out.astype( + "float32") if self.multi_precision else param_out + + def init_dtype(self): + pass + + def init_axis(self): + pass + + def init_multi_precision(self): + pass + + def init_use_nesterov(self): + pass + + def test_check_output(self): + self.check_output(atol=5e-3 if self.multi_precision else 1e-5) + + +class TestSparseMomentumOpDtype1(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float32 + self.index_dtype = np.int64 + + +class TestSparseMomentumOpDtype2(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float64 + self.index_dtype = np.int32 + + +class TestSparseMomentumOpDtype3(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float64 + self.index_dtype = np.int64 + + +class TestSparseMomentumOpAxis(TestSparseMomentumOp): + def init_axis(self): + self.axis = 1 + + +class TestSparseMomentumOpNesterov(TestSparseMomentumOp): + def init_use_nesterov(self): + self.use_nesterov = True + + +class TestSparseMomentumOpMultiPrecision(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float16 + self.index_dtype = np.int32 + + def init_multi_precision(self): + self.multi_precision = True + + def init_use_nesterov(self): + self.use_nesterov = True + + +class TestSparseMomentumOpMultiPrecision1(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float16 + self.index_dtype = np.int64 + + def init_multi_precision(self): + self.multi_precision = True + + def init_use_nesterov(self): + self.use_nesterov = True + + +class TestSparseMomentumOpMultiPrecision2(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float16 + self.index_dtype = np.int32 + + def init_multi_precision(self): + self.multi_precision = True + + def init_use_nesterov(self): + self.use_nesterov = False + + +class TestSparseMomentumOpMultiPrecision3(TestSparseMomentumOp): + def init_dtype(self): + self.dtype = np.float16 + self.index_dtype = np.int64 + + def init_multi_precision(self): + self.multi_precision = True + + def init_use_nesterov(self): + self.use_nesterov = False diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 4255e1f4e440d0f5d29ae77a79ab5dc08dda6037..4442ff538cb0179ed3e050cd9832bf10e535a0c0 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -333,6 +333,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_mish_op', 'test_modified_huber_loss_op', 'test_momentum_op', + 'test_sparse_momentum_op', 'test_monitor', 'test_mse_loss', 'test_mul_op',