提交 eed2c1e1 编写于 作者: A Abhinav Arora 提交者: GitHub

Changing SGD inputs and outputs to conform to Operator naming convention (#4586)

上级 77e07833
......@@ -23,22 +23,22 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("param"),
"Input(param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"),
"Input(grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
"Input(learning_rate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of SGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("learning_rate");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"),
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same.");
ctx->SetOutputDim("param_out", param_dim);
ctx->SetOutputDim("ParamOut", param_dim);
}
};
......@@ -46,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("learning_rate", "learning rate of sgd");
AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter");
AddInput("Param", "Input parameter");
AddInput("LearningRate", "Learning rate of SGD");
AddInput("Grad", "Input gradient");
AddOutput("ParamOut", "output parameter");
AddComment(R"DOC(
Simplest sgd algorithm.
......
......@@ -28,10 +28,10 @@ template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.Input<Tensor>("learning_rate")->data<float>()[0];
auto param = ctx.Input<Tensor>("Param");
auto grad = ctx.Input<Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("ParamOut");
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
param_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -10,8 +10,8 @@ class TestSGDOp(OpTest):
g = np.random.random((102, 105)).astype("float32")
lr = np.array([0.1]).astype("float32")
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr}
self.outputs = {'param_out': w - lr * g}
self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'ParamOut': w - lr * g}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册