提交 11680037 编写于 作者: A Abhinav Arora 提交者: GitHub

Adding the Adam Optimizer operator (#4733)

* add adam op

moment1_out = beta1 * moment1 + (1 − beta1) * grad
moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad
moment1_hat =  moment1_out / (1 - beta1^t)
moment2_hat =  moment2_out / (1 - beta2^t)
param_out = param - learning_rate * moment1_hat / (sqrt(moment2_hat) +
epsilon)

* fix moment 2

* Adding the Adam optimization operator

* Adding more tests for Adam op
上级 74609581
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/adam_op.h"
namespace paddle {
namespace operators {
class AdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
"Input(Moment1) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
"Input(Moment2) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
"Input(Beta2Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
"Output(Moment1Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
"Output(Moment2Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"),
"Output(Beta1PowOut) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Beta2PowOut"),
"Output(Beta2PowOut) of AdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"),
"Param and InfNorm input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims);
}
};
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
AddInput("Moment1", "(Tensor) Input first moment");
AddInput("Moment2", "(Tensor) Input second moment");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddAttr<float>("beta1",
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates.")
.SetDefault(0.9f);
AddAttr<float>("beta2",
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates.")
.SetDefault(0.999f);
AddAttr<float>("epsilon",
"(float, default 1.0e-8) "
"Constant for numerical stability")
.SetDefault(1.0e-8f);
AddComment(R"DOC(
Adam Updates Operator.
This implements the Adam optimizer from Section 2 of the Adam
paper[1]. Adam is a first-order gradient-based optimization
method based on adaptive estimates of lower-order moments.
Adam updates:
moment1_out = beta1 * moment1 + (1 − beta1) * grad
moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad
beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2
learning_rate_t = learning_rate_t *
sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out)
param_out = param - learning_rate_t * moment1/ (sqrt(moment2) + epsilon)
References:
[1] Adam: A Method for Stochastic Optimization
(https://arxiv.org/abs/1412.6980)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/adam_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment1_out_tensor = ctx.Output<framework::Tensor>("Moment1Out");
auto moment2_out_tensor = ctx.Output<framework::Tensor>("Moment2Out");
auto beta1_pow_out_tensor = ctx.Output<framework::Tensor>("Beta1PowOut");
auto beta2_pow_out_tensor = ctx.Output<framework::Tensor>("Beta2PowOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
beta1_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
beta2_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
float beta1 = ctx.Attr<float>("beta1");
float beta2 = ctx.Attr<float>("beta2");
float epsilon = ctx.Attr<float>("epsilon");
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment1 = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment1"));
auto moment2 = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment2"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto beta1_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta1Pow"));
auto beta2_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta2Pow"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
auto beta1_pow_out =
framework::EigenVector<T>::Flatten(*beta1_pow_out_tensor);
auto beta2_pow_out =
framework::EigenVector<T>::Flatten(*beta2_pow_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad;
moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square();
beta1_pow_out.device(place) = beta1_pow * beta1;
beta2_pow_out.device(place) = beta2_pow * beta2;
// All of these are tensors of 1 element
auto lr_t = lr * (1 - beta2_pow_out).sqrt() / (1 - beta1_pow_out);
// Eigen does not support automatic broadcast
// Get dimensions of moment vector to broadcast lr_t
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
param_out.device(place) =
param -
lr_t.broadcast(m_dsize) *
(moment1_out / (moment2_out.sqrt() + epsilon));
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestAdamOp1(OpTest):
def setUp(self):
'''Test Adam Op with supplied attributes
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.004
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32")
}
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
param_out, moment1_out, moment2_out, beta1_pow_out, \
beta2_pow_out = adam_step(self.inputs, self.attrs)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out,
'ParamOut': param_out
}
def test_check_output(self):
self.check_output()
class TestAdamOp2(OpTest):
def setUp(self):
'''Test Adam Op with supplied attributes
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32")
}
attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
param_out, moment1_out, moment2_out, beta1_pow_out, \
beta2_pow_out = adam_step(self.inputs, attributes)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out,
'ParamOut': param_out
}
def test_check_output(self):
self.check_output()
class TestAdamOpMultipleSteps(OpTest):
def setUp(self):
'''Test Adam Operator with supplied attributes
'''
self.op_type = "adam"
self.num_steps = 10
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32")
}
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
def test_check_output(self):
for _ in range(self.num_steps):
param_out, moment1_out, moment2_out, beta1_pow_out, \
beta2_pow_out = adam_step(self.inputs, self.attrs)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out,
'ParamOut': param_out
}
# Verify output for this step
self.check_output()
# Output of this step becomes input for next step
self.inputs['Param'] = param_out
self.inputs['Moment1'] = moment1_out
self.inputs['Moment2'] = moment2_out
self.inputs['Beta1Pow'] = beta1_pow_out
self.inputs['Beta2Pow'] = beta2_pow_out
# Randomize gradient for next step
self.inputs['Grad'] = np.random.uniform(
-1, 1, (102, 105)).astype("float32")
def adam_step(inputs, attributes):
'''
Simulate one step of the adam optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param = inputs['Param']
grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
beta1 = attributes['beta1']
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2
lr_t = lr * np.sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out)
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon))
return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册