未验证 提交 234ce932 编写于 作者: G Guoxia Wang 提交者: GitHub

sparse_momentum_op is used to save w@GRAD memory for gather_op (#34942)

* sparse_momentum_op is used to save w@GRAD memory for gather_op when gather from a large parameter
上级 1533d7e2
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/sparse_momentum_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class SparseMomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto in_var_type = ctx->GetInputType("Param");
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::LOD_TENSOR,
true,
platform::errors::InvalidArgument(
"Only support LodTensor, Unexpected Input Type."));
ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
}
};
void SparseMomentumOpMaker::Make() {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("Index",
"(Tensor, default Tensor<int>) "
"Input index of Param to do update operation");
AddInput("Axis",
"The Tensor which contains the axis that we do update operation.")
.AsDispensable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
AddAttr<std::string>(
"regularization_method",
"(string) regularization_method, right now only support l2decay or none")
.SetDefault("");
AddAttr<float>("regularization_coeff", "(float) regularization_coeff")
.SetDefault(0.0f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddAttr<int>("axis",
"(int, default 0) The integer which specific the axis that we "
"do update operation.")
.SetDefault(0);
AddComment(R"DOC(
Sparse Momentum Optimizer.
This optimizer has a flag for Nestrov Momentum.
The update equations are as follows:
$$
velocity = mu * velocity + gradient \\
if (use\_nesterov): \\
param = param - (gradient + mu * velocity) * learning\_rate \\
else: \\
param = param - learning\_rate * velocity. \\
$$
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
sparse_momentum, ops::SparseMomentumOp, ops::SparseMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::SparseMomentumOpInferVarType);
REGISTER_OP_CPU_KERNEL(
sparse_momentum,
ops::SparseMomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SparseMomentumOpKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/sparse_momentum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sparse_momentum,
ops::SparseMomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SparseMomentumOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SparseMomentumOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
enum class RegularizationType {
kNONE = 0,
kL1DECAY = 1, // do not need support right now
kL2DECAY = 2,
};
template <typename T>
struct NoNesterov {
HOSTDEVICE inline T operator()(const T& grad, const T& velocity,
const T& mu) const {
return velocity;
}
};
template <typename T>
struct UseNesterov {
HOSTDEVICE inline T operator()(const T& grad, const T& velocity,
const T& mu) const {
return grad + velocity * mu;
}
};
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
template <typename T>
HOSTDEVICE inline void BinarySearchLowerUpperBound(const T* x, int64_t num,
const T& value,
int64_t* lower_bound,
int64_t* upper_bound) {
*lower_bound = -1;
*upper_bound = -1;
auto* first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
int64_t step = (count >> 1);
auto* it = first + step;
if (*it < value) {
first = ++it;
count -= (step + 1);
} else {
count = step;
}
}
auto idx = static_cast<int64_t>(first - x);
if ((idx > 0 && idx < num) || (idx == 0 && x[idx] == value)) {
*lower_bound = idx;
}
if (*lower_bound >= 0) {
first = x + idx;
count = static_cast<int64_t>(num - idx);
while (count > 0) {
auto step = (count >> 1);
auto* it = first + step;
if (value < *it) {
count = step;
} else {
first = ++it;
count -= (step + 1);
}
}
auto upper_idx = static_cast<int64_t>(first - x) - 1;
if ((upper_idx >= 0 && upper_idx < num - 1) ||
(upper_idx == num - 1 && x[upper_idx] == value)) {
*upper_bound = upper_idx;
}
}
return;
}
template <typename T>
class RangeFunctor {
private:
T* value_;
public:
explicit RangeFunctor(T* value) : value_(value) {}
inline HOSTDEVICE void operator()(size_t i) { value_[i] = static_cast<T>(i); }
};
class SparseMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class SparseMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "SparseMomentum");
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "SparseMomentum");
OP_INOUT_CHECK(ctx->HasInput("Velocity"), "Input", "Velocity",
"SparseMomentum");
OP_INOUT_CHECK(ctx->HasInput("Index"), "Input", "Index", "SparseMomentum");
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
"SparseMomentum");
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut",
"SparseMomentum");
OP_INOUT_CHECK(ctx->HasOutput("VelocityOut"), "Output", "VelocityOut",
"SparseMomentum");
auto lr_dims = framework::product(ctx->GetInputDim("LearningRate"));
PADDLE_ENFORCE_EQ(lr_dims != 0 && lr_dims == 1, true,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
lr_dims));
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
platform::errors::InvalidArgument(
"Param and Velocity of SparseMomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dim, ctx->GetInputDim("Velocity")));
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T, typename MT, typename IndexT, typename UpdateMethod>
class IndexMomentumFunctor {
private:
const T* param_;
const T* grad_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
const MT mu_;
const MT rescale_grad_;
const IndexT* sorted_index_;
const IndexT* grad_index_;
const int64_t num_index_;
const int axis_;
const int64_t param_row_numel_;
const int64_t grad_row_numel_;
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
const UpdateMethod& update_method_;
public:
IndexMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* lr, const MT* master_param,
const MT mu, const MT rescale_grad,
const IndexT* sorted_index, const IndexT* grad_index,
int64_t num_index, int axis, int64_t param_row_numel,
int64_t grad_row_numel,
const RegularizationType regularization_flag,
const MT regularization_coeff,
const UpdateMethod& update_method, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
grad_(grad),
velocity_(velocity),
lr_(lr),
master_param_(master_param),
mu_(mu),
rescale_grad_(rescale_grad),
sorted_index_(sorted_index),
grad_index_(grad_index),
num_index_(num_index),
axis_(axis),
param_row_numel_(param_row_numel),
grad_row_numel_(grad_row_numel),
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff),
update_method_(update_method) {}
inline HOSTDEVICE void operator()(size_t i) {
MT grad = static_cast<MT>(0);
size_t row = i / param_row_numel_;
size_t col = i % param_row_numel_;
if (axis_ == 0) {
int64_t row_idx0, row_idx1;
BinarySearchLowerUpperBound<IndexT>(sorted_index_, num_index_, row,
&row_idx0, &row_idx1);
if (row_idx0 >= 0 && row_idx1 >= 0) {
for (int64_t row_idx = row_idx0; row_idx <= row_idx1; row_idx++) {
size_t offset = grad_index_[row_idx] * param_row_numel_ + col;
grad += static_cast<MT>(grad_[offset]) * rescale_grad_;
}
}
} else if (axis_ == 1) {
int64_t col_idx0, col_idx1;
BinarySearchLowerUpperBound<IndexT>(sorted_index_, num_index_, col,
&col_idx0, &col_idx1);
if (col_idx0 >= 0 && col_idx1 >= 0) {
for (int64_t col_idx = col_idx0; col_idx <= col_idx1; col_idx++) {
size_t offset = row * grad_row_numel_ + grad_index_[col_idx];
grad += static_cast<MT>(grad_[offset]) * rescale_grad_;
}
}
}
// put memory access in register
const MT param =
master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
MT velocity_out = velocity * mu_ + grad;
MT velocity_tmp = update_method_(grad, velocity_out, mu_);
MT param_out = param - velocity_tmp * lr;
// write reigster to memory
velocity_out_[i] = velocity_out;
param_out_[i] = static_cast<T>(param_out);
if (master_param_out_) {
master_param_out_[i] = param_out;
}
}
};
template <typename DeviceContext, typename T>
class SparseMomentumOpKernel : public framework::OpKernel<T> {
using MPDType = MultiPrecisionType<T>;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto index = ctx.Input<framework::Tensor>("Index");
const auto& index_type = index->type();
if (multi_precision) {
if (use_nesterov) {
auto update_method = UseNesterov<MPDType>();
if (index_type == framework::proto::VarType::INT32) {
InnerCompute<MPDType, int, UseNesterov<MPDType>>(ctx, multi_precision,
update_method);
} else {
InnerCompute<MPDType, int64_t, UseNesterov<MPDType>>(
ctx, multi_precision, update_method);
}
} else {
auto update_method = NoNesterov<MPDType>();
if (index_type == framework::proto::VarType::INT32) {
InnerCompute<MPDType, int, NoNesterov<MPDType>>(ctx, multi_precision,
update_method);
} else {
InnerCompute<MPDType, int64_t, NoNesterov<MPDType>>(
ctx, multi_precision, update_method);
}
}
} else {
if (use_nesterov) {
auto update_method = UseNesterov<T>();
if (index_type == framework::proto::VarType::INT32) {
InnerCompute<T, int, UseNesterov<T>>(ctx, multi_precision,
update_method);
} else {
InnerCompute<T, int64_t, UseNesterov<T>>(ctx, multi_precision,
update_method);
}
} else {
auto update_method = NoNesterov<T>();
if (index_type == framework::proto::VarType::INT32) {
InnerCompute<T, int, NoNesterov<T>>(ctx, multi_precision,
update_method);
} else {
InnerCompute<T, int64_t, NoNesterov<T>>(ctx, multi_precision,
update_method);
}
}
}
}
private:
template <typename MT, typename IndexT, typename UpdateMethod>
void InnerCompute(const framework::ExecutionContext& ctx,
const bool multi_precision,
const UpdateMethod& update_method) const {
std::string regularization_method =
ctx.Attr<std::string>("regularization_method");
MT regularization_coeff =
static_cast<MT>(ctx.Attr<float>("regularization_coeff"));
RegularizationType regularization_flag{
RegularizationType::kNONE}; // disable regularization
if (regularization_method == "l2_decay") {
regularization_flag = RegularizationType::kL2DECAY;
}
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
Tensor cpu_axis;
const Tensor* axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto& axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
PADDLE_ENFORCE_EQ(
axis == 0 || axis == 1, true,
platform::errors::InvalidArgument("The axis of sparse_momentum_op only "
"support axis=0 or axis=1 now."));
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto index = ctx.Input<framework::Tensor>("Index");
int64_t num_index = index->numel();
// check index of shape 1-D
if (index->dims().size() == 1) {
PADDLE_ENFORCE_GT(
index->dims()[0], 0,
platform::errors::InvalidArgument(
"The index of sparse_momentum_op should not be empty"
"when the index's rank is 1."));
} else if (index->dims().size() == 2) {
PADDLE_ENFORCE_EQ(index->dims()[1], 1,
platform::errors::InvalidArgument(
"If the index's rank of sparse_momentum_op is 2,"
" the second dimension should be 1."));
}
const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<MT>(ctx.GetPlace());
const MT* master_in_data =
multi_precision ? master_param->data<MT>() : nullptr;
MT* master_out_data =
multi_precision ? master_param_out->mutable_data<MT>(ctx.GetPlace())
: nullptr;
auto grad = ctx.Input<framework::Tensor>("Grad");
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
auto param_dims = param->dims();
auto grad_dims = grad->dims();
PADDLE_ENFORCE_EQ(param_dims.size(), 2,
platform::errors::InvalidArgument(
"The Param's rank of sparse_momentum_op"
" must be 2 now."));
PADDLE_ENFORCE_EQ(grad_dims.size(), 2,
platform::errors::InvalidArgument(
"The Grad's rank of sparse_momentum_op"
" must be 2 now."));
Tensor sorted_index, grad_index, sort_value;
auto sorted_index_ptr =
sorted_index.mutable_data<IndexT>({num_index}, ctx.GetPlace());
auto grad_index_ptr =
grad_index.mutable_data<IndexT>({num_index}, ctx.GetPlace());
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto sort_value_ptr =
sort_value.mutable_data<IndexT>({num_index}, ctx.GetPlace());
platform::ForRange<DeviceContext> for_range_index(
static_cast<const DeviceContext&>(ctx.device_context()), num_index);
RangeFunctor<IndexT> range_functor(sort_value_ptr);
for_range_index(range_functor);
size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
(cub::DeviceRadixSort::SortPairs<IndexT, IndexT>(
nullptr, temp_storage_bytes, nullptr, nullptr, nullptr, nullptr,
static_cast<int>(num_index))));
auto d_temp_storage = memory::Alloc(ctx.GetPlace(), temp_storage_bytes);
PADDLE_ENFORCE_CUDA_SUCCESS(
(cub::DeviceRadixSort::SortPairs<IndexT, IndexT>(
d_temp_storage->ptr(), temp_storage_bytes, index->data<IndexT>(),
sorted_index_ptr, sort_value_ptr, grad_index_ptr,
static_cast<int>(num_index), 0, sizeof(IndexT) * 8,
ctx.cuda_device_context().stream())));
#endif
} else if (platform::is_cpu_place(ctx.GetPlace())) {
std::vector<std::pair<IndexT, IndexT>> vec_tosort;
auto index_ptr = index->data<IndexT>();
for (IndexT i = 0; i < num_index; i++) {
vec_tosort.push_back({index_ptr[i], i});
}
std::sort(vec_tosort.begin(), vec_tosort.end(),
[](const std::pair<IndexT, IndexT>& k1,
const std::pair<IndexT, IndexT>& k2) {
return k1.first < k2.first;
});
for (IndexT i = 0; i < num_index; i++) {
sorted_index_ptr[i] = vec_tosort[i].first;
grad_index_ptr[i] = vec_tosort[i].second;
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"sparse_momentum %s is not supported.", ctx.GetPlace()));
}
IndexMomentumFunctor<T, MT, IndexT, UpdateMethod> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
sorted_index_ptr, grad_index_ptr, num_index, axis, param_dims[1],
grad_dims[1], regularization_flag, regularization_coeff, update_method,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
{"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}},
};
......@@ -97,6 +98,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
......@@ -124,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}},
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
{"accuracy", {"Correct", "Total"}},
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
import paddle
import paddle.fluid as fluid
def calculate_sparse_momentum_by_numpy(param,
grad,
mu,
velocity,
use_nesterov,
learning_rate,
index,
axis,
regularization_method=None,
regularization_coeff=1.0):
sub_grad = grad.copy()
grad = np.zeros_like(param)
if axis == 0:
unique_index = np.unique(index)
for idx in unique_index:
grad[idx, :] = np.sum(sub_grad[index == idx, :], axis=0)
else:
unique_index = np.unique(index)
for idx in unique_index:
grad[:, idx] = np.sum(sub_grad[:, index == idx], axis=1)
if regularization_method == "l2_decay":
grad = grad + regularization_coeff * param
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - (grad + velocity_out * mu) * learning_rate
else:
param_out = param - learning_rate * velocity_out
else:
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate
else:
param_out = param - learning_rate * velocity_out
return param_out, velocity_out
class TestSparseMomentumOp(OpTest):
def setUp(self):
self.op_type = "sparse_momentum"
self.dtype = np.float32
self.index_dtype = np.int32
self.axis = 0
self.multi_precision = False
self.use_nesterov = False
self.batch_size = 20
self.num_classes = 20
self.init_dtype()
self.init_axis()
self.init_multi_precision()
self.init_use_nesterov()
if self.multi_precision:
assert self.dtype == np.float16
param = np.random.random(
(self.batch_size, self.num_classes)).astype(self.dtype)
grad = np.random.random(
(self.batch_size, self.num_classes)).astype(self.dtype)
if self.axis == 0:
index = np.random.randint(
0,
self.batch_size,
size=(self.batch_size // 2, ),
dtype=self.index_dtype)
grad = grad[index]
else:
index = np.random.randint(
0,
self.num_classes,
size=(self.num_classes // 2, ),
dtype=self.index_dtype)
grad = grad[:, index]
velocity = np.random.random(
(self.batch_size, self.num_classes)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(self.dtype)
mu = 0.9
regularization_method = "l2_decay"
regularization_coeff = 1.0
param_out, velocity_out = calculate_sparse_momentum_by_numpy(
param=param,
grad=grad,
mu=mu,
velocity=velocity,
use_nesterov=self.use_nesterov,
learning_rate=learning_rate,
regularization_method=regularization_method,
regularization_coeff=regularization_coeff,
index=index,
axis=self.axis)
self.attrs = {
'mu': mu,
'use_nesterov': self.use_nesterov,
'regularization_method': regularization_method,
'regularization_coeff': regularization_coeff,
'multi_precision': self.multi_precision,
'axis': self.axis,
}
self.inputs = {
'Param': param.astype("float16") if self.multi_precision else param,
'Velocity': velocity.astype("float32")
if self.multi_precision else velocity,
'LearningRate': learning_rate.astype("float32")
if self.multi_precision else learning_rate,
'Grad': grad.astype("float16") if self.multi_precision else grad,
'Index': index,
'Axis': np.array(self.axis).astype(np.int32),
}
self.outputs = {
'ParamOut': param_out.astype("float16")
if self.multi_precision else param_out,
'VelocityOut': velocity_out.astype("float32")
if self.multi_precision else velocity_out,
}
if self.multi_precision:
self.inputs['MasterParam'] = param.astype(
"float32") if self.multi_precision else param
self.outputs['MasterParamOut'] = param_out.astype(
"float32") if self.multi_precision else param_out
def init_dtype(self):
pass
def init_axis(self):
pass
def init_multi_precision(self):
pass
def init_use_nesterov(self):
pass
def test_check_output(self):
self.check_output(atol=5e-3 if self.multi_precision else 1e-5)
class TestSparseMomentumOpDtype1(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float32
self.index_dtype = np.int64
class TestSparseMomentumOpDtype2(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float64
self.index_dtype = np.int32
class TestSparseMomentumOpDtype3(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float64
self.index_dtype = np.int64
class TestSparseMomentumOpAxis(TestSparseMomentumOp):
def init_axis(self):
self.axis = 1
class TestSparseMomentumOpNesterov(TestSparseMomentumOp):
def init_use_nesterov(self):
self.use_nesterov = True
class TestSparseMomentumOpMultiPrecision(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float16
self.index_dtype = np.int32
def init_multi_precision(self):
self.multi_precision = True
def init_use_nesterov(self):
self.use_nesterov = True
class TestSparseMomentumOpMultiPrecision1(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float16
self.index_dtype = np.int64
def init_multi_precision(self):
self.multi_precision = True
def init_use_nesterov(self):
self.use_nesterov = True
class TestSparseMomentumOpMultiPrecision2(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float16
self.index_dtype = np.int32
def init_multi_precision(self):
self.multi_precision = True
def init_use_nesterov(self):
self.use_nesterov = False
class TestSparseMomentumOpMultiPrecision3(TestSparseMomentumOp):
def init_dtype(self):
self.dtype = np.float16
self.index_dtype = np.int64
def init_multi_precision(self):
self.multi_precision = True
def init_use_nesterov(self):
self.use_nesterov = False
......@@ -333,6 +333,7 @@ STATIC_MODE_TESTING_LIST = [
'test_mish_op',
'test_modified_huber_loss_op',
'test_momentum_op',
'test_sparse_momentum_op',
'test_monitor',
'test_mse_loss',
'test_mul_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册