diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index a84262e0075aa4e29197acb4b10fbf3c5712af0d..c4ede7d2fba38036cad488b7235d57fb950ec755 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -149,7 +149,7 @@ static std::unique_ptr BackwardRecursive( for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; ++output_idx) { auto insert_add_x = dup_outputs[output_idx]; - auto insert_add_y = dup_outputs[output_idx]; + auto insert_add_y = dup_outputs[output_idx + 1]; auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); // first add op inserted if (output_idx == dup_outputs.size() - 2) { @@ -160,9 +160,8 @@ static std::unique_ptr BackwardRecursive( } insert_position.push_back( {dup_op.back(), - OpRegistry::CreateOp( - "sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, - {{"Out", {insert_add_out}}}, {})}); + OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}}, + {{"Out", {insert_add_out}}}, {})}); } } @@ -202,7 +201,8 @@ static std::unique_ptr BackwardRecursive( // process recurrent gradient op as a special operator. if (forwardOp.Type() == "recurrent") { - // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or + // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), + // or // this will result in infinite loop. const auto& rnnop = *static_cast(&forwardOp); diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 8f9eae4186ad848fcecd74b4ab22711f8bb99e2a..31d491f130e36289f7dd3dc18710d35d5ba6ab34 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("param"), - "Input(param) of SGDOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("grad"), - "Input(grad) of SGDOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("learning_rate"), - "Input(learning_rate) of SGDOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("param_out"), - "Output(param_out) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of SGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of SGDOp should not be null."); - auto param_dim = ctx->GetInputDim("param"); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 element"); + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), "Two input of SGD Op's dimension must be same."); - ctx->SetOutputDim("param_out", param_dim); + ctx->SetOutputDim("ParamOut", param_dim); } }; @@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { public: SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("param", "input parameter"); - AddInput("learning_rate", "learning rate of sgd"); - AddInput("grad", "input gradient"); - AddOutput("param_out", "output parameter"); + AddInput("Param", "Input parameter"); + AddInput("LearningRate", "Learning rate of SGD"); + AddInput("Grad", "Input gradient"); + AddOutput("ParamOut", "output parameter"); AddComment(R"DOC( Simplest sgd algorithm. diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 977d201ced31c498c2ab41cf6d412756cabb3aee..d72d333a9a03df17b34356090ca420cbe75fa58f 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -28,10 +28,10 @@ template class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param = ctx.Input("param"); - auto grad = ctx.Input("grad"); - auto param_out = ctx.Output("param_out"); - float lr = *ctx.Input("learning_rate"); + auto param = ctx.Input("Param"); + auto grad = ctx.Input("Grad"); + auto param_out = ctx.Output("ParamOut"); + float lr = ctx.Input("LearningRate")->data()[0]; param_out->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index f1125f4edb5248abb2a0128a7a8b8b3647ed3317..2dd881e5e107249277a91bd8e3a72567269e1cd4 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -8,10 +8,10 @@ class TestSGDOp(OpTest): self.op_type = "sgd" w = np.random.random((102, 105)).astype("float32") g = np.random.random((102, 105)).astype("float32") - lr = 0.1 + lr = np.array([0.1]).astype("float32") - self.inputs = {'param': w, 'grad': g, 'learning_rate': lr} - self.outputs = {'param_out': w - lr * g} + self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} + self.outputs = {'ParamOut': w - lr * g} def test_check_output(self): self.check_output()