提交 9e43b4f4 编写于 作者: T tangwei12

add is_sparse

上级 16e350da
......@@ -64,21 +64,11 @@ def skip_gram_word2vec(dict_size,
non_leaf_num=non_leaf_num,
ptable=ptable,
pcode=pcode,
is_costum=True)
is_costum=True,
is_sparse=is_sparse)
return cost
def get_loss(loss1, loss2):
loss_op1 = fluid.layers.elementwise_sub(
fluid.layers.fill_constant_batch_size_like(input=loss1, shape=[-1, 1], value=margin,
dtype='float32'), cos_q_pt)
loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt)
loss_op3 = fluid.layers.elementwise_max(
fluid.layers.fill_constant_batch_size_like(input=loss_op2, shape=[-1, 1], value=0.0,
dtype='float32'), loss_op2)
avg_cost = fluid.layers.mean(loss_op3)
return avg_cost
datas = []
input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
......@@ -87,14 +77,12 @@ def skip_gram_word2vec(dict_size,
datas.append(predict_word)
if with_hsigmoid:
if max_code_length:
ptable = fluid.layers.data(
name='ptable', shape=[max_code_length], dtype='int64')
pcode = fluid.layers.data(
name='pcode', shape=[max_code_length], dtype='int64')
else:
ptable = fluid.layers.data(name='ptable', shape=[40], dtype='int64')
pcode = fluid.layers.data(name='pcode', shape=[40], dtype='int64')
ptable = fluid.layers.data(name='ptable',
shape=[max_code_length if max_code_length else 40],
dtype='int64')
pcode = fluid.layers.data(name='pcode',
shape=[max_code_length if max_code_length else 40],
dtype='int64')
datas.append(ptable)
datas.append(pcode)
......@@ -116,13 +104,13 @@ def skip_gram_word2vec(dict_size,
if with_nce:
cost_nce = nce_layer(emb, words[1], embedding_size, dict_size, 5, "uniform",
word_frequencys, None)
word_frequencys, None)
cost = cost_nce
if with_hsigmoid:
cost_hs = hsigmoid_layer(emb, words[1], words[2], words[3], dict_size)
cost = cost_hs
if with_nce and with_hsigmoid:
cost = fluid.layers.elementwise_add(cost_nce, cost)
cost = fluid.layers.elementwise_add(cost_nce, cost_hs)
avg_cost = fluid.layers.reduce_mean(cost)
......
......@@ -87,7 +87,7 @@ def parse_args():
'--with_nce',
action='store_true',
required=False,
default=True,
default=False,
help='using negtive sampling, (default: True)')
parser.add_argument(
......@@ -188,6 +188,9 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
def train():
args = parse_args()
if not args.with_nce and not args.with_hs:
logger.error("with_nce or with_hs must choose one")
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册