diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 8f9eae4186ad848fcecd74b4ab22711f8bb99e2a..1a4d3fb8c57a5c1871b7ce51360509fb07da90fe 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -32,6 +32,9 @@ class SGDOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("param_out"), "Output(param_out) of SGDOp should not be null."); + auto lr_dims = ctx->GetInputDim("learning_rate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 element"); auto param_dim = ctx->GetInputDim("param"); PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), "Two input of SGD Op's dimension must be same."); diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 977d201ced31c498c2ab41cf6d412756cabb3aee..e2ae65beb0910a366136bc5f17c92992ae753a01 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel { auto param = ctx.Input("param"); auto grad = ctx.Input("grad"); auto param_out = ctx.Output("param_out"); - float lr = *ctx.Input("learning_rate"); + float lr = ctx.Input("learning_rate")->data()[0]; param_out->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index f1125f4edb5248abb2a0128a7a8b8b3647ed3317..c05364490f0de545f2bf144784d1bc1c9aaa94d2 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -8,7 +8,7 @@ class TestSGDOp(OpTest): self.op_type = "sgd" w = np.random.random((102, 105)).astype("float32") g = np.random.random((102, 105)).astype("float32") - lr = 0.1 + lr = np.array([0.1]).astype("float32") self.inputs = {'param': w, 'grad': g, 'learning_rate': lr} self.outputs = {'param_out': w - lr * g}