提交 d1275e63 编写于 作者: S ShawnXuan

rm usless comments

上级 3aec0ec1
......@@ -155,19 +155,17 @@ def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size):
dtype=flow.float,
initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05),
)
hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids)#, no_duplicates_in_indices=True)
hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids)
lf_ids = lf_ids - hf_vocab_size_constant
with flow.scope.placement('cpu', '0:0'):
lf_embedding_table = flow.get_variable(
name=f'lf_{name}',
shape=(vocab_size - hf_vocab_size, embedding_size),
#shape=(vocab_size, embedding_size),
dtype=flow.float,
initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05),
)
lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids)#, no_duplicates_in_indices=True)
lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids)
unique_embedding = flow.reshape(flow.zeros_like(unique_ids, dtype=flow.float), (-1, 1)) * flow.constant(0.0, dtype=flow.float, shape=(1,embedding_size))
# unique_embedding = flow.constant(0.0, dtype=flow.float, shape=(b*s, embedding_size))
unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=hf_embedding, indices=hf_indices)
unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=lf_embedding, indices=lf_indices)
unique_embedding = flow.gather(params=unique_embedding, indices=unique_ids_idx)
......@@ -309,8 +307,6 @@ def print_args(args):
for arg in vars(args):
print("{} = {}".format(arg, getattr(args, arg)))
print("-".ljust(66, "-"))
#print("Time stamp: {}".format(
# str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
def main():
......@@ -320,8 +316,6 @@ def main():
flow.config.enable_model_io_v2(True)
flow.config.enable_debug_mode(True)
flow.config.collective_boxing.nccl_enable_all_to_all(True)
#flow.config.enable_numa_aware_cuda_malloc_host(True)
#flow.config.collective_boxing.enable_fusion(False)
check_point = flow.train.CheckPoint()
check_point.init()
for i in range(FLAGS.max_iter):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册