提交 4cb5bd90 编写于 作者: A Abhinav Arora 提交者: GitHub

Implementing the Adamax optimizer operator (#4538)

* Implementing the Adamax optimizer step operator
* Adding unit tests for adamax_op

* Changing learning rate and time step to inputs from attributes

* Changing learning rate and time step to input(tensors)

* Making the Adamax operator conform to naming convention

* Removing Tensor<float> from comments

* Rectifying the Adamax implementation

* Changing Unit Test values and adding comments

* Changing Unit Test to test multiple steps
上级 f30a1f42
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/adamax_op.h"
namespace paddle {
namespace operators {
class AdamaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InfNorm"),
"Input(InfNorm) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(MomentOut) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("InfNormOut"),
"Output(InfNormOut) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"),
"Output(Beta1PowOut) of AdamaxOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamaxOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment"),
"Param and Moment input of AdamaxOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("InfNorm"),
"Param and InfNorm input of AdamaxOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
ctx->SetOutputDim("InfNormOut", param_dims);
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
}
};
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
AddInput("Moment", "(Tensor) First moment");
AddInput("InfNorm",
"(Tensor) "
"Input exponentially weighted infinity norm");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("MomentOut", "(Tensor) Output first moment");
AddOutput("InfNormOut",
"(Tensor) "
"Output exponentially weighted infinity norm");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddAttr<float>("beta1",
"(float, default 0.9) "
"Exponential decay rate for the "
"1st moment estimates.")
.SetDefault(0.9f);
AddAttr<float>("beta2",
"(float, default 0.999) "
"exponential decay rate for the weighted "
"infinity norm estimates.")
.SetDefault(0.999f);
AddAttr<float>("epsilon",
"(float, default 1.0e-8) "
"Constant for numerical stability")
.SetDefault(1.0e-8f);
AddComment(R"DOC(
Adamax Updates Operator.
This implements the Adamax optimizer from Section 7 of the Adam
paper[1]. Adamax is a variant of the
Adam algorithm based on the infinity norm.
Adamax updates:
moment_out = beta1 * moment + (1 - beta1) * grad
inf_norm_out = max(beta2 * inf_norm + epsilon, abs(grad))
beta1_pow_out = beta1_pow * beta1
learning_rate_t = learning_rate/(1 - beta1_pow_out)
param_out = param - learning_rate_t * moment_out/inf_norm_out
The original paper does not have an epsilon attribute.
However, it is added here for numerical stability
by preventing divide by 0.
References:
[1] Adam: A Method for Stochastic Optimization
(https://arxiv.org/abs/1412.6980)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
REGISTER_OP_CPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/adamax_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");
auto beta1_pow_out_tensor = ctx.Output<framework::Tensor>("Beta1PowOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
inf_norm_out_tensor->mutable_data<T>(ctx.GetPlace());
beta1_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
float beta1 = ctx.Attr<float>("beta1");
float beta2 = ctx.Attr<float>("beta2");
float epsilon = ctx.Attr<float>("epsilon");
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto inf_norm = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("InfNorm"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto beta1_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta1Pow"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto inf_norm_out =
framework::EigenVector<T>::Flatten(*inf_norm_out_tensor);
auto beta1_pow_out =
framework::EigenVector<T>::Flatten(*beta1_pow_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
moment_out.device(place) = beta1 * moment + (1 - beta1) * grad;
inf_norm_out.device(place) =
grad.abs().cwiseMax((beta2 * inf_norm) + epsilon);
beta1_pow_out.device(place) = beta1_pow * beta1;
auto lr_t = lr / (1 - beta1_pow_out);
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out);
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestAdamaxOp1(OpTest):
def setUp(self):
'''Test Adamax Operator with supplied attributes
'''
self.op_type = "adamax"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The infinity norm is positive
inf_norm = np.random.random((102, 105)).astype("float32")
learning_rate = 0.002
beta1 = 0.78
beta2 = 0.899
epsilon = 1e-5
beta1_pow = beta1**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'InfNorm': inf_norm,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32")
}
self.attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon}
param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step(
self.inputs, self.attrs)
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'InfNormOut': inf_norm_out,
'Beta1PowOut': beta1_pow_out
}
def test_check_output(self):
self.check_output()
class TestAdamaxOp2(OpTest):
'''Test Adamax Operator with default attributes
'''
def setUp(self):
self.op_type = "adamax"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The infinity norm is positive
inf_norm = np.random.random((102, 105)).astype("float32")
learning_rate = 0.002
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
beta1_pow = beta1**8
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'InfNorm': inf_norm,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32")
}
attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon}
param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step(
self.inputs, attrs)
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'InfNormOut': inf_norm_out,
'Beta1PowOut': beta1_pow_out
}
def test_check_output(self):
self.check_output()
class TestAdamaxOpMultipleSteps(OpTest):
def setUp(self):
'''Test Adamax Operator with supplied attributes
'''
self.op_type = "adamax"
self.num_steps = 10
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The infinity norm is positive
inf_norm = np.random.random((102, 105)).astype("float32")
learning_rate = 0.002
beta1 = 0.8
beta2 = 0.99
epsilon = 1e-5
beta1_pow = 1
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'InfNorm': inf_norm,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32")
}
self.attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon}
param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step(
self.inputs, self.attrs)
def test_check_output(self):
for _ in range(self.num_steps):
param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step(
self.inputs, self.attrs)
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'InfNormOut': inf_norm_out,
'Beta1PowOut': beta1_pow_out
}
# Verify output for this step
self.check_output()
# Output of this step becomes input for next step
self.inputs['Param'] = param_out
self.inputs['Moment'] = moment_out
self.inputs['InfNorm'] = inf_norm_out
self.inputs['Beta1Pow'] = beta1_pow_out
# Randomize gradient for next step
self.inputs['Grad'] = np.random.uniform(
-1, 1, (102, 105)).astype("float32")
def adamax_step(inputs, attributes):
'''
Simulate one step of the adamax optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment, inf_norm and
beta1 power accumulator
'''
param = inputs['Param']
grad = inputs['Grad']
moment = inputs['Moment']
inf_norm = inputs['InfNorm']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta1 = attributes['beta1']
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
moment_out = beta1 * moment + (1 - beta1) * grad
inf_norm_out = np.maximum(beta2 * inf_norm + epsilon, np.abs(grad))
beta1_pow_out = beta1_pow * beta1
lr_t = (lr / (1 - beta1_pow_out))
param_out = param - lr_t * np.divide(moment_out, inf_norm_out)
return param_out, moment_out, inf_norm_out, beta1_pow_out
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册