diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index 03e22cc600bc1805c0851a1ad195dde76957aa52..56a5fbcb86524174a8b5fc993e61bb9a5219969e 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -29,12 +29,17 @@ class AdagradOp : public framework::OperatorWithKernel { "Input(grad) of AdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("moment"), "Input(moment) of AdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("learning_rate"), + "Input(learning_rate) of AdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("param_out"), "Output(param_out) of AdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("moment_out"), "Output(moment_out) of AdagradOp should not be null."); + auto lr_dims = ctx->GetInputDim("learning_rate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "learning_rate should have one element"); auto param_dim = ctx->GetInputDim("param"); PADDLE_ENFORCE_EQ( param_dim, ctx->GetInputDim("grad"), @@ -56,11 +61,11 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("param", "Input parameter"); AddInput("grad", "Input gradient"); AddInput("moment", "Second moment"); + AddInput("learning_rate", "learning rate of adagrad"); AddOutput("param_out", "Output parameter"); AddOutput("moment_out", "Output second moment"); - AddAttr("learning_rate", "Learning rate"); AddAttr("epsilon", "Constant for numerical stability"); AddComment(R"DOC( diff --git a/paddle/operators/adagrad_op.h b/paddle/operators/adagrad_op.h index ca1836c3faf077ad26976b4c472863cdb0935af0..73833d4a3fd05d2aa7f4329481bc098218472795 100644 --- a/paddle/operators/adagrad_op.h +++ b/paddle/operators/adagrad_op.h @@ -20,6 +20,11 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; + +template +using EigenScalar = framework::EigenScalar; + template using EigenVector = framework::EigenVector; @@ -34,12 +39,14 @@ class AdagradOpKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); - float lr = ctx.Attr("learning_rate"); + float lr = ctx.Input("learning_rate")->data()[0]; float epsilon = ctx.Attr("epsilon"); auto p = EigenVector::Flatten(*ctx.Input("param")); auto g = EigenVector::Flatten(*ctx.Input("grad")); auto m = EigenVector::Flatten(*ctx.Input("moment")); + auto lr = EigenScalar::From(*ctx.Input("learning_rate")); + auto p_out = EigenVector::Flatten(*param_out); auto m_out = EigenVector::Flatten(*moment_out); auto place = ctx.GetEigenDevice(); diff --git a/python/paddle/v2/framework/tests/test_adagrad_op.py b/python/paddle/v2/framework/tests/test_adagrad_op.py index b3f8b812e1db81e2d46530625d65d12e8c1af84a..2ee38ea37c9df61c34ea9e7fe23b47ca60577b09 100644 --- a/python/paddle/v2/framework/tests/test_adagrad_op.py +++ b/python/paddle/v2/framework/tests/test_adagrad_op.py @@ -11,7 +11,7 @@ class TestAdagradOp(OpTest): grad = np.random.random((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32") - learning_rate = 0.01 + lr = np.array([0.01]).astype("float32") epsilon = 1e-6 self.inputs = {'param': param, 'grad': grad, 'moment': moment} @@ -19,8 +19,7 @@ class TestAdagradOp(OpTest): self.attrs = {'learning_rate': learning_rate, 'epsilon': epsilon} moment_out = moment + grad * grad - param_out = param - learning_rate * grad / (np.sqrt(moment_out) + - epsilon) + param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon) self.outputs = {'param_out': param_out, 'moment_out': moment_out}