提交 c61e82bc 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_backward_for_op_desc

...@@ -149,7 +149,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -149,7 +149,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
++output_idx) { ++output_idx) {
auto insert_add_x = dup_outputs[output_idx]; auto insert_add_x = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx]; auto insert_add_y = dup_outputs[output_idx + 1];
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
// first add op inserted // first add op inserted
if (output_idx == dup_outputs.size() - 2) { if (output_idx == dup_outputs.size() - 2) {
...@@ -160,9 +160,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -160,9 +160,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
} }
insert_position.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(),
OpRegistry::CreateOp( OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, {{"Out", {insert_add_out}}}, {})});
{{"Out", {insert_add_out}}}, {})});
} }
} }
...@@ -202,7 +201,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -202,7 +201,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// process recurrent gradient op as a special operator. // process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") { if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or // NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or
// this will result in infinite loop. // this will result in infinite loop.
const auto& rnnop = const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp); *static_cast<const operators::RecurrentOp*>(&forwardOp);
......
...@@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of SGDOp should not be null."); "Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of SGDOp should not be null."); "Input(Grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(learning_rate) of SGDOp should not be null."); "Input(LearningRate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of SGDOp should not be null."); "Output(ParamOut) of SGDOp should not be null.");
auto param_dim = ctx->GetInputDim("param"); auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
ctx->SetOutputDim("param_out", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
}; };
...@@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("Param", "Input parameter");
AddInput("learning_rate", "learning rate of sgd"); AddInput("LearningRate", "Learning rate of SGD");
AddInput("grad", "input gradient"); AddInput("Grad", "Input gradient");
AddOutput("param_out", "output parameter"); AddOutput("ParamOut", "output parameter");
AddComment(R"DOC( AddComment(R"DOC(
Simplest sgd algorithm. Simplest sgd algorithm.
......
...@@ -28,10 +28,10 @@ template <typename Place, typename T> ...@@ -28,10 +28,10 @@ template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("Param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<Tensor>("ParamOut");
float lr = *ctx.Input<float>("learning_rate"); float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
...@@ -8,10 +8,10 @@ class TestSGDOp(OpTest): ...@@ -8,10 +8,10 @@ class TestSGDOp(OpTest):
self.op_type = "sgd" self.op_type = "sgd"
w = np.random.random((102, 105)).astype("float32") w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32") g = np.random.random((102, 105)).astype("float32")
lr = 0.1 lr = np.array([0.1]).astype("float32")
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr} self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'param_out': w - lr * g} self.outputs = {'ParamOut': w - lr * g}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册