提交 c173e64b 编写于 作者: T tangwei12

add is_sparse and args optimize

上级 873b5475
......@@ -28,12 +28,10 @@ def skip_gram_word2vec(dict_size,
embedding_size,
max_code_length=None,
with_hsigmoid=False,
with_nce=True):
with_nce=True,
is_sparse=False):
def nce_layer(input, label, embedding_size, num_total_classes,
num_neg_samples, sampler, custom_dist, sample_weight):
# convert word_frequencys to tensor
nid_freq_arr = np.array(word_frequencys).astype('float32')
nid_freq_var = fluid.layers.assign(input=nid_freq_arr)
num_neg_samples, sampler, word_frequencys, sample_weight):
w_param_name = "nce_w"
b_param_name = "nce_b"
......@@ -48,11 +46,11 @@ def skip_gram_word2vec(dict_size,
label=label,
num_total_classes=num_total_classes,
sampler=sampler,
custom_dist=nid_freq_var,
custom_dist=word_frequencys,
sample_weight=sample_weight,
param_attr=fluid.ParamAttr(name=w_param_name),
bias_attr=fluid.ParamAttr(name=b_param_name),
num_neg_samples=num_neg_samples)
num_neg_samples=num_neg_samples, is_sparse=is_sparse)
return cost
......@@ -76,8 +74,8 @@ def skip_gram_word2vec(dict_size,
non_leaf_num = dict_size
cost = fluid.layers.hsigmoid(
input=emb,
label=predict_word,
input=input,
label=label,
non_leaf_num=non_leaf_num,
ptable=ptable,
pcode=pcode,
......@@ -86,13 +84,13 @@ def skip_gram_word2vec(dict_size,
return cost
input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
predict_word = fluid.layers.data(
name='predict_word', shape=[1], dtype='int64')
predict_word = fluid.layers.data(name='predict_word', shape=[1], dtype='int64')
cost = None
data_list = [input_word, predict_word]
emb = fluid.layers.embedding(
input=input_word,
is_sparse=is_sparse,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(dict_size))))
......
......@@ -5,7 +5,7 @@ import logging
import os
import time
# disable gpu training for this example
# disable gpu training for this example
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
......@@ -57,6 +57,31 @@ def parse_args():
default=64,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--with_hs',
action='store_true',
required=False,
default=False,
help='using hierarchical sigmoid, (default: False)')
parser.add_argument(
'--with_nce',
action='store_true',
required=False,
default=True,
help='using negtive sampling, (default: True)')
parser.add_argument(
'--max_code_length',
type=int,
default=40,
help='max code length used by hierarchical sigmoid, (default: 40)')
parser.add_argument(
'--is_sparse',
type=bool,
default=False,
help='embedding and nce will use sparse or not, (default: False)')
parser.add_argument(
'--is_local',
type=int,
......@@ -88,21 +113,6 @@ def parse_args():
type=int,
default=1,
help='The num of trianers, (default: 1)')
parser.add_argument(
'--with_hs',
type=int,
default=0,
help='using hierarchical sigmoid, (default: 0)')
parser.add_argument(
'--with_nce',
type=int,
default=1,
help='using negtive sampling, (default: 1)')
parser.add_argument(
'--max_code_length',
type=int,
default=40,
help='max code length used by hierarchical sigmoid, (default: 40)')
return parser.parse_args()
......@@ -142,8 +152,7 @@ def train_loop(args, train_program, reader, data_list, loss, trainer_num,
[loss], exe)
model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if args.trainer_id == 0:
fluid.io.save_inference_model(model_dir, data_name_list, [loss],
exe)
fluid.io.save_inference_model(model_dir, data_name_list, [loss], exe)
def train():
......@@ -156,12 +165,12 @@ def train():
args.train_data_path)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
logger.info("word_frequencys length: {}".format(
len(word2vec_reader.word_frequencys)))
loss, data_list = skip_gram_word2vec(
word2vec_reader.dict_size, word2vec_reader.word_frequencys,
args.embedding_size, args.max_code_length, args.with_hs, args.with_nce)
args.embedding_size, args.max_code_length,
args.with_hs, args.with_nce, is_sparse=args.is_sparse)
optimizer = fluid.optimizer.Adam(learning_rate=1e-3)
optimizer.minimize(loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册