提交 d4b10eef 编写于 作者: M minqiyang

Polish code

上级 bc12c2c6
......@@ -533,6 +533,10 @@ class Operator(object):
in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
......@@ -566,7 +570,9 @@ class Operator(object):
elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode())
else:
out_arg_names.append(six.u(arg.name))
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
......
......@@ -401,6 +401,8 @@ class LayerHelper(object):
return input_var
if isinstance(act, six.string_types):
act = {'type': act}
else:
raise TypeError(str(act) + " should be unicode or str")
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
act['use_cudnn'] = self.kwargs.get('use_cudnn')
......
......@@ -70,6 +70,10 @@ def switch(new_generator=None):
def guard(new_generator=None):
if isinstance(new_generator, six.string_types):
new_generator = UniqueNameGenerator(new_generator)
elif isinstance(new_generator, six.binary_type):
new_generator = UniqueNameGenerator(new_generator.decode())
else:
raise TypeError(str(new_generator) + " should be unicode or str")
old = switch(new_generator)
yield
switch(old)
......@@ -73,6 +73,8 @@ def recordio(paths, buf_size=100):
def reader():
if isinstance(paths, six.string_types):
path = paths
elif isinstance(paths, six.binary_type):
path = paths.decode()
else:
path = ",".join(paths)
f = rec.reader(path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册