提交 e7b3a5f1 编写于 作者: Y Yu Yang

Follow comments

上级 a125ef1a
...@@ -21,7 +21,7 @@ import data_type ...@@ -21,7 +21,7 @@ import data_type
import topology import topology
import data_feeder import data_feeder
import networks import networks
import evaluators import evaluator
from . import dataset from . import dataset
from . import reader from . import reader
from . import plot from . import plot
...@@ -36,7 +36,7 @@ import plot ...@@ -36,7 +36,7 @@ import plot
__all__ = [ __all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader', 'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
'topology', 'networks', 'infer', 'plot', 'evaluators' 'topology', 'networks', 'infer', 'plot', 'evaluator'
] ]
......
...@@ -20,21 +20,28 @@ __all__ = [] ...@@ -20,21 +20,28 @@ __all__ = []
def initialize(): def initialize():
def convert_to_new_name(nm):
return nm[:-len("_evaluator")]
for __ev_name__ in filter(lambda x: x.endswith('_evaluator'), evs.__all__): for __ev_name__ in filter(lambda x: x.endswith('_evaluator'), evs.__all__):
__ev__ = getattr(evs, __ev_name__) __ev__ = getattr(evs, __ev_name__)
if hasattr(__ev__, 'argspec'): if hasattr(__ev__, 'argspec'):
argspec = __ev__.argspec argspec = __ev__.argspec
else: else:
argspec = inspect.getargspec(__ev__) argspec = inspect.getargspec(__ev__)
parent_names = filter(lambda x: x in ['input', 'label'], argspec.args) parent_names = filter(lambda x: x in ['input', 'label', 'weight'],
argspec.args)
v2_ev = __convert_to_v2__( v2_ev = __convert_to_v2__(
__ev_name__, __ev_name__,
parent_names=parent_names, parent_names=parent_names,
is_default_name='name' in argspec.args, is_default_name='name' in argspec.args,
attach_parent=True) attach_parent=True)
globals()[__ev_name__] = v2_ev
globals()[__ev_name__].__name__ = __ev_name__ __new_name__ = convert_to_new_name(__ev_name__)
__all__.append(__ev_name__)
globals()[__new_name__] = v2_ev
globals()[__new_name__].__name__ = __new_name__
__all__.append(__new_name__)
initialize() initialize()
...@@ -19,7 +19,7 @@ import paddle.v2.data_type as data_type ...@@ -19,7 +19,7 @@ import paddle.v2.data_type as data_type
import paddle.v2.layer as layer import paddle.v2.layer as layer
import paddle.v2.pooling as pooling import paddle.v2.pooling as pooling
import paddle.v2.networks as networks import paddle.v2.networks as networks
import paddle.v2.evaluators as evaluators import paddle.v2.evaluator as evaluator
pixel = layer.data(name='pixel', type=data_type.dense_vector(128)) pixel = layer.data(name='pixel', type=data_type.dense_vector(128))
label = layer.data(name='label', type=data_type.integer_value(10)) label = layer.data(name='label', type=data_type.integer_value(10))
...@@ -273,7 +273,7 @@ class EvaluatorTest(unittest.TestCase): ...@@ -273,7 +273,7 @@ class EvaluatorTest(unittest.TestCase):
lbl = layer.data(name='label', type=data_type.integer_value(10)) lbl = layer.data(name='label', type=data_type.integer_value(10))
cost = layer.cross_entropy_cost(input=output, label=lbl) cost = layer.cross_entropy_cost(input=output, label=lbl)
evaluators.classification_error_evaluator(input=output, label=lbl) evaluator.classification_error(input=output, label=lbl)
print layer.parse_network(cost) print layer.parse_network(cost)
print layer.parse_network(output) print layer.parse_network(output)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册