提交 584c3f04 编写于 作者: S sneaxiy

fix sparse rmsprop

上级 23644940
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
int64_t beg = 0, end = row_count_ - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (rows_[mid] == row)
return mid;
else if (rows_[mid] < row)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
inline HOSTDEVICE void operator()(size_t i) const {
int64_t row = i / row_numel_;
auto row_idx = BinarySearchInRows(row);
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
// The following code is the same as dense
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstdint> // for int64_t
#include <numeric>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
int64_t beg = 0, end = num - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (x[mid] == val)
return mid;
else if (x[mid] < val)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -13,66 +13,254 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
struct DenseRmspropGradFunctor {
inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {}
HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; }
const T *grad_;
};
template <typename T>
struct SparseRmspropGradFunctor {
inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows,
int64_t row_numel, int64_t row_count)
: grad_(grad),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
HOSTDEVICE inline T operator()(int64_t idx) const {
auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_);
return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0;
}
const T *grad_;
const int64_t *rows_;
int64_t row_numel_;
int64_t row_count_;
};
template <typename T, typename GradFunctor>
struct UncenteredRmspropFunctor {
UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho,
T epsilon, T momentum,
const GradFunctor &grad_functor)
: param_(param),
ms_(ms),
mom_(mom),
lr_(lr),
rho_(rho),
epsilon_(epsilon),
momentum_(momentum),
grad_functor_(grad_functor) {}
HOSTDEVICE inline void operator()(int64_t idx) const {
T g = grad_functor_(idx);
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_);
param_[idx] -= mom_out;
ms_[idx] = ms_out;
mom_[idx] = mom_out;
}
T *param_;
T *ms_;
T *mom_;
const T *lr_;
T rho_;
T epsilon_;
T momentum_;
GradFunctor grad_functor_;
};
template <typename T, typename GradFunctor>
struct CenteredRmspropFunctor {
CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr,
T rho, T epsilon, T momentum,
const GradFunctor &grad_functor)
: param_(param),
ms_(ms),
mom_(mom),
mean_grad_(mean_grad),
lr_(lr),
rho_(rho),
epsilon_(epsilon),
momentum_(momentum),
grad_functor_(grad_functor) {}
HOSTDEVICE inline void operator()(int64_t idx) const {
T g = grad_functor_(idx);
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g;
T mom_out = momentum_ * mom_[idx] +
lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_);
param_[idx] -= mom_out;
ms_[idx] = ms_out;
mom_[idx] = mom_out;
mean_grad_[idx] = mg_out;
}
T *param_;
T *ms_;
T *mom_;
T *mean_grad_;
const T *lr_;
T rho_;
T epsilon_;
T momentum_;
GradFunctor grad_functor_;
};
template <typename DeviceContext, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
void Compute(const framework::ExecutionContext &ctx) const override {
using Tensor = framework::LoDTensor;
auto *grad_var = ctx.InputVar("Grad");
auto *param_out = ctx.Output<Tensor>("ParamOut");
auto *moment_out = ctx.Output<Tensor>("MomentOut");
auto *mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
auto epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto rho = static_cast<T>(ctx.Attr<float>("decay"));
auto momentum = static_cast<T>(ctx.Attr<float>("momentum"));
bool centered = ctx.Attr<bool>("centered");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
auto &p_tensor = *ctx.Input<Tensor>("Param");
auto &ms_tensor = *ctx.Input<Tensor>("MeanSquare");
auto &lr_tensor = *ctx.Input<Tensor>("LearningRate");
auto &mom_tensor = *ctx.Input<Tensor>("Moment");
float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
bool centered = ctx.Attr<bool>("centered");
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
"Param and ParamOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&mom_tensor, moment_out,
"Moment and MomentOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out,
"MeanSquare and MeanSquareOut must be the same Tensor");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(ms_tensor.numel());
if (grad_var->IsType<Tensor>()) {
auto &grad_tensor = grad_var->Get<Tensor>();
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
auto &place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto lr_value = lr_tensor.data<T>()[0];
auto p = EigenVector<T>::Flatten(p_tensor);
auto ms = EigenVector<T>::Flatten(ms_tensor);
auto g = EigenVector<T>::Flatten(grad_tensor);
auto mom = EigenVector<T>::Flatten(mom_tensor);
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
mom_out.device(place) =
momentum * mom +
lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt();
} else {
mom_out.device(place) =
momentum * mom + lr_value * g / (ms_out + epsilon).sqrt();
}
p_out.device(place) = p - mom_out;
} else {
DenseRmspropGradFunctor<T> grad_func(grad_tensor.data<T>());
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
if (centered) {
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()),
mean_grad_out->mutable_data<T>(ctx.GetPlace()),
lr_tensor.data<T>(), rho, epsilon, momentum, grad_func));
} else {
for_range(UncenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto &grad = grad_var->Get<framework::SelectedRows>();
auto *merged_grad = const_cast<framework::Scope &>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(dev_ctx, grad, merged_grad);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
const int64_t *rows;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
auto &merged_tensor = merged_grad->value();
int64_t row_count = merged_grad->rows().size();
int64_t row_numel = merged_tensor.numel() / row_count;
SparseRmspropGradFunctor<T> grad_func(merged_tensor.data<T>(), rows,
row_numel, row_count);
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(static_cast<int>(grad->numel()));
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto mg = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanGrad"));
auto* mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
mean_grad_out->mutable_data<T>(ctx.GetPlace());
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
mom_out.device(place) = momentum * mom +
lr.broadcast(grad_dsize) * g /
(ms_out - mg_out.square() + epsilon).sqrt();
if (centered) {
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()),
mean_grad_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
} else {
for_range(UncenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
}
} else {
mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient");
}
p_out.device(place) = p - mom_out;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册