From 5b4526fabcb1108ee18829eea2352ec0794c166a Mon Sep 17 00:00:00 2001 From: caoying03 Date: Thu, 7 Sep 2017 11:20:48 +0800 Subject: [PATCH] rename input and output of softmax_op. --- paddle/operators/identity_op.cc | 4 +- paddle/operators/scale_op.cc | 3 +- paddle/operators/softmax_op.cc | 31 ++++++------ paddle/operators/softmax_op.h | 20 ++++---- python/paddle/v2/framework/op.py | 48 ++++++++++--------- .../framework/tests/test_gradient_checker.py | 4 +- .../v2/framework/tests/test_softmax_op.py | 8 ++-- 7 files changed, 60 insertions(+), 58 deletions(-) diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc index 7d7a53baedd..7d9d4fa519d 100644 --- a/paddle/operators/identity_op.cc +++ b/paddle/operators/identity_op.cc @@ -19,7 +19,7 @@ namespace paddle { namespace operators { // The identity operator is an alias of the scale operator. This is also an -// example for creating the alias for an existing operator. +// example for creating an alias for an existing operator. template class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -30,7 +30,7 @@ class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of identity operator."); AddComment(R"DOC( The identity operator is an alias of the scale operator -with the attribute scale fixed to 1.0 +with the attribute scale fixed to 1.0. )DOC"); } }; diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 0377f05b2c5..841e38d6514 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -49,7 +49,8 @@ The equation is: Out = scale*X } }; -// The gradients of a scale operator is just the scale operator itself. +// The operator to calculate gradients of a scale operator is just the scale +// operator itself. // Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out)) template class ScaleGradOp : public NetOp { diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 7edf1c3460c..7166b2f60be 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -23,9 +23,9 @@ class SoftmaxOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.Input("Logits")->dims().size() == 2UL, + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, "The input of softmax op must be a matrix."); - ctx.Output("Out")->Resize(ctx.Input("Logits")->dims()); + ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -34,10 +34,10 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { SoftmaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Logits", + AddInput("X", "The input tensor of softmax. " "2-D with shape [batch_size, input_feature_dimensions]."); - AddOutput("Out", "The normalized values with the same shape as the input."); + AddOutput("Y", "The normalized values with the same shape as X."); AddComment(R"DOC( The input of softmax operator is a 2-D tensor with shape N x K (N is the batch_size, K is the dimension of input feature). The output tensor has the @@ -51,8 +51,8 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the exponential of the given dimension and the sum of exponential values of all the other dimensions is the output of the softmax operator. -For each row `i` and each column `j` in the input: Logits, we have: - Out[i, j] = exp(Logits[i, j]) / sum_j(exp(Logits[i, j])) +For each row `i` and each column `j` in input X, we have: + Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j])) )DOC"); } @@ -64,16 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), - "Input(Out) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should be not null."); - PADDLE_ENFORCE_EQ(ctx.Input("Out")->dims(), - ctx.Input(framework::GradVarName("Out"))->dims(), - "Input(Out) and its gradients should have a same shape."); - - ctx.Output(framework::GradVarName("Logits")) - ->Resize(ctx.Input("Logits")->dims()); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(), + ctx.Input(framework::GradVarName("Y"))->dims(), + "Input(Y) and its gradients should have a same shape."); + + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("X")->dims()); } }; diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 2ef5915239f..8a3a5ab927c 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -28,12 +28,12 @@ template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto X = context.Input("Logits"); - auto Y = context.Output("Out"); + auto X = context.Input("X"); + auto Y = context.Output("Y"); Y->mutable_data(context.GetPlace()); auto logits = EigenMatrix::From(*X); - auto out = EigenMatrix::From(*Y); + auto softmax = EigenMatrix::From(*Y); const int kBatchDim = 0; const int kClassDim = 1; @@ -51,11 +51,11 @@ class SoftmaxKernel : public framework::OpKernel { .reshape(batch_by_one) .broadcast(one_by_class)); - out.device(context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - out.device(context.GetEigenDevice()) = - (out * - out.sum(along_class) + softmax.device(context.GetEigenDevice()) = + (softmax * + softmax.sum(along_class) .inverse() .eval() .reshape(batch_by_one) @@ -69,9 +69,9 @@ class SoftmaxGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { std::shared_ptr scale_ = std::make_shared(); - auto Y = context.Input("Out"); - auto dY = context.Input(framework::GradVarName("Out")); - auto dX = context.Output(framework::GradVarName("Logits")); + auto Y = context.Input("Y"); + auto dY = context.Input(framework::GradVarName("Y")); + auto dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); const int batch_size = Y->dims()[0]; diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index db07bd329d0..c1585bcffcc 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -5,7 +5,7 @@ import paddle.v2.framework.proto.framework_pb2 as framework_pb2 def get_all_op_protos(): """ Get all registered op proto from PaddlePaddle C++ end. - :return: list of OpProto + :return: A list of registered OpProto. """ protostrs = core.get_all_op_protos() ret_values = [] @@ -21,8 +21,8 @@ def is_str(s): class OpDescCreationMethod(object): """ - A Functor object converting the user's input(only keyword arguments are - supported) to OpDesc based on the OpProto. + Convert the user's input(only keyword arguments are supported) to OpDesc + based on the OpProto. :param op_proto: The OpProto object. :type op_proto: op_proto_pb2.OpProto @@ -37,7 +37,7 @@ class OpDescCreationMethod(object): def __call__(self, *args, **kwargs): """ Convert user's input to OpDesc. Only keyword arguments are supported. - :return: OpDesc based on user input + :return: The OpDesc based on user input. :rtype: op_desc_pb2.OpDesc """ if len(args) != 0: @@ -54,7 +54,7 @@ class OpDescCreationMethod(object): "Input %s expects only one input, but %d are given." % (input_parameter.name, len(input_arguments))) - ipt = op_desc.inputs.add() + ipt = op_desc.inputs.add() ipt.parameter = input_parameter.name ipt.arguments.extend(input_arguments) @@ -68,7 +68,7 @@ class OpDescCreationMethod(object): "Output %s expects only one output, but %d are given." % (output_parameter.name, len(output_arguments))) - out = op_desc.outputs.add() + out = op_desc.outputs.add() out.parameter = output_parameter.name out.arguments.extend(output_arguments) @@ -106,12 +106,13 @@ class OpDescCreationMethod(object): "A not supported attribute type: %s." % ( str(attr.type))) - return op_desc + return op_desc @staticmethod def any_is_true(generator): """ - Reduce a bool array to one. If any of them is True, then return True. + Reduce a boolean array to a single boolean parameter. If any element in + the array is True, this function will return True, otherwise False. """ for flag in generator: if flag: @@ -130,7 +131,7 @@ class OpInfo(object): def create_op_creation_method(op_proto): """ - Generate op creation method for an OpProto + Generate op creation method for an OpProto. """ method = OpDescCreationMethod(op_proto) @@ -145,27 +146,28 @@ def create_op_creation_method(op_proto): outputs=[var.name for var in op_proto.outputs], attrs=[attr.name for attr in op_proto.attrs]) - class OperatorFactory(object): - def __init__(self): - self.op_methods = dict() + +class OperatorFactory(object): + def __init__(self): + self.op_methods = dict() for op_proto in get_all_op_protos(): method = create_op_creation_method(op_proto) self.op_methods[method.name] = method def __call__(self, *args, **kwargs): - if 'type' in kwargs: + if "type" in kwargs: if len(args) != 0: raise ValueError( - ("All PaddlePaddle arguments should be keyword " - "arguments except the argument \"type\".")) - t = kwargs.pop('type') + "Except the argument \"type\"," + "all of the other arguments should be keyword arguments.") + t = kwargs.pop("type") else: if len(args) != 1: raise ValueError( - ("All PaddlePaddle arguments should be keyword " - "arguments except the argument \"type\".")) - t = args[0] + "Except the argument \"type\"," + "all of the other arguments should be keyword arguments.") + t = args[0] return self.get_op_info(t).method(**kwargs) @@ -189,7 +191,7 @@ def create_op_creation_method(op_proto): class __RecurrentOp__(object): __proto__ = None - type = 'recurrent' + type = "recurrent" def __init__(self): # cache recurrent_op's proto @@ -199,8 +201,8 @@ class __RecurrentOp__(object): self.__proto__ = op_proto def __call__(self, *args, **kwargs): - if self.type not in args and 'type' not in kwargs: - kwargs['type'] = self.type + if self.type not in args and "type" not in kwargs: + kwargs["type"] = self.type # create proto create_method = OpDescCreationMethod(self.__proto__) proto = create_method(*args, **kwargs) @@ -208,5 +210,5 @@ class __RecurrentOp__(object): return core.RecurrentOp.create(proto.SerializeToString()) -Operator = OperatorFactory() # Default global factory +Operator = OperatorFactory() # The default global factory RecurrentOp = __RecurrentOp__() diff --git a/python/paddle/v2/framework/tests/test_gradient_checker.py b/python/paddle/v2/framework/tests/test_gradient_checker.py index e6307bc2ecd..e8a7f848dff 100644 --- a/python/paddle/v2/framework/tests/test_gradient_checker.py +++ b/python/paddle/v2/framework/tests/test_gradient_checker.py @@ -28,14 +28,14 @@ class GetNumericGradientTest(unittest.TestCase): dX[i, :] = Y[i, :] * (dY[i, :] - d) return dX - softmax_op = Operator("softmax", Logits="Logits", Out="Out") + softmax_op = Operator("softmax", X="X", Y="Y") X = numpy.random.random((2, 2)).astype("float32") Y = numpy.apply_along_axis(stable_softmax, 1, X) dY = numpy.ones(Y.shape) dX = label_softmax_grad(Y, dY) - arr = get_numeric_gradient(softmax_op, {"Logits": X}, "Out", "Logits") + arr = get_numeric_gradient(softmax_op, {"X": X}, "Y", "X") numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py index 63042e9bfdd..0d590fa7065 100644 --- a/python/paddle/v2/framework/tests/test_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -18,9 +18,9 @@ class TestSoftmaxOp(unittest.TestCase): def setUp(self): self.type = "softmax" - self.inputs = {"Logits": np.random.random((10, 10)).astype("float32")} + self.inputs = {"X": np.random.random((10, 10)).astype("float32")} self.outputs = { - "Out": np.apply_along_axis(stable_softmax, 1, self.inputs["Logits"]) + "Y": np.apply_along_axis(stable_softmax, 1, self.inputs["X"]) } @@ -28,11 +28,11 @@ class TestSoftmaxGradOp(GradientChecker): def setUp(self): self.op = create_op("softmax") self.inputs = { - "Logits": np.random.uniform(0.1, 1, [10, 10]).astype("float32") + "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32") } def test_softmax_grad(self): - self.check_grad(self.op, self.inputs, ["Logits"], "Out") + self.check_grad(self.op, self.inputs, ["X"], "Y") if __name__ == "__main__": -- GitLab