提交 42e7fe05 编写于 作者: A Abhinav Arora 提交者: GitHub

Changing learning rate from attribute to input(float) (#4568)

* Changing learning rate from attribute to input(float)
* Removing obsolete code
上级 fddaf0c4
...@@ -27,6 +27,8 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -27,6 +27,8 @@ class SGDOp : public framework::OperatorWithKernel {
"Input(param) of SGDOp should not be null."); "Input(param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"), PADDLE_ENFORCE(ctx->HasInput("grad"),
"Input(grad) of SGDOp should not be null."); "Input(grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
"Input(learning_rate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"), PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of SGDOp should not be null."); "Output(param_out) of SGDOp should not be null.");
...@@ -42,9 +44,9 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,9 +44,9 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("param", "input parameter");
AddInput("learning_rate", "learning rate of sgd");
AddInput("grad", "input gradient"); AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter"); AddOutput("param_out", "output parameter");
AddAttr<float>("learning_rate", "learning rate of sgd");
AddComment(R"DOC( AddComment(R"DOC(
Simplest sgd algorithm. Simplest sgd algorithm.
......
...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.Attr<float>("learning_rate"); float lr = *ctx.Input<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
...@@ -143,6 +143,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -143,6 +143,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("set_int", .def("set_int",
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; }) [](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
.def("get_int", [](const Variable &var) -> int { return var.Get<int>(); }) .def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
.def("is_float", [](const Variable &var) { return var.IsType<float>(); })
.def("set_float",
[](Variable &var, float val) -> void {
*var.GetMutable<float>() = val;
})
.def("get_float",
[](const Variable &var) -> float { return var.Get<float>(); })
.def("get_tensor", .def("get_tensor",
[](Variable &self) -> LoDTensor * { [](Variable &self) -> LoDTensor * {
return self.GetMutable<LoDTensor>(); return self.GetMutable<LoDTensor>();
......
...@@ -46,12 +46,17 @@ def create_op(scope, op_type, inputs, outputs, attrs): ...@@ -46,12 +46,17 @@ def create_op(scope, op_type, inputs, outputs, attrs):
def set_input(scope, op, inputs, place): def set_input(scope, op, inputs, place):
def __set_input__(var_name, var): def __set_input__(var_name, var):
if isinstance(var, tuple) or isinstance(var, np.ndarray):
tensor = scope.find_var(var_name).get_tensor() tensor = scope.find_var(var_name).get_tensor()
if isinstance(var, tuple): if isinstance(var, tuple):
tensor.set_lod(var[1]) tensor.set_lod(var[1])
var = var[0] var = var[0]
tensor.set_dims(var.shape) tensor.set_dims(var.shape)
tensor.set(var, place) tensor.set(var, place)
elif isinstance(var, float):
scope.find_var(var_name).set_float(var)
elif isinstance(var, int):
scope.find_var(var_name).set_int(var)
for in_name, in_dup in Operator.get_op_inputs(op.type()): for in_name, in_dup in Operator.get_op_inputs(op.type()):
if in_name in inputs: if in_name in inputs:
......
...@@ -10,8 +10,7 @@ class TestSGDOp(OpTest): ...@@ -10,8 +10,7 @@ class TestSGDOp(OpTest):
g = np.random.random((102, 105)).astype("float32") g = np.random.random((102, 105)).astype("float32")
lr = 0.1 lr = 0.1
self.inputs = {'param': w, 'grad': g} self.inputs = {'param': w, 'grad': g, 'learning_rate': lr}
self.attrs = {'learning_rate': lr}
self.outputs = {'param_out': w - lr * g} self.outputs = {'param_out': w - lr * g}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册