提交 db7d8136 编写于 作者: M minqiyang

Fix CI issue

上级 ee1d08ab
...@@ -25,7 +25,7 @@ import collections ...@@ -25,7 +25,7 @@ import collections
import tarfile import tarfile
import re import re
import string import string
from six.moves import range import six
__all__ = ['build_dict', 'train', 'test', 'convert'] __all__ = ['build_dict', 'train', 'test', 'convert']
...@@ -43,13 +43,13 @@ def tokenize(pattern): ...@@ -43,13 +43,13 @@ def tokenize(pattern):
# sequential access of member files, other than # sequential access of member files, other than
# tarfile.extractfile, which does random access and might # tarfile.extractfile, which does random access and might
# destroy hard disks. # destroy hard disks.
tf = next(tarf) tf = tarf.next()
while tf != None: while tf != None:
if bool(pattern.match(tf.name)): if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization. # newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip("\n\r").translate( yield tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(
None, string.punctuation).lower().split() None, six.b(string.punctuation)).lower().split()
tf = next(tarf) tf = tarf.next()
def build_dict(pattern, cutoff): def build_dict(pattern, cutoff):
...@@ -67,7 +67,7 @@ def build_dict(pattern, cutoff): ...@@ -67,7 +67,7 @@ def build_dict(pattern, cutoff):
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary)) words, _ = list(zip(*dictionary))
word_idx = dict(list(zip(words, range(len(words))))) word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words) word_idx['<unk>'] = len(words)
return word_idx return word_idx
......
...@@ -905,10 +905,9 @@ class Block(object): ...@@ -905,10 +905,9 @@ class Block(object):
Variable: the Variable with the giving name. Variable: the Variable with the giving name.
""" """
if not isinstance(name, six.string_types): if not isinstance(name, six.string_types):
if not isinstance(name, six.binary_type): raise TypeError(
raise TypeError( "var require string as parameter, but get %s instead." %
"var require string as parameter, but get %s instead." % (type(name)))
(type(name)))
v = self.vars.get(name, None) v = self.vars.get(name, None)
if v is None: if v is None:
raise ValueError("var %s not in this block" % name) raise ValueError("var %s not in this block" % name)
......
...@@ -56,7 +56,7 @@ def resnet_cifar10(input, depth=32): ...@@ -56,7 +56,7 @@ def resnet_cifar10(input, depth=32):
return tmp return tmp
assert (depth - 2) % 6 == 0 assert (depth - 2) % 6 == 0
n = (depth - 2) / 6 n = (depth - 2) // 6
conv1 = conv_bn_layer( conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1) input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from collections import defaultdict from collections import defaultdict
from .. import core from .. import core
from .. import compat
from ..framework import Program, default_main_program, Parameter from ..framework import Program, default_main_program, Parameter
from ..backward import _rename_arg_ from ..backward import _rename_arg_
from functools import reduce from functools import reduce
...@@ -125,15 +126,15 @@ class ControlFlowGraph(object): ...@@ -125,15 +126,15 @@ class ControlFlowGraph(object):
def _has_var(self, block_desc, var_name, is_forward): def _has_var(self, block_desc, var_name, is_forward):
if is_forward: if is_forward:
return block_desc.has_var(str(var_name)) return block_desc.has_var(cpt.to_bytes(var_name))
else: else:
return block_desc.has_var_recursive(str(var_name)) return block_desc.has_var_recursive(cpt.to_bytes(var_name))
def _find_var(self, block_desc, var_name, is_forward): def _find_var(self, block_desc, var_name, is_forward):
if is_forward: if is_forward:
return block_desc.find_var(str(var_name)) return block_desc.find_var(cpt.to_bytes(var_name))
else: else:
return block_desc.find_var_recursive(str(var_name)) return block_desc.find_var_recursive(cpt.to_bytes(var_name))
def _check_var_validity(self, block_desc, x, is_forward): def _check_var_validity(self, block_desc, x, is_forward):
if str(x) == "@EMPTY@": if str(x) == "@EMPTY@":
...@@ -258,7 +259,7 @@ class ControlFlowGraph(object): ...@@ -258,7 +259,7 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with # Rename the var to the cache var already with
# memory allocated in order to reuse the memory. # memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i) _rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(str( self._program.block(block_desc.id).var(cpt.to_literal_str(
x)).desc = self._find_var(block_desc, cache_var, x)).desc = self._find_var(block_desc, cache_var,
is_forward) is_forward)
self._update_graph(x, cache_var, begin_idx=i) self._update_graph(x, cache_var, begin_idx=i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册