提交 a239418b 编写于 作者: Y Yu Yang

Fix unittest for operator.py

Rename operator.py to op.py because it is conflict with protobuf
上级 53f85df1
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
import cStringIO import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
def get_all_op_protos(): def get_all_op_protos():
...@@ -146,66 +145,6 @@ class OpDescCreationMethod(object): ...@@ -146,66 +145,6 @@ class OpDescCreationMethod(object):
return False return False
def get_docstring_from_op_proto(op_proto):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Input must be OpProto")
f = cStringIO.StringIO()
f.write(op_proto.comment)
f.write("\n")
def __append_param__(name, comment, type):
# Maybe replace the following line with template engine is better.
f.write(":param ")
f.write(name)
f.write(": ")
f.write(comment)
f.write("\n")
f.write(":type ")
f.write(name)
f.write(": ")
f.write(type)
f.write("\n")
for ipt in op_proto.inputs:
__append_param__(ipt.name, ipt.comment, "list | basestr"
if ipt.multiple else "basestr")
temp_var_prefix = \
"This is a temporary variable. It does not have to set by user. "
for opt in op_proto.outputs:
__append_param__(opt.name, opt.comment if not opt.temporary else
temp_var_prefix + opt.comment, "list | basestr"
if opt.multiple else "basestr")
for attr in op_proto.attrs:
attr_type = None
if attr.type == attr_type_pb2.INT:
attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT:
attr_type = "float"
elif attr.type == attr_type_pb2.STRING:
attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS:
attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS:
attr_type = "list of basestr"
if attr_type is None:
raise RuntimeError("Not supported attribute type " + attr.type)
__append_param__(attr.name, attr.comment, attr_type)
return f.getvalue()
def create_op_creation_method(op_proto): def create_op_creation_method(op_proto):
""" """
Generate op creation method for an OpProto Generate op creation method for an OpProto
...@@ -232,7 +171,7 @@ class OperatorFactory(object): ...@@ -232,7 +171,7 @@ class OperatorFactory(object):
self.op_methods = dict() self.op_methods = dict()
for op_proto in get_all_op_protos(): for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto) method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method self.op_methods[method['name']] = method
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if 'type' in kwargs: if 'type' in kwargs:
......
import unittest import unittest
import paddle.v2.framework.create_op_creation_methods as creation import paddle.v2.framework.op as op
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
...@@ -8,7 +8,7 @@ import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 ...@@ -8,7 +8,7 @@ import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
class TestGetAllProtos(unittest.TestCase): class TestGetAllProtos(unittest.TestCase):
def test_all(self): def test_all(self):
all_protos = creation.get_all_op_protos() all_protos = op.get_all_op_protos()
self.assertNotEqual(0, len(all_protos)) self.assertNotEqual(0, len(all_protos))
for each in all_protos: for each in all_protos:
...@@ -17,25 +17,25 @@ class TestGetAllProtos(unittest.TestCase): ...@@ -17,25 +17,25 @@ class TestGetAllProtos(unittest.TestCase):
class TestOpDescCreationMethod(unittest.TestCase): class TestOpDescCreationMethod(unittest.TestCase):
def test_plain_input_output(self): def test_plain_input_output(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "test" op_proto.type = "test"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "X" ipt.name = "X"
ipt.comment = "not matter" ipt.comment = "not matter"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "Y" ipt.name = "Y"
ipt.comment = "not matter" ipt.comment = "not matter"
opt = op.outputs.add() opt = op_proto.outputs.add()
opt.name = "Z" opt.name = "Z"
opt.comment = "not matter" opt.comment = "not matter"
op.comment = "not matter" op_proto.comment = "not matter"
self.assertTrue(op.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
output = method(X="a", Y="b", Z="c") output = method(X="a", Y="b", Z="c")
expected = op_desc_pb2.OpDesc() expected = op_desc_pb2.OpDesc()
...@@ -45,29 +45,29 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -45,29 +45,29 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(expected, output) self.assertEqual(expected, output)
def test_multiple_input_plain_output(self): def test_multiple_input_plain_output(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "fc" op_proto.type = "fc"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "X" ipt.name = "X"
ipt.comment = "" ipt.comment = ""
ipt.multiple = True ipt.multiple = True
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "W" ipt.name = "W"
ipt.comment = "" ipt.comment = ""
ipt.multiple = True ipt.multiple = True
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "b" ipt.name = "b"
ipt.comment = "" ipt.comment = ""
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "Y" out.name = "Y"
out.comment = "" out.comment = ""
op.comment = "" op_proto.comment = ""
self.assertTrue(op.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
generated1 = method(X="x", W="w", b="b", Y="y") generated1 = method(X="x", W="w", b="b", Y="y")
expected1 = op_desc_pb2.OpDesc() expected1 = op_desc_pb2.OpDesc()
...@@ -93,14 +93,14 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -93,14 +93,14 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(expected2, generated2) self.assertEqual(expected2, generated2)
def test_attrs(self): def test_attrs(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "test" op_proto.type = "test"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = 'X' ipt.name = 'X'
ipt.comment = "" ipt.comment = ""
def __add_attr__(name, type): def __add_attr__(name, type):
attr = op.attrs.add() attr = op_proto.attrs.add()
attr.name = name attr.name = name
attr.comment = "" attr.comment = ""
attr.type = type attr.type = type
...@@ -112,10 +112,10 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -112,10 +112,10 @@ class TestOpDescCreationMethod(unittest.TestCase):
__add_attr__("floats_attr", attr_type_pb2.FLOATS) __add_attr__("floats_attr", attr_type_pb2.FLOATS)
__add_attr__("strings_attr", attr_type_pb2.STRINGS) __add_attr__("strings_attr", attr_type_pb2.STRINGS)
op.comment = "" op_proto.comment = ""
self.assertTrue(op.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
generated = method( generated = method(
X="a", X="a",
...@@ -162,23 +162,23 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -162,23 +162,23 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(expected, generated) self.assertEqual(expected, generated)
def test_input_temporary_output(self): def test_input_temporary_output(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "test" op_proto.type = "test"
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "OUT" out.name = "OUT"
out.comment = "" out.comment = ""
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "TMP" out.name = "TMP"
out.comment = "" out.comment = ""
out.temporary = True out.temporary = True
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "OUT2" out.name = "OUT2"
out.comment = "" out.comment = ""
op.comment = "" op_proto.comment = ""
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
generated = method(OUT="a", OUT2="b") generated = method(OUT="a", OUT2="b")
desc = op_desc_pb2.OpDesc() desc = op_desc_pb2.OpDesc()
desc.outputs.extend(["a", core.var_names.temp(), "b"]) desc.outputs.extend(["a", core.var_names.temp(), "b"])
...@@ -190,60 +190,9 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -190,60 +190,9 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(generated, desc) self.assertEqual(generated, desc)
class TestOpCreationDocStr(unittest.TestCase):
def test_all(self):
op = op_proto_pb2.OpProto()
op.type = "test"
op.comment = """Test Op.
This op is used for unit test, not a real op.
"""
a = op.inputs.add()
a.name = "a"
a.comment = "Input a for test op"
a.multiple = True
b = op.inputs.add()
b.name = "b"
b.comment = "Input b for test op"
self.assertTrue(op.IsInitialized())
o1 = op.outputs.add()
o1.name = "output"
o1.comment = "The output of test op"
o2 = op.outputs.add()
o2.name = "temp output"
o2.comment = "The temporary output of test op"
o2.temporary = True
test_str = op.attrs.add()
test_str.name = "str_attr"
test_str.type = attr_type_pb2.STRING
test_str.comment = "A string attribute for test op"
actual = creation.get_docstring_from_op_proto(op)
expected_docstring = '''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self.assertEqual(expected_docstring, actual)
class TestOpCreations(unittest.TestCase): class TestOpCreations(unittest.TestCase):
def test_all(self): def test_all(self):
add_op = creation.op_creations.add_two(X="a", Y="b", Out="z") add_op = op.Operator("add_two", X="a", Y="b", Out="z")
self.assertIsNotNone(add_op) self.assertIsNotNone(add_op)
# Invoke C++ DebugString() # Invoke C++ DebugString()
self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).', self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册