提交 229d4bbb 编写于 作者: Z Zeng Jinle 提交者: chengduo

cherry-pick sparse rmsprop to release/1.0.0 (#13907)

* test=release/1.0.0

* fix sparse rmsprop

* test=develop

* add check for opt op

* test=release/1.0.0
上级 b97257b1
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
"Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null.");
......@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
......
......@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut");
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -21,25 +22,31 @@ namespace operators {
template <typename DeviceContext, typename T>
struct SparseAdagradFunctor {
void operator()(const DeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param);
void operator()(const DeviceContext &context,
const framework::SelectedRows &grad,
const framework::Tensor &learning_rate, T epsilon,
framework::Tensor *moment, framework::Tensor *param);
};
template <typename DeviceContext, typename T>
class AdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* grad_var = ctx.InputVar("Grad");
auto *grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
......@@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
auto *place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
if (platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate->data<T>();
auto *lr = learning_rate->data<T>();
param_out.device(*place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else {
......@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
auto *param_tensor = ctx.Input<framework::Tensor>("Param");
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
auto *moment_tensor = ctx.Input<framework::Tensor>("Moment");
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
SparseAdagradFunctor<DeviceContext, T> functor;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
int64_t beg = 0, end = row_count_ - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (rows_[mid] == row)
return mid;
else if (rows_[mid] < row)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
inline HOSTDEVICE void operator()(size_t i) const {
int64_t row = i / row_numel_;
auto row_idx = BinarySearchInRows(row);
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
// The following code is the same as dense
......@@ -244,6 +231,12 @@ template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
......
......@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
"Input(LearningRate) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamaxOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamaxOp should not be null.");
......
......@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");
......
......@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
"Input(LearningRate) of DecayedAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of DecayedAdagradOp should not be null.");
......
......@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class DecayedAdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
......
......@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel {
"Input(Grad) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of FTRL should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of FTRL should not be null.");
......
......@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T>
class FTRLOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstdint> // for int64_t
#include <numeric>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
int64_t beg = 0, end = num - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (x[mid] == val)
return mid;
else if (x[mid] < val)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel {
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
......
......@@ -46,6 +46,17 @@ template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
......
......@@ -23,6 +23,12 @@ template <typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
......
......@@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel {
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
......
......@@ -13,66 +13,259 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
struct DenseRmspropGradFunctor {
inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {}
HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; }
const T *grad_;
};
template <typename T>
struct SparseRmspropGradFunctor {
inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows,
int64_t row_numel, int64_t row_count)
: grad_(grad),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
HOSTDEVICE inline T operator()(int64_t idx) const {
auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_);
return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0;
}
const T *grad_;
const int64_t *rows_;
int64_t row_numel_;
int64_t row_count_;
};
template <typename T, typename GradFunctor>
struct UncenteredRmspropFunctor {
UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho,
T epsilon, T momentum,
const GradFunctor &grad_functor)
: param_(param),
ms_(ms),
mom_(mom),
lr_(lr),
rho_(rho),
epsilon_(epsilon),
momentum_(momentum),
grad_functor_(grad_functor) {}
HOSTDEVICE inline void operator()(int64_t idx) const {
T g = grad_functor_(idx);
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_);
param_[idx] -= mom_out;
ms_[idx] = ms_out;
mom_[idx] = mom_out;
}
T *param_;
T *ms_;
T *mom_;
const T *lr_;
T rho_;
T epsilon_;
T momentum_;
GradFunctor grad_functor_;
};
template <typename T, typename GradFunctor>
struct CenteredRmspropFunctor {
CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr,
T rho, T epsilon, T momentum,
const GradFunctor &grad_functor)
: param_(param),
ms_(ms),
mom_(mom),
mean_grad_(mean_grad),
lr_(lr),
rho_(rho),
epsilon_(epsilon),
momentum_(momentum),
grad_functor_(grad_functor) {}
HOSTDEVICE inline void operator()(int64_t idx) const {
T g = grad_functor_(idx);
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g;
T mom_out = momentum_ * mom_[idx] +
lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_);
param_[idx] -= mom_out;
ms_[idx] = ms_out;
mom_[idx] = mom_out;
mean_grad_[idx] = mg_out;
}
T *param_;
T *ms_;
T *mom_;
T *mean_grad_;
const T *lr_;
T rho_;
T epsilon_;
T momentum_;
GradFunctor grad_functor_;
};
template <typename DeviceContext, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
void Compute(const framework::ExecutionContext &ctx) const override {
using LoDTensor = framework::LoDTensor;
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto *grad_var = ctx.InputVar("Grad");
auto *param_out = ctx.Output<LoDTensor>("ParamOut");
auto *moment_out = ctx.Output<LoDTensor>("MomentOut");
auto *mean_square_out = ctx.Output<LoDTensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
auto epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto rho = static_cast<T>(ctx.Attr<float>("decay"));
auto momentum = static_cast<T>(ctx.Attr<float>("momentum"));
bool centered = ctx.Attr<bool>("centered");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
auto &p_tensor = *ctx.Input<LoDTensor>("Param");
auto &ms_tensor = *ctx.Input<LoDTensor>("MeanSquare");
auto &lr_tensor = *ctx.Input<LoDTensor>("LearningRate");
auto &mom_tensor = *ctx.Input<LoDTensor>("Moment");
float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
bool centered = ctx.Attr<bool>("centered");
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
"Param and ParamOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&mom_tensor, moment_out,
"Moment and MomentOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out,
"MeanSquare and MeanSquareOut must be the same Tensor");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(ms_tensor.numel());
if (grad_var->IsType<LoDTensor>()) {
auto &grad_tensor = grad_var->Get<LoDTensor>();
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
auto &place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto lr_value = lr_tensor.data<T>()[0];
auto p = EigenVector<T>::Flatten(p_tensor);
auto ms = EigenVector<T>::Flatten(ms_tensor);
auto g = EigenVector<T>::Flatten(grad_tensor);
auto mom = EigenVector<T>::Flatten(mom_tensor);
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
mom_out.device(place) =
momentum * mom +
lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt();
} else {
mom_out.device(place) =
momentum * mom + lr_value * g / (ms_out + epsilon).sqrt();
}
p_out.device(place) = p - mom_out;
} else {
DenseRmspropGradFunctor<T> grad_func(grad_tensor.data<T>());
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()),
mean_grad_out->mutable_data<T>(ctx.GetPlace()),
lr_tensor.data<T>(), rho, epsilon, momentum, grad_func));
} else {
for_range(UncenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto &grad = grad_var->Get<framework::SelectedRows>();
auto *merged_grad = const_cast<framework::Scope &>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(dev_ctx, grad, merged_grad);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
const int64_t *rows;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
auto &merged_tensor = merged_grad->value();
int64_t row_count = merged_grad->rows().size();
int64_t row_numel = merged_tensor.numel() / row_count;
SparseRmspropGradFunctor<T> grad_func(merged_tensor.data<T>(), rows,
row_numel, row_count);
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(static_cast<int>(grad->numel()));
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto mg = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanGrad"));
auto* mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
mean_grad_out->mutable_data<T>(ctx.GetPlace());
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
mom_out.device(place) = momentum * mom +
lr.broadcast(grad_dsize) * g /
(ms_out - mg_out.square() + epsilon).sqrt();
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()),
mean_grad_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
} else {
for_range(UncenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
}
} else {
mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient");
}
p_out.device(place) = p - mom_out;
}
};
......
......@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......@@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(data_type, ctx.device_context());
}
......@@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto input_var_n = op_desc.Input("Param")[0];
auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
input_var_n, in_var_type);
for (auto &out_var_n : op_desc.Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n);
if (out_var.GetType() != in_var_type) {
out_var.SetType(in_var_type);
}
}
}
......
......@@ -57,6 +57,12 @@ template <typename T>
class SGDOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
......
......@@ -659,6 +659,9 @@ class AdamaxOptimizer(Optimizer):
optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, AdamaxOptimizer doesn't support sparse gradient.
"""
_moment_acc_str = "moment"
_inf_norm_acc_str = "inf_norm"
......@@ -778,6 +781,9 @@ class DecayedAdagradOptimizer(Optimizer):
optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, DecayedAdagradOptimizer doesn't support sparse gradient.
"""
_moment_acc_str = "moment"
......@@ -858,6 +864,9 @@ class AdadeltaOptimizer(Optimizer):
optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, AdadeltaOptimizer doesn't support sparse gradient.
"""
_avg_squared_grad_acc_str = "_avg_squared_grad"
......@@ -1126,6 +1135,9 @@ class FtrlOptimizer(Optimizer):
optimizer = fluid.optimizer.Ftrl(0.0001)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, FtrlOptimizer doesn't support sparse gradient.
"""
_squared_acc_str = "squared"
......
......@@ -19,33 +19,76 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
def create_selected_rows_and_tensor(scope, place, height, row_num,
embedding_size):
sr = scope.var("@selected_rows@").get_selected_rows()
tensor = scope.var("grad").get_tensor()
rows = np.random.random_integers(
low=0, high=height - 1, size=[row_num, ]).astype('int64')
sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32')
sr.set_height(height)
sr.set_rows(rows)
sr.get_tensor().set(sr_val, place)
tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
for i in range(row_num):
row = rows[i]
tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :]
tensor.set(tensor_val, place)
return tensor_val, sr_val
class TestBase(unittest.TestCase):
def setup(self, centered, epsilon=1e-6):
def setup(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
self.param_name = "param"
self.param = np.random.random((123, 321)).astype("float32")
self.param = np.random.random(size).astype("float32")
self.mean_square_name = "mean_square"
self.mean_square = np.random.random((123, 321)).astype("float32")
self.mean_square = np.random.uniform(
low=1, high=2, size=size).astype("float32")
self.mean_grad_name = "mean_grad"
self.mean_grad = np.random.random((123, 321)).astype("float32")
self.mean_grad = np.random.random(size).astype("float32")
self.lr_name = "lr"
self.learning_rate = np.array([0.01]).astype("float32")
self.grad_name = "grad"
self.grad = np.random.random((123, 321)).astype("float32")
self.is_sparse = is_sparse
if self.is_sparse:
self.grad_sr_name = "@selected_rows@"
self.grad, self.grad_sr = create_selected_rows_and_tensor(
self.scope, place, size[0], row_num, size[1])
else:
self.grad = np.random.random(size).astype("float32")
grad_tensor = self.scope.var(self.grad_name).get_tensor()
grad_tensor.set(self.grad, place)
self.moment_name = "moment"
self.moment = np.zeros((123, 321)).astype("float32")
self.moment = np.random.uniform(
low=0, high=1, size=size).astype("float32")
self.epsilon = epsilon
self.decay = 0.9
self.momentum = 0.0
self.momentum = 0.1
self.centered = centered
self.ms_out = self.decay * self.mean_square + (1 - self.decay
......@@ -61,118 +104,122 @@ class TestBase(unittest.TestCase):
self.param_out = self.param - self.moment_out
def check(self,
actual_t,
expect_t,
place,
out_name,
atol=1e-5,
equal_nan=False):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
class TestRmspropOp(TestBase):
def check_with_place(self, place, centered, epsilon):
self.setup(centered, epsilon)
scope = core.Scope()
# create and initialize Param Variable
param = scope.var(self.param_name).get_tensor()
param.set(self.param, place)
self.param_tensor = self.scope.var(self.param_name).get_tensor()
self.param_tensor.set(self.param, place)
mean_square = scope.var(self.mean_square_name).get_tensor()
mean_square.set(self.mean_square, place)
self.mean_square_tensor = self.scope.var(
self.mean_square_name).get_tensor()
self.mean_square_tensor.set(self.mean_square, place)
lr = scope.var(self.lr_name).get_tensor()
lr = self.scope.var(self.lr_name).get_tensor()
lr.set(self.learning_rate, place)
grad = scope.var(self.grad_name).get_tensor()
grad.set(self.grad, place)
self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
self.moment_tensor.set(self.moment, place)
moment = scope.var(self.moment_name).get_tensor()
moment.set(self.moment, place)
if self.centered:
self.mean_grad_tensor = self.scope.var(
self.mean_grad_name).get_tensor()
self.mean_grad_tensor.set(self.mean_grad, place)
# create and run sgd operator
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
if self.centered:
mean_grad = scope.var(self.mean_grad_name).get_tensor()
mean_grad.set(self.mean_grad, place)
rmsprop_op = Operator(
"rmsprop",
Param=self.param_name,
Grad=self.grad_name,
MeanSquare=self.mean_square_name,
MeanGrad=self.mean_grad_name,
Moment=self.moment_name,
LearningRate=self.lr_name,
ParamOut=self.param_name,
MeanSquareOut=self.mean_square_name,
MomentOut=self.moment_name,
MeanGradOut=self.mean_grad_name,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum,
centered=True)
else:
rmsprop_op = Operator(
"rmsprop",
Param=self.param_name,
Grad=self.grad_name,
MeanSquare=self.mean_square_name,
Moment=self.moment_name,
LearningRate=self.lr_name,
ParamOut=self.param_name,
MeanSquareOut=self.mean_square_name,
MomentOut=self.moment_name,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum,
centered=False)
rmsprop_op.run(scope, place)
atol = 1e-5
equal_nan = False
class TestRmspropOp(TestBase):
def check_with_place(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
self.setup(place, is_sparse, centered, size, row_num, epsilon)
self.run_and_check()
def run_and_check(self):
grad_name = self.grad_sr_name if self.is_sparse else self.grad_name
kwargs = {
'Param': self.param_name,
'Grad': grad_name,
'MeanSquare': self.mean_square_name,
'Moment': self.moment_name,
'LearningRate': self.lr_name,
'ParamOut': self.param_name,
'MeanSquareOut': self.mean_square_name,
'MomentOut': self.moment_name,
'epsilon': self.epsilon,
'decay': self.decay,
'momentum': self.momentum,
'centered': self.centered
}
if self.centered:
atol = 1e-3
equal_nan = True
kwargs['MeanGrad'] = self.mean_grad_name
kwargs['MeanGradOut'] = self.mean_grad_name
rmsprop_op = Operator('rmsprop', **kwargs)
atol = 1e-6
rmsprop_op.run(self.scope, self.place)
self.check(
np.array(mean_square), self.ms_out, place, self.mean_square_name)
np.array(self.mean_square_tensor),
self.ms_out,
self.place,
self.mean_square_name,
atol=atol)
self.check(
np.array(moment),
np.array(self.moment_tensor),
self.moment_out,
place,
self.place,
self.moment_name,
atol=atol,
equal_nan=equal_nan)
atol=atol)
self.check(
np.array(param),
np.array(self.param_tensor),
self.param_out,
place,
self.place,
self.param_name,
atol=atol,
equal_nan=equal_nan)
atol=atol)
if self.centered:
self.check(
np.array(mean_grad), self.mg_out, place, self.mean_grad_name)
np.array(self.mean_grad_tensor), self.mg_out, self.place,
self.mean_grad_name)
def test_rmsprop(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
size = (128, 320)
for place in places:
self.check_with_place(place, False, 1e-6)
self.check_with_place(place, False, 1e-10)
self.check_with_place(place, True, 1e-6)
self.check_with_place(place, True, 1e-10)
for centered in [False, True]:
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place, is_sparse=False, centered=centered, size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=512,
size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=60,
size=size)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册