提交 f8cddad5 编写于 作者: T tangwei12

rewrite pyreader

上级 93ca734c
...@@ -96,31 +96,33 @@ def skip_gram_word2vec(dict_size, ...@@ -96,31 +96,33 @@ def skip_gram_word2vec(dict_size,
data_lod_levels.append(1) data_lod_levels.append(1)
data_types.append('int64') data_types.append('int64')
py_reader = fluid.layers.py_reader(capacity=64, datas = []
shapes=data_shapes,
lod_levels=data_lod_levels, input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
dtypes=data_types, predict_word = fluid.layers.data(name='predict_word', shape=[1], dtype='int64')
name='py_reader',
use_double_buffer=True)
word_and_label = fluid.layers.read_file(py_reader) datas.append(input_word, predict_word)
cost = None cost = None
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=word_and_label[0], input=input_word,
is_sparse=is_sparse, is_sparse=is_sparse,
size=[dict_size, embedding_size], size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(dict_size)))) scale=1 / math.sqrt(dict_size))))
if with_nce: if with_nce:
cost = nce_layer(emb, word_and_label[1], embedding_size, dict_size, 5, cost = nce_layer(emb, predict_word, embedding_size, dict_size, 5, "uniform",
"uniform", word_frequencys, None) word_frequencys, None)
if with_hsigmoid: if with_hsigmoid:
cost = hsigmoid_layer(emb, word_and_label[1], dict_size, max_code_length, cost = hsigmoid_layer(emb, predict_word, dict_size, max_code_length, datas)
None)
avg_cost = fluid.layers.reduce_mean(cost) avg_cost = fluid.layers.reduce_mean(cost)
py_reader = fluid.layers.create_py_reader_by_data(capacity=64,
feed_list=datas,
name='py_reader',
use_double_buffer=True)
return avg_cost, py_reader return avg_cost, py_reader
...@@ -107,7 +107,10 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -107,7 +107,10 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
start = time.clock() start = time.clock()
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
if os.getenv("NUM_THREADS", ""):
exec_strategy.num_threads = int(os.getenv("NUM_THREADS")) exec_strategy.num_threads = int(os.getenv("NUM_THREADS"))
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册