提交 eed2c1e1 编写于 作者: A Abhinav Arora 提交者: GitHub

Changing SGD inputs and outputs to conform to Operator naming convention (#4586)

上级 77e07833
...@@ -23,22 +23,22 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -23,22 +23,22 @@ class SGDOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of SGDOp should not be null."); "Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of SGDOp should not be null."); "Input(Grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(learning_rate) of SGDOp should not be null."); "Input(LearningRate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of SGDOp should not be null."); "Output(ParamOut) of SGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("learning_rate"); auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element"); "Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
ctx->SetOutputDim("param_out", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
}; };
...@@ -46,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -46,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("Param", "Input parameter");
AddInput("learning_rate", "learning rate of sgd"); AddInput("LearningRate", "Learning rate of SGD");
AddInput("grad", "input gradient"); AddInput("Grad", "Input gradient");
AddOutput("param_out", "output parameter"); AddOutput("ParamOut", "output parameter");
AddComment(R"DOC( AddComment(R"DOC(
Simplest sgd algorithm. Simplest sgd algorithm.
......
...@@ -28,10 +28,10 @@ template <typename Place, typename T> ...@@ -28,10 +28,10 @@ template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("Param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<Tensor>("ParamOut");
float lr = ctx.Input<Tensor>("learning_rate")->data<float>()[0]; float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
...@@ -10,8 +10,8 @@ class TestSGDOp(OpTest): ...@@ -10,8 +10,8 @@ class TestSGDOp(OpTest):
g = np.random.random((102, 105)).astype("float32") g = np.random.random((102, 105)).astype("float32")
lr = np.array([0.1]).astype("float32") lr = np.array([0.1]).astype("float32")
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr} self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'param_out': w - lr * g} self.outputs = {'ParamOut': w - lr * g}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册