提交 d1de7ec6 编写于 作者: K Kexin Zhao

Change learning rate from attribute to input tensor

上级 05cbd4da
...@@ -29,12 +29,17 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -29,12 +29,17 @@ class AdagradOp : public framework::OperatorWithKernel {
"Input(grad) of AdagradOp should not be null."); "Input(grad) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("moment"), PADDLE_ENFORCE(ctx->HasInput("moment"),
"Input(moment) of AdagradOp should not be null."); "Input(moment) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
"Input(learning_rate) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"), PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of AdagradOp should not be null."); "Output(param_out) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("moment_out"), PADDLE_ENFORCE(ctx->HasOutput("moment_out"),
"Output(moment_out) of AdagradOp should not be null."); "Output(moment_out) of AdagradOp should not be null.");
auto lr_dims = ctx->GetInputDim("learning_rate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"learning_rate should have one element");
auto param_dim = ctx->GetInputDim("param"); auto param_dim = ctx->GetInputDim("param");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("grad"), param_dim, ctx->GetInputDim("grad"),
...@@ -56,11 +61,11 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,11 +61,11 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("param", "Input parameter"); AddInput("param", "Input parameter");
AddInput("grad", "Input gradient"); AddInput("grad", "Input gradient");
AddInput("moment", "Second moment"); AddInput("moment", "Second moment");
AddInput("learning_rate", "learning rate of adagrad");
AddOutput("param_out", "Output parameter"); AddOutput("param_out", "Output parameter");
AddOutput("moment_out", "Output second moment"); AddOutput("moment_out", "Output second moment");
AddAttr<float>("learning_rate", "Learning rate");
AddAttr<float>("epsilon", "Constant for numerical stability"); AddAttr<float>("epsilon", "Constant for numerical stability");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -20,6 +20,11 @@ namespace paddle { ...@@ -20,6 +20,11 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
...@@ -34,12 +39,14 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -34,12 +39,14 @@ class AdagradOpKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace()); moment_out->mutable_data<T>(ctx.GetPlace());
float lr = ctx.Attr<float>("learning_rate"); float lr = ctx.Input<Tensor>("learning_rate")->data<float>()[0];
float epsilon = ctx.Attr<float>("epsilon"); float epsilon = ctx.Attr<float>("epsilon");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("param")); auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("grad")); auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("grad"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("moment")); auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("moment"));
auto lr = EigenScalar<T>::From(*ctx.Input<Tensor>("learning_rate"));
auto p_out = EigenVector<T>::Flatten(*param_out); auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out); auto m_out = EigenVector<T>::Flatten(*moment_out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
......
...@@ -11,7 +11,7 @@ class TestAdagradOp(OpTest): ...@@ -11,7 +11,7 @@ class TestAdagradOp(OpTest):
grad = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32")
learning_rate = 0.01 lr = np.array([0.01]).astype("float32")
epsilon = 1e-6 epsilon = 1e-6
self.inputs = {'param': param, 'grad': grad, 'moment': moment} self.inputs = {'param': param, 'grad': grad, 'moment': moment}
...@@ -19,8 +19,7 @@ class TestAdagradOp(OpTest): ...@@ -19,8 +19,7 @@ class TestAdagradOp(OpTest):
self.attrs = {'learning_rate': learning_rate, 'epsilon': epsilon} self.attrs = {'learning_rate': learning_rate, 'epsilon': epsilon}
moment_out = moment + grad * grad moment_out = moment + grad * grad
param_out = param - learning_rate * grad / (np.sqrt(moment_out) + param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
epsilon)
self.outputs = {'param_out': param_out, 'moment_out': moment_out} self.outputs = {'param_out': param_out, 'moment_out': moment_out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册