提交 5b4526fa 编写于 作者: C caoying03

rename input and output of softmax_op.

上级 e61485e0
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
// The identity operator is an alias of the scale operator. This is also an // The identity operator is an alias of the scale operator. This is also an
// example for creating the alias for an existing operator. // example for creating an alias for an existing operator.
template <typename AttrType> template <typename AttrType>
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -30,7 +30,7 @@ class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -30,7 +30,7 @@ class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of identity operator."); AddOutput("Out", "The output tensor of identity operator.");
AddComment(R"DOC( AddComment(R"DOC(
The identity operator is an alias of the scale operator The identity operator is an alias of the scale operator
with the attribute scale fixed to 1.0 with the attribute scale fixed to 1.0.
)DOC"); )DOC");
} }
}; };
......
...@@ -49,7 +49,8 @@ The equation is: Out = scale*X ...@@ -49,7 +49,8 @@ The equation is: Out = scale*X
} }
}; };
// The gradients of a scale operator is just the scale operator itself. // The operator to calculate gradients of a scale operator is just the scale
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out)) // Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template <typename AttrType> template <typename AttrType>
class ScaleGradOp : public NetOp { class ScaleGradOp : public NetOp {
......
...@@ -23,9 +23,9 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -23,9 +23,9 @@ class SoftmaxOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("Logits")->dims().size() == 2UL, PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
"The input of softmax op must be a matrix."); "The input of softmax op must be a matrix.");
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("Logits")->dims()); ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
...@@ -34,10 +34,10 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,10 +34,10 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
SoftmaxOpMaker(framework::OpProto *proto, SoftmaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "The normalized values with the same shape as the input."); AddOutput("Y", "The normalized values with the same shape as X.");
AddComment(R"DOC( AddComment(R"DOC(
The input of softmax operator is a 2-D tensor with shape N x K (N is the The input of softmax operator is a 2-D tensor with shape N x K (N is the
batch_size, K is the dimension of input feature). The output tensor has the batch_size, K is the dimension of input feature). The output tensor has the
...@@ -51,8 +51,8 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the ...@@ -51,8 +51,8 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the
exponential of the given dimension and the sum of exponential values of all exponential of the given dimension and the sum of exponential values of all
the other dimensions is the output of the softmax operator. the other dimensions is the output of the softmax operator.
For each row `i` and each column `j` in the input: Logits, we have: For each row `i` and each column `j` in input X, we have:
Out[i, j] = exp(Logits[i, j]) / sum_j(exp(Logits[i, j])) Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))
)DOC"); )DOC");
} }
...@@ -64,16 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -64,16 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null.");
"Input(Out) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Y@GRAD) should be not null.");
"Input(Out@GRAD) should be not null."); PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Y")->dims(),
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Out")->dims(), ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(), "Input(Y) and its gradients should have a same shape.");
"Input(Out) and its gradients should have a same shape.");
ctx.Output<Tensor>(framework::GradVarName("X"))
ctx.Output<Tensor>(framework::GradVarName("Logits")) ->Resize(ctx.Input<Tensor>("X")->dims());
->Resize(ctx.Input<Tensor>("Logits")->dims());
} }
}; };
......
...@@ -28,12 +28,12 @@ template <typename Place, typename T> ...@@ -28,12 +28,12 @@ template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel { class SoftmaxKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto X = context.Input<Tensor>("Logits"); auto X = context.Input<Tensor>("X");
auto Y = context.Output<Tensor>("Out"); auto Y = context.Output<Tensor>("Y");
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(context.GetPlace());
auto logits = EigenMatrix<T>::From(*X); auto logits = EigenMatrix<T>::From(*X);
auto out = EigenMatrix<T>::From(*Y); auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0; const int kBatchDim = 0;
const int kClassDim = 1; const int kClassDim = 1;
...@@ -51,11 +51,11 @@ class SoftmaxKernel : public framework::OpKernel { ...@@ -51,11 +51,11 @@ class SoftmaxKernel : public framework::OpKernel {
.reshape(batch_by_one) .reshape(batch_by_one)
.broadcast(one_by_class)); .broadcast(one_by_class));
out.device(context.GetEigenDevice<Place>()) = shifted_logits.exp(); softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
out.device(context.GetEigenDevice<Place>()) = softmax.device(context.GetEigenDevice<Place>()) =
(out * (softmax *
out.sum(along_class) softmax.sum(along_class)
.inverse() .inverse()
.eval() .eval()
.reshape(batch_by_one) .reshape(batch_by_one)
...@@ -69,9 +69,9 @@ class SoftmaxGradKernel : public framework::OpKernel { ...@@ -69,9 +69,9 @@ class SoftmaxGradKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>(); std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Out"); auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(framework::GradVarName("Out")); auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX = context.Output<Tensor>(framework::GradVarName("Logits")); auto dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
const int batch_size = Y->dims()[0]; const int batch_size = Y->dims()[0];
......
...@@ -5,7 +5,7 @@ import paddle.v2.framework.proto.framework_pb2 as framework_pb2 ...@@ -5,7 +5,7 @@ import paddle.v2.framework.proto.framework_pb2 as framework_pb2
def get_all_op_protos(): def get_all_op_protos():
""" """
Get all registered op proto from PaddlePaddle C++ end. Get all registered op proto from PaddlePaddle C++ end.
:return: list of OpProto :return: A list of registered OpProto.
""" """
protostrs = core.get_all_op_protos() protostrs = core.get_all_op_protos()
ret_values = [] ret_values = []
...@@ -21,8 +21,8 @@ def is_str(s): ...@@ -21,8 +21,8 @@ def is_str(s):
class OpDescCreationMethod(object): class OpDescCreationMethod(object):
""" """
A Functor object converting the user's input(only keyword arguments are Convert the user's input(only keyword arguments are supported) to OpDesc
supported) to OpDesc based on the OpProto. based on the OpProto.
:param op_proto: The OpProto object. :param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto :type op_proto: op_proto_pb2.OpProto
...@@ -37,7 +37,7 @@ class OpDescCreationMethod(object): ...@@ -37,7 +37,7 @@ class OpDescCreationMethod(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
Convert user's input to OpDesc. Only keyword arguments are supported. Convert user's input to OpDesc. Only keyword arguments are supported.
:return: OpDesc based on user input :return: The OpDesc based on user input.
:rtype: op_desc_pb2.OpDesc :rtype: op_desc_pb2.OpDesc
""" """
if len(args) != 0: if len(args) != 0:
...@@ -54,7 +54,7 @@ class OpDescCreationMethod(object): ...@@ -54,7 +54,7 @@ class OpDescCreationMethod(object):
"Input %s expects only one input, but %d are given." % "Input %s expects only one input, but %d are given." %
(input_parameter.name, len(input_arguments))) (input_parameter.name, len(input_arguments)))
ipt = op_desc.inputs.add() ipt = op_desc.inputs.add()
ipt.parameter = input_parameter.name ipt.parameter = input_parameter.name
ipt.arguments.extend(input_arguments) ipt.arguments.extend(input_arguments)
...@@ -68,7 +68,7 @@ class OpDescCreationMethod(object): ...@@ -68,7 +68,7 @@ class OpDescCreationMethod(object):
"Output %s expects only one output, but %d are given." % "Output %s expects only one output, but %d are given." %
(output_parameter.name, len(output_arguments))) (output_parameter.name, len(output_arguments)))
out = op_desc.outputs.add() out = op_desc.outputs.add()
out.parameter = output_parameter.name out.parameter = output_parameter.name
out.arguments.extend(output_arguments) out.arguments.extend(output_arguments)
...@@ -106,12 +106,13 @@ class OpDescCreationMethod(object): ...@@ -106,12 +106,13 @@ class OpDescCreationMethod(object):
"A not supported attribute type: %s." % ( "A not supported attribute type: %s." % (
str(attr.type))) str(attr.type)))
return op_desc return op_desc
@staticmethod @staticmethod
def any_is_true(generator): def any_is_true(generator):
""" """
Reduce a bool array to one. If any of them is True, then return True. Reduce a boolean array to a single boolean parameter. If any element in
the array is True, this function will return True, otherwise False.
""" """
for flag in generator: for flag in generator:
if flag: if flag:
...@@ -130,7 +131,7 @@ class OpInfo(object): ...@@ -130,7 +131,7 @@ class OpInfo(object):
def create_op_creation_method(op_proto): def create_op_creation_method(op_proto):
""" """
Generate op creation method for an OpProto Generate op creation method for an OpProto.
""" """
method = OpDescCreationMethod(op_proto) method = OpDescCreationMethod(op_proto)
...@@ -145,27 +146,28 @@ def create_op_creation_method(op_proto): ...@@ -145,27 +146,28 @@ def create_op_creation_method(op_proto):
outputs=[var.name for var in op_proto.outputs], outputs=[var.name for var in op_proto.outputs],
attrs=[attr.name for attr in op_proto.attrs]) attrs=[attr.name for attr in op_proto.attrs])
class OperatorFactory(object):
def __init__(self): class OperatorFactory(object):
self.op_methods = dict() def __init__(self):
self.op_methods = dict()
for op_proto in get_all_op_protos(): for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto) method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method self.op_methods[method.name] = method
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if 'type' in kwargs: if "type" in kwargs:
if len(args) != 0: if len(args) != 0:
raise ValueError( raise ValueError(
("All PaddlePaddle arguments should be keyword " "Except the argument \"type\","
"arguments except the argument \"type\".")) "all of the other arguments should be keyword arguments.")
t = kwargs.pop('type') t = kwargs.pop("type")
else: else:
if len(args) != 1: if len(args) != 1:
raise ValueError( raise ValueError(
("All PaddlePaddle arguments should be keyword " "Except the argument \"type\","
"arguments except the argument \"type\".")) "all of the other arguments should be keyword arguments.")
t = args[0] t = args[0]
return self.get_op_info(t).method(**kwargs) return self.get_op_info(t).method(**kwargs)
...@@ -189,7 +191,7 @@ def create_op_creation_method(op_proto): ...@@ -189,7 +191,7 @@ def create_op_creation_method(op_proto):
class __RecurrentOp__(object): class __RecurrentOp__(object):
__proto__ = None __proto__ = None
type = 'recurrent' type = "recurrent"
def __init__(self): def __init__(self):
# cache recurrent_op's proto # cache recurrent_op's proto
...@@ -199,8 +201,8 @@ class __RecurrentOp__(object): ...@@ -199,8 +201,8 @@ class __RecurrentOp__(object):
self.__proto__ = op_proto self.__proto__ = op_proto
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs: if self.type not in args and "type" not in kwargs:
kwargs['type'] = self.type kwargs["type"] = self.type
# create proto # create proto
create_method = OpDescCreationMethod(self.__proto__) create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs) proto = create_method(*args, **kwargs)
...@@ -208,5 +210,5 @@ class __RecurrentOp__(object): ...@@ -208,5 +210,5 @@ class __RecurrentOp__(object):
return core.RecurrentOp.create(proto.SerializeToString()) return core.RecurrentOp.create(proto.SerializeToString())
Operator = OperatorFactory() # Default global factory Operator = OperatorFactory() # The default global factory
RecurrentOp = __RecurrentOp__() RecurrentOp = __RecurrentOp__()
...@@ -28,14 +28,14 @@ class GetNumericGradientTest(unittest.TestCase): ...@@ -28,14 +28,14 @@ class GetNumericGradientTest(unittest.TestCase):
dX[i, :] = Y[i, :] * (dY[i, :] - d) dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX return dX
softmax_op = Operator("softmax", Logits="Logits", Out="Out") softmax_op = Operator("softmax", X="X", Y="Y")
X = numpy.random.random((2, 2)).astype("float32") X = numpy.random.random((2, 2)).astype("float32")
Y = numpy.apply_along_axis(stable_softmax, 1, X) Y = numpy.apply_along_axis(stable_softmax, 1, X)
dY = numpy.ones(Y.shape) dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY) dX = label_softmax_grad(Y, dY)
arr = get_numeric_gradient(softmax_op, {"Logits": X}, "Out", "Logits") arr = get_numeric_gradient(softmax_op, {"X": X}, "Y", "X")
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
......
...@@ -18,9 +18,9 @@ class TestSoftmaxOp(unittest.TestCase): ...@@ -18,9 +18,9 @@ class TestSoftmaxOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "softmax" self.type = "softmax"
self.inputs = {"Logits": np.random.random((10, 10)).astype("float32")} self.inputs = {"X": np.random.random((10, 10)).astype("float32")}
self.outputs = { self.outputs = {
"Out": np.apply_along_axis(stable_softmax, 1, self.inputs["Logits"]) "Y": np.apply_along_axis(stable_softmax, 1, self.inputs["X"])
} }
...@@ -28,11 +28,11 @@ class TestSoftmaxGradOp(GradientChecker): ...@@ -28,11 +28,11 @@ class TestSoftmaxGradOp(GradientChecker):
def setUp(self): def setUp(self):
self.op = create_op("softmax") self.op = create_op("softmax")
self.inputs = { self.inputs = {
"Logits": np.random.uniform(0.1, 1, [10, 10]).astype("float32") "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")
} }
def test_softmax_grad(self): def test_softmax_grad(self):
self.check_grad(self.op, self.inputs, ["Logits"], "Out") self.check_grad(self.op, self.inputs, ["X"], "Y")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册