提交 324876bb 编写于 作者: A Abhinav Arora 提交者: GitHub

Changing learning rate from type Input(float) to Input(tensor) (#4578)

上级 b884bc33
......@@ -32,6 +32,9 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of SGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("learning_rate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"),
"Two input of SGD Op's dimension must be same.");
......
......@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out");
float lr = *ctx.Input<float>("learning_rate");
float lr = ctx.Input<Tensor>("learning_rate")->data<float>()[0];
param_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -8,7 +8,7 @@ class TestSGDOp(OpTest):
self.op_type = "sgd"
w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
lr = 0.1
lr = np.array([0.1]).astype("float32")
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr}
self.outputs = {'param_out': w - lr * g}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册