提交 3ec6d60c 编写于 作者: M minqiyang

Fix write bytes in dataset download

上级 e6ae1e4f
...@@ -59,6 +59,11 @@ def reader_creator(filename, sub_name, cycle=False): ...@@ -59,6 +59,11 @@ def reader_creator(filename, sub_name, cycle=False):
while True: while True:
for name in names: for name in names:
import sys
print(name)
sys.stdout.flush()
print(f.extractfile(name))
sys.stdout.flush()
batch = pickle.load(f.extractfile(name)) batch = pickle.load(f.extractfile(name))
for item in read_batch(batch): for item in read_batch(batch):
yield item yield item
......
...@@ -86,15 +86,21 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -86,15 +86,21 @@ def download(url, module_name, md5sum, save_name=None):
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
if total_length is None: if total_length is None:
with open(filename, 'w') as f: with open(filename, 'wb') as f:
shutil.copyfileobj(r.raw, f) import sys
print("write follow block")
sys.stdout.flush()
shutil.copyfileobj(cpt.to_bytes(r.raw), f)
else: else:
with open(filename, 'w') as f: with open(filename, 'wb') as f:
import sys
print("write follow length")
sys.stdout.flush()
dl = 0 dl = 0
total_length = int(total_length) total_length = int(total_length)
for data in r.iter_content(chunk_size=4096): for data in r.iter_content(chunk_size=4096):
dl += len(data) dl += len(data)
f.write(cpt.to_literal_str(data)) f.write(cpt.to_bytes(data))
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done, sys.stdout.write("\r[%s%s]" % ('=' * done,
' ' * (50 - done))) ' ' * (50 - done)))
......
...@@ -24,19 +24,20 @@ import tarfile ...@@ -24,19 +24,20 @@ import tarfile
import gzip import gzip
import itertools import itertools
import paddle.dataset.common import paddle.dataset.common
import paddle.fluid.compat as cpt
from six.moves import zip, range from six.moves import zip, range
__all__ = ['test, get_dict', 'get_embedding', 'convert'] __all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_MD5 = '387719152ae52d60422c016e92a742fc' DATA_MD5 = '387719152ae52d60422c016e92a742fc'
WORDDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/wordDict.txt' WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/wordDict.txt'
WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa' WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
VERBDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/verbDict.txt' VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/verbDict.txt'
VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c' VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c'
TRGDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/targetDict.txt' TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/targetDict.txt'
TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751' TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751'
EMB_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/emb' EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/emb'
EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7' EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7'
UNK_IDX = 0 UNK_IDX = 0
...@@ -89,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name): ...@@ -89,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name):
labels = [] labels = []
one_seg = [] one_seg = []
for word, label in zip(words_file, props_file): for word, label in zip(words_file, props_file):
word = word.strip() word = cpt.to_literal_str(word.strip())
label = label.strip().split() label = cpt.to_literal_str(label.strip().split())
if len(label) == 0: # end of sentence if len(label) == 0: # end of sentence
for i in range(len(one_seg[0])): for i in range(len(one_seg[0])):
......
...@@ -320,7 +320,7 @@ class Executor(object): ...@@ -320,7 +320,7 @@ class Executor(object):
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), ( assert isinstance(var, Variable) or isinstance(var, six.text_type), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) "Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op( global_block.append_op(
type='fetch', type='fetch',
......
...@@ -55,7 +55,7 @@ def resnet_cifar10(input, depth=32): ...@@ -55,7 +55,7 @@ def resnet_cifar10(input, depth=32):
return tmp return tmp
assert (depth - 2) % 6 == 0 assert (depth - 2) % 6 == 0
n = (depth - 2) / 6 n = (depth - 2) // 6
conv1 = conv_bn_layer( conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1) input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册