提交 d4b10eef 编写于 作者: M minqiyang

Polish code

上级 bc12c2c6
...@@ -533,6 +533,10 @@ class Operator(object): ...@@ -533,6 +533,10 @@ class Operator(object):
in_arg_names.append(arg.name) in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type): elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode()) in_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
self.desc.set_input(in_proto.name, in_arg_names) self.desc.set_input(in_proto.name, in_arg_names)
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
...@@ -566,7 +570,9 @@ class Operator(object): ...@@ -566,7 +570,9 @@ class Operator(object):
elif isinstance(arg.name, six.binary_type): elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode()) out_arg_names.append(arg.name.decode())
else: else:
out_arg_names.append(six.u(arg.name)) raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
......
...@@ -401,6 +401,8 @@ class LayerHelper(object): ...@@ -401,6 +401,8 @@ class LayerHelper(object):
return input_var return input_var
if isinstance(act, six.string_types): if isinstance(act, six.string_types):
act = {'type': act} act = {'type': act}
else:
raise TypeError(str(act) + " should be unicode or str")
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'): if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
act['use_cudnn'] = self.kwargs.get('use_cudnn') act['use_cudnn'] = self.kwargs.get('use_cudnn')
......
...@@ -70,6 +70,10 @@ def switch(new_generator=None): ...@@ -70,6 +70,10 @@ def switch(new_generator=None):
def guard(new_generator=None): def guard(new_generator=None):
if isinstance(new_generator, six.string_types): if isinstance(new_generator, six.string_types):
new_generator = UniqueNameGenerator(new_generator) new_generator = UniqueNameGenerator(new_generator)
elif isinstance(new_generator, six.binary_type):
new_generator = UniqueNameGenerator(new_generator.decode())
else:
raise TypeError(str(new_generator) + " should be unicode or str")
old = switch(new_generator) old = switch(new_generator)
yield yield
switch(old) switch(old)
...@@ -73,6 +73,8 @@ def recordio(paths, buf_size=100): ...@@ -73,6 +73,8 @@ def recordio(paths, buf_size=100):
def reader(): def reader():
if isinstance(paths, six.string_types): if isinstance(paths, six.string_types):
path = paths path = paths
elif isinstance(paths, six.binary_type):
path = paths.decode()
else: else:
path = ",".join(paths) path = ",".join(paths)
f = rec.reader(path) f = rec.reader(path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册