提交 ee5d44d5 编写于 作者: M malin10

Merge branch 'infer_dssm_w2v' into 'develop'

add infer, word2vec, dssm

See merge request !9
...@@ -16,7 +16,10 @@ class Model(object): ...@@ -16,7 +16,10 @@ class Model(object):
self._cost = None self._cost = None
self._metrics = {} self._metrics = {}
self._data_var = [] self._data_var = []
self._infer_data_var = []
self._infer_results = {}
self._data_loader = None self._data_loader = None
self._infer_data_loader = None
self._fetch_interval = 20 self._fetch_interval = 20
self._namespace = "train.model" self._namespace = "train.model"
self._platform = envs.get_platform() self._platform = envs.get_platform()
...@@ -24,6 +27,12 @@ class Model(object): ...@@ -24,6 +27,12 @@ class Model(object):
def get_inputs(self): def get_inputs(self):
return self._data_var return self._data_var
def get_infer_inputs(self):
return self._infer_data_var
def get_infer_results(self):
return self._infer_results
def get_cost_op(self): def get_cost_op(self):
"""R """R
""" """
......
...@@ -62,7 +62,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -62,7 +62,7 @@ class SingleTrainer(TranspileTrainer):
context['status'] = 'train_pass' context['status'] = 'train_pass'
def dataloader_train(self, context): def dataloader_train(self, context):
reader = self._get_dataloader() reader = self._get_dataloader("TRAIN")
epochs = envs.get_global_env("train.epochs") epochs = envs.get_global_env("train.epochs")
program = fluid.compiler.CompiledProgram( program = fluid.compiler.CompiledProgram(
...@@ -98,11 +98,12 @@ class SingleTrainer(TranspileTrainer): ...@@ -98,11 +98,12 @@ class SingleTrainer(TranspileTrainer):
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
self.save(epoch, "train", is_fleet=False)
context['status'] = 'infer_pass' context['status'] = 'infer_pass'
def dataset_train(self, context): def dataset_train(self, context):
dataset = self._get_dataset() dataset = self._get_dataset("TRAIN")
epochs = envs.get_global_env("train.epochs") epochs = envs.get_global_env("train.epochs")
for i in range(epochs): for i in range(epochs):
...@@ -115,6 +116,52 @@ class SingleTrainer(TranspileTrainer): ...@@ -115,6 +116,52 @@ class SingleTrainer(TranspileTrainer):
context['status'] = 'infer_pass' context['status'] = 'infer_pass'
def infer(self, context): def infer(self, context):
infer_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(infer_program, startup_program):
self.model.infer_net()
if self.model._infer_data_loader is None:
context['status'] = 'terminal_pass'
return
reader = self._get_dataloader("Evaluate")
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_infer_results().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models:
print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program)
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 2 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'terminal_pass' context['status'] = 'terminal_pass'
def terminal(self, context): def terminal(self, context):
......
...@@ -36,15 +36,22 @@ class TranspileTrainer(Trainer): ...@@ -36,15 +36,22 @@ class TranspileTrainer(Trainer):
def processor_register(self): def processor_register(self):
print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first") print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
def _get_dataloader(self): def _get_dataloader(self, state):
namespace = "train.reader" if state == "TRAIN":
dataloader = self.model._data_loader dataloader = self.model._data_loader
namespace = "train.reader"
class_name = "TrainReader"
else:
dataloader = self.model._infer_data_loader
namespace = "evaluate.reader"
class_name = "EvaluateReader"
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
reader = dataloader_instance.dataloader(reader_class, "TRAIN", self._config_yaml) reader = dataloader_instance.dataloader(reader_class, state, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader") reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml) reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins,'generate_batch_from_trainfiles'): if hasattr(reader_ins,'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader) dataloader.set_sample_list_generator(reader)
...@@ -52,18 +59,22 @@ class TranspileTrainer(Trainer): ...@@ -52,18 +59,22 @@ class TranspileTrainer(Trainer):
dataloader.set_sample_generator(reader, batch_size) dataloader.set_sample_generator(reader, batch_size)
return dataloader return dataloader
def _get_dataset(self): def _get_dataset(self, state):
if state == "TRAIN":
inputs = self.model.get_inputs()
namespace = "train.reader" namespace = "train.reader"
train_data_path = envs.get_global_env("train_data_path", None, namespace)
else:
inputs = self.model.get_infer_inputs()
namespace = "evaluate.reader"
train_data_path = envs.get_global_env("test_data_path", None, namespace)
inputs = self.model.get_inputs()
threads = int(envs.get_runtime_environ("train.trainer.threads")) threads = int(envs.get_runtime_environ("train.trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) pipe_cmd = "python {} {} {} {}".format(reader, reader_class, state, self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
if train_data_path.startswith("fleetrec::"): if train_data_path.startswith("fleetrec::"):
package_base = envs.get_runtime_environ("PACKAGE_BASE") package_base = envs.get_runtime_environ("PACKAGE_BASE")
...@@ -104,7 +115,7 @@ class TranspileTrainer(Trainer): ...@@ -104,7 +115,7 @@ class TranspileTrainer(Trainer):
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace) feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace) fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames] fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None, namespace) dirname = envs.get_global_env("save.inference.dirname", None, namespace)
assert dirname is not None assert dirname is not None
...@@ -136,6 +147,7 @@ class TranspileTrainer(Trainer): ...@@ -136,6 +147,7 @@ class TranspileTrainer(Trainer):
save_persistables() save_persistables()
save_inference_model() save_inference_model()
def instance(self, context): def instance(self, context):
models = envs.get_global_env("train.model.models") models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance_by_fliename(models, "Model") model_class = envs.lazy_instance_by_fliename(models, "Model")
......
...@@ -22,13 +22,13 @@ from fleetrec.core.utils.envs import get_runtime_environ ...@@ -22,13 +22,13 @@ from fleetrec.core.utils.envs import get_runtime_environ
def dataloader(readerclass, train, yaml_file): def dataloader(readerclass, train, yaml_file):
namespace = "train.reader"
if train == "TRAIN": if train == "TRAIN":
reader_name = "TrainReader" reader_name = "TrainReader"
namespace = "train.reader"
data_path = get_global_env("train_data_path", None, namespace) data_path = get_global_env("train_data_path", None, namespace)
else: else:
reader_name = "EvaluateReader" reader_name = "EvaluateReader"
namespace = "evaluate.reader"
data_path = get_global_env("test_data_path", None, namespace) data_path = get_global_env("test_data_path", None, namespace)
if data_path.startswith("fleetrec::"): if data_path.startswith("fleetrec::"):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
reader:
batch_size: 1
class: "{workspace}/synthetic_evaluate_reader.py"
test_data_path: "{workspace}/data/train"
train:
trainer:
# for cluster training
strategy: "async"
epochs: 4
workspace: "fleetrec.models.match.dssm"
reader:
batch_size: 4
class: "{workspace}/synthetic_reader.py"
train_data_path: "{workspace}/data/train"
model:
models: "{workspace}/model.py"
hyper_parameters:
TRIGRAM_D: 1000
NEG: 4
fc_sizes: [300, 300, 128]
fc_acts: ['tanh', 'tanh', 'tanh']
learning_rate: 0.01
optimizer: sgd
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
feed_varnames: ["query", "doc_pos"]
fetch_varnames: ["cos_sim_0.tmp_0"]
save_last: True
因为 它太大了无法显示 source diff 。你可以改为 查看blob
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle.fluid as fluid
from fleetrec.core.utils import envs
from fleetrec.core.model import Model as ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def input(self):
TRIGRAM_D = envs.get_global_env("hyper_parameters.TRIGRAM_D", None, self._namespace)
Neg = envs.get_global_env("hyper_parameters.NEG", None, self._namespace)
self.query = fluid.data(name="query", shape=[-1, TRIGRAM_D], dtype='float32', lod_level=0)
self.doc_pos = fluid.data(name="doc_pos", shape=[-1, TRIGRAM_D], dtype='float32', lod_level=0)
self.doc_negs = [fluid.data(name="doc_neg_" + str(i), shape=[-1, TRIGRAM_D], dtype="float32", lod_level=0) for i in range(Neg)]
self._data_var.append(self.query)
self._data_var.append(self.doc_pos)
for input in self.doc_negs:
self._data_var.append(input)
if self._platform != "LINUX":
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False)
def net(self, is_infer=False):
hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes", None, self._namespace)
hidden_acts = envs.get_global_env("hyper_parameters.fc_acts", None, self._namespace)
def fc(data, hidden_layers, hidden_acts, names):
fc_inputs = [data]
for i in range(len(hidden_layers)):
xavier=fluid.initializer.Xavier(uniform=True, fan_in=fc_inputs[-1].shape[1], fan_out=hidden_layers[i])
out = fluid.layers.fc(input=fc_inputs[-1],
size=hidden_layers[i],
act=hidden_acts[i],
param_attr=xavier,
bias_attr=xavier,
name=names[i])
fc_inputs.append(out)
return fc_inputs[-1]
query_fc = fc(self.query, hidden_layers, hidden_acts, ['query_l1', 'query_l2', 'query_l3'])
doc_pos_fc = fc(self.doc_pos, hidden_layers, hidden_acts, ['doc_pos_l1', 'doc_pos_l2', 'doc_pos_l3'])
self.R_Q_D_p = fluid.layers.cos_sim(query_fc, doc_pos_fc)
if is_infer:
return
R_Q_D_ns = []
for i, doc_neg in enumerate(self.doc_negs):
doc_neg_fc_i = fc(doc_neg, hidden_layers, hidden_acts, ['doc_neg_l1_' + str(i), 'doc_neg_l2_' + str(i), 'doc_neg_l3_' + str(i)])
R_Q_D_ns.append(fluid.layers.cos_sim(query_fc, doc_neg_fc_i))
concat_Rs = fluid.layers.concat(input=[self.R_Q_D_p] + R_Q_D_ns, axis=-1)
prob = fluid.layers.softmax(concat_Rs, axis=1)
hit_prob = fluid.layers.slice(prob, axes=[0,1], starts=[0,0], ends=[4, 1])
loss = -fluid.layers.reduce_sum(fluid.layers.log(hit_prob))
self.avg_cost = fluid.layers.mean(x=loss)
def infer_results(self):
self._infer_results['query_doc_sim'] = self.R_Q_D_p
def avg_loss(self):
self._cost = self.avg_cost
def metrics(self):
self._metrics["LOSS"] = self.avg_cost
def train_net(self):
self.input()
self.net(is_infer=False)
self.avg_loss()
self.metrics()
def optimizer(self):
learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace)
optimizer = fluid.optimizer.SGD(learning_rate)
return optimizer
def infer_input(self):
TRIGRAM_D = envs.get_global_env("hyper_parameters.TRIGRAM_D", None, self._namespace)
self.query = fluid.data(name="query", shape=[-1, TRIGRAM_D], dtype='float32', lod_level=0)
self.doc_pos = fluid.data(name="doc_pos", shape=[-1, TRIGRAM_D], dtype='float32', lod_level=0)
self._infer_data_var = [self.query, self.doc_pos]
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False)
def infer_net(self):
self.infer_input()
self.net(is_infer=True)
self.infer_results()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class EvaluateReader(Reader):
def init(self):
pass
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
query = map(float, features[0].split(','))
pos_doc = map(float, features[1].split(','))
feature_names = ['query', 'doc_pos']
yield zip(feature_names, [query] + [pos_doc])
return reader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class TrainReader(Reader):
def init(self):
pass
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
query = map(float, features[0].split(','))
pos_doc = map(float, features[1].split(','))
feature_names = ['query', 'doc_pos']
neg_docs = []
for i in range(len(features) - 2):
feature_names.append('doc_neg_' + str(i))
neg_docs.append(map(float, features[i+2].split(',')))
yield zip(feature_names, [query] + [pos_doc] + neg_docs)
return reader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
workspace: "fleetrec.models.recall.word2vec"
reader:
batch_size: 50
class: "{workspace}/w2v_evaluate_reader.py"
test_data_path: "{workspace}/data/test"
word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt"
train:
trainer:
# for cluster training
strategy: "async"
epochs: 2
workspace: "fleetrec.models.recall.word2vec"
reader:
batch_size: 100
class: "{workspace}/w2v_reader.py"
train_data_path: "{workspace}/data/train"
word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt"
model:
models: "{workspace}/model.py"
hyper_parameters:
sparse_feature_number: 85
sparse_feature_dim: 300
with_shuffle_batch: False
neg_num: 5
window_size: 5
learning_rate: 1.0
decay_steps: 100000
decay_rate: 0.999
optimizer: sgd
save:
increment:
dirname: "increment"
epoch_interval: 1
save_last: True
inference:
dirname: "inference"
epoch_interval: 1
save_last: True
<UNK> 2541
the 256
to 135
of 122
a 106
in 97
and 94
that 54
for 49
is 47
on 44
s 43
at 37
said 34
be 31
with 27
will 26
are 25
have 24
was 23
it 22
more 20
who 20
an 19
as 19
by 18
his 18
from 18
they 17
not 16
their 16
has 15
there 15
this 15
but 15
we 13
he 13
been 12
out 12
new 11
would 11
than 11
were 11
year 10
or 10
us 10
had 9
first 9
all 9
two 9
after 8
them 8
t 8
most 8
last 8
some 8
so 8
i 8
even 7
when 7
according 7
its 7
during 7
per 7
because 7
up 7
she 7
home 7
about 7
mr 6
do 6
if 6
just 6
no 6
time 6
team 6
may 6
years 6
city 6
only 6
world 6
you 6
including 6
day 6
cent 6
and 6
all 48
because 64
just 72
per 63
when 59
is 9
year 43
some 55
it 20
an 23
as 24
including 82
at 12
have 18
in 5
home 67
its 61
<UNK> 0
even 58
city 78
said 13
from 27
for 8
their 30
there 32
had 46
two 49
been 37
than 41
up 65
to 2
only 79
time 74
new 39
you 81
has 31
was 19
day 83
more 21
be 14
we 35
his 26
may 76
do 70
that 7
mr 69
she 66
team 75
who 22
but 34
if 71
most 53
cent 84
them 51
they 28
not 29
during 62
years 77
with 15
by 25
after 50
he 36
a 4
on 10
about 68
last 54
would 40
world 80
this 33
of 3
no 73
according 60
us 45
will 16
i 57
s 11
so 56
t 52
were 42
the 1
first 47
out 38
or 44
are 17
Athens Greece Baghdad Iraq
Athens Greece Bangkok Thailand
Athens Greece Beijing China
Athens Greece Berlin Germany
Athens Greece Bern Switzerland
Athens Greece Cairo Egypt
Athens Greece Canberra Australia
Athens Greece Hanoi Vietnam
Athens Greece Havana Cuba
Athens Greece Helsinki Finland
Athens Greece Islamabad Pakistan
Athens Greece Kabul Afghanistan
Athens Greece London England
Athens Greece Madrid Spain
Athens Greece Moscow Russia
Athens Greece Oslo Norway
Athens Greece Ottawa Canada
Athens Greece Paris France
Athens Greece Rome Italy
Athens Greece Stockholm Sweden
Athens Greece Tehran Iran
Athens Greece Tokyo Japan
Baghdad Iraq Bangkok Thailand
Baghdad Iraq Beijing China
Baghdad Iraq Berlin Germany
Baghdad Iraq Bern Switzerland
Baghdad Iraq Cairo Egypt
Baghdad Iraq Canberra Australia
Baghdad Iraq Hanoi Vietnam
Baghdad Iraq Havana Cuba
Baghdad Iraq Helsinki Finland
Baghdad Iraq Islamabad Pakistan
Baghdad Iraq Kabul Afghanistan
Baghdad Iraq London England
Baghdad Iraq Madrid Spain
Baghdad Iraq Moscow Russia
Baghdad Iraq Oslo Norway
Baghdad Iraq Ottawa Canada
Baghdad Iraq Paris France
Baghdad Iraq Rome Italy
Baghdad Iraq Stockholm Sweden
Baghdad Iraq Tehran Iran
Baghdad Iraq Tokyo Japan
Baghdad Iraq Athens Greece
Bangkok Thailand Beijing China
Bangkok Thailand Berlin Germany
Bangkok Thailand Bern Switzerland
Bangkok Thailand Cairo Egypt
Bangkok Thailand Canberra Australia
Bangkok Thailand Hanoi Vietnam
Bangkok Thailand Havana Cuba
Bangkok Thailand Helsinki Finland
Bangkok Thailand Islamabad Pakistan
Bangkok Thailand Kabul Afghanistan
Bangkok Thailand London England
Bangkok Thailand Madrid Spain
Bangkok Thailand Moscow Russia
Bangkok Thailand Oslo Norway
Bangkok Thailand Ottawa Canada
Bangkok Thailand Paris France
Bangkok Thailand Rome Italy
Bangkok Thailand Stockholm Sweden
Bangkok Thailand Tehran Iran
Bangkok Thailand Tokyo Japan
Bangkok Thailand Athens Greece
Bangkok Thailand Baghdad Iraq
Beijing China Berlin Germany
Beijing China Bern Switzerland
Beijing China Cairo Egypt
Beijing China Canberra Australia
Beijing China Hanoi Vietnam
Beijing China Havana Cuba
Beijing China Helsinki Finland
Beijing China Islamabad Pakistan
Beijing China Kabul Afghanistan
Beijing China London England
Beijing China Madrid Spain
Beijing China Moscow Russia
Beijing China Oslo Norway
Beijing China Ottawa Canada
Beijing China Paris France
Beijing China Rome Italy
Beijing China Stockholm Sweden
Beijing China Tehran Iran
Beijing China Tokyo Japan
Beijing China Athens Greece
Beijing China Baghdad Iraq
Beijing China Bangkok Thailand
Berlin Germany Bern Switzerland
Berlin Germany Cairo Egypt
Berlin Germany Canberra Australia
Berlin Germany Hanoi Vietnam
Berlin Germany Havana Cuba
Berlin Germany Helsinki Finland
Berlin Germany Islamabad Pakistan
Berlin Germany Kabul Afghanistan
Berlin Germany London England
Berlin Germany Madrid Spain
Berlin Germany Moscow Russia
Berlin Germany Oslo Norway
Berlin Germany Ottawa Canada
Berlin Germany Paris France
Berlin Germany Rome Italy
Berlin Germany Stockholm Sweden
Berlin Germany Tehran Iran
Berlin Germany Tokyo Japan
Berlin Germany Athens Greece
Berlin Germany Baghdad Iraq
Berlin Germany Bangkok Thailand
Berlin Germany Beijing China
Bern Switzerland Cairo Egypt
Bern Switzerland Canberra Australia
Bern Switzerland Hanoi Vietnam
Bern Switzerland Havana Cuba
Bern Switzerland Helsinki Finland
Bern Switzerland Islamabad Pakistan
Bern Switzerland Kabul Afghanistan
Bern Switzerland London England
Bern Switzerland Madrid Spain
Bern Switzerland Moscow Russia
Bern Switzerland Oslo Norway
Bern Switzerland Ottawa Canada
Bern Switzerland Paris France
Bern Switzerland Rome Italy
Bern Switzerland Stockholm Sweden
Bern Switzerland Tehran Iran
Bern Switzerland Tokyo Japan
Bern Switzerland Athens Greece
Bern Switzerland Baghdad Iraq
Bern Switzerland Bangkok Thailand
Bern Switzerland Beijing China
Bern Switzerland Berlin Germany
Cairo Egypt Canberra Australia
Cairo Egypt Hanoi Vietnam
Cairo Egypt Havana Cuba
Cairo Egypt Helsinki Finland
Cairo Egypt Islamabad Pakistan
Cairo Egypt Kabul Afghanistan
Cairo Egypt London England
Cairo Egypt Madrid Spain
Cairo Egypt Moscow Russia
Cairo Egypt Oslo Norway
Cairo Egypt Ottawa Canada
Cairo Egypt Paris France
Cairo Egypt Rome Italy
Cairo Egypt Stockholm Sweden
Cairo Egypt Tehran Iran
Cairo Egypt Tokyo Japan
Cairo Egypt Athens Greece
Cairo Egypt Baghdad Iraq
Cairo Egypt Bangkok Thailand
Cairo Egypt Beijing China
Cairo Egypt Berlin Germany
Cairo Egypt Bern Switzerland
Canberra Australia Hanoi Vietnam
Canberra Australia Havana Cuba
Canberra Australia Helsinki Finland
Canberra Australia Islamabad Pakistan
Canberra Australia Kabul Afghanistan
Canberra Australia London England
Canberra Australia Madrid Spain
Canberra Australia Moscow Russia
Canberra Australia Oslo Norway
Canberra Australia Ottawa Canada
Canberra Australia Paris France
Canberra Australia Rome Italy
Canberra Australia Stockholm Sweden
Canberra Australia Tehran Iran
Canberra Australia Tokyo Japan
Canberra Australia Athens Greece
Canberra Australia Baghdad Iraq
Canberra Australia Bangkok Thailand
Canberra Australia Beijing China
Canberra Australia Berlin Germany
Canberra Australia Bern Switzerland
Canberra Australia Cairo Egypt
Hanoi Vietnam Havana Cuba
Hanoi Vietnam Helsinki Finland
Hanoi Vietnam Islamabad Pakistan
Hanoi Vietnam Kabul Afghanistan
Hanoi Vietnam London England
Hanoi Vietnam Madrid Spain
Hanoi Vietnam Moscow Russia
Hanoi Vietnam Oslo Norway
Hanoi Vietnam Ottawa Canada
Hanoi Vietnam Paris France
Hanoi Vietnam Rome Italy
Hanoi Vietnam Stockholm Sweden
Hanoi Vietnam Tehran Iran
Hanoi Vietnam Tokyo Japan
Hanoi Vietnam Athens Greece
Hanoi Vietnam Baghdad Iraq
Hanoi Vietnam Bangkok Thailand
Hanoi Vietnam Beijing China
Hanoi Vietnam Berlin Germany
Hanoi Vietnam Bern Switzerland
Hanoi Vietnam Cairo Egypt
Hanoi Vietnam Canberra Australia
Havana Cuba Helsinki Finland
Havana Cuba Islamabad Pakistan
45 8 71 53 83 58 71 28 46 3
59 68 5 82 0 81
61
52
80 2 4
18
0 45 10 10 0 8 45 5 0 10
16 16 14 10
71 73 23 32 16 0 49 53
67 6 26 5
18 37 30 65
16 75 30
1 42 25
54 43 0 6 0 10 0 66
20 13 7 49 5 46 37 0
32 1 40 55 74
16 14 3
76 29 14 3 44 13 42 44 34 3
4 80
32 37 0 3 0 22 6 8 3 62
13 75 9 6 65 79 8 24 0 24 6 73
81 0
79 7 40 14 5 6
58 56 38
23 14 6 2 51
12 24 6
18 37 55
0 14 43 50
52 53
22 19 11 6 6 41
20 68 7 66 59 66 31 48
31 2 70 15 24 24 44 72 68 14 27 6
2
28 10
35 51 6 0 64 17 4 21 13 0 11
9 33 43
26 4
4
69 29 4 8
0 76 46 0 51 30 34 20 79 22
1 49 9 25 0 25 78
10
81 57 81 72
8 34 31 29 37
38 13
9 5 6 39 54 43
81 70 18 2 53
55 7 44 21 30 0 60
19 23 3 0 39 82
28 56 27 4 38 55 2
41 17 0 43 6 21 41 27
70 29 59
5 36 36 31 26 17 8 39 78
28 64 11 8 21 41 11 16 7 16 20
8
13 40 61 68
9
57 40 72 7 71 29 2 22 29 38 1 30
0 3
39 0 4 5 39 21 41 5 54 45
22 7 1 1 0 0
46 0 0 20 40 29 3
11 0 78 4 15 82 51
0 2 33
0 21 41 19 29 2 59 36
27 3 14 0
32 63 84 63 84 3 63 84 0 63 84
36 13
13 15
36 57 35 34 54 0
13 22 31 5
3 78 2 2
27 11 57 20 20 11
67
28 70 44 58 0 28 17 7 17 29
53 11 62 17 6 17 12 30
32 81
80 0 35 22 19 6 35 51
55 33 76 0 9 0
0
56
52
42 62 0
50 1 34 38 0 58 21
54 62 0 10
13 1 42 25 4 3 3 0 0
25 26 9
28 18 39
4 49 77 32 49 33
13 0
6 11 56 52 10
15 12 74 1 8 45 44 8 0 14 12 6 12 9 8 45 0 44 76 4 3 12 11 0
35 48 23 1 0
8 5 54 15 5 1
20 38 0 48 7 30 0 17 29 32 76 14
8 46 37
64 53 0 0 24 0 13 6
0 52 0 1 3 0
55 1 43 24 34 24 71 28 42 1
83 15 57 46 24 3 40 14
61 47 23 1
31 0 26 24 25 36 16 27 12 11 33 25 43
20 34 57 52 2 70 56 7 57 52
44
62 26 69 8
74 1 6 51 33 74
49 0 22 1 0 17 32 14 21 22 3 45 26 10
5 78
64 35 18 75 5 80 0 24 53 26 0
48 83
79 61
60 1 23 9 10
50 3 1 3 24 0
1 47 27 30 67 4 83 61
32 15 69 36 19 6 7 42
34 47 33 68
63 16 38 11 67
1 50 4
65 27 78
27 48 39 16 14 76
13 0 42 34 36 20 19 33
7 19 31 37 25
5 42 64
4 42 23 8
77 50 4 31 5
9 14 5 0 3 27
19 27 1 40 2 1 77 40 29 14 2 1 25 69
33 73 7 18 25 35 29 14 58 0 0 35 70
23 6 4
53 3 0 46 4 74 58 42 1
35 27 77 8 4 77 0
0 17 48 0 0 6 22
19 0 2 43 59 0 61
20 71 79 20 14 41 1
37 73 65
9 3
0 5 10
0 6 42 8 47 74
23 9 18 62 23
47
39 50 0 50 26
69 26 66 38 14 72 15 1
6 21 33 65 24 9 2
60
16 25 22 16 15 0
18 37 0 28 50
40 75 5
36 66 11 38 0 3 36
5 26 59 66 0
45 10 6 7 31 21 41 27 4
72 30 10 4 0 83 2 30 47 67 33
17 6 64 29 0
0 30 38 12 5 18 4 0
60 83 3 55 3
0 4 0 33 43 80 8 75
5 77 0 22 30 21 41 27
36 19 3 0 82 49
6 32 17
0 10 0 62 8 82
54 11 38 4 2 19 7 35 18 39 0 16 14
37 0 47 75
61 0 58 1 48 33 32 10
47 10 73 47
17 34 2
7 56 28 0 2
39 23 15 15 6 13
9 15 0 13 45 2 14 15
0 11 0 0 72 11 13 5 26 3 0
0 19 38 12 1 3
67 12
36 26 0 5
56 60 18 37 1 44
11 13 11 40 12
19 56
57 0 22 40 35 0 51 6 28 28 13
73
34 22 65 64 28 52 44
13 1 25 63 84 6 7 12 41 63 84
69 46 4 0
17 3
0 3 13 55 3 26 46
2 2 21 7 67
45 34 0 14 21 60 2
80 11 18 34 29 60 4 14
48
27
21 41 0 66 34
54 43 0 0
79 68 13 23 5 51 8
0 49 31 23 4
59 20 48 35 16 5 8
22 0 8 26 49 39 10
37 4 24 0 5 6 65 68 11 0
11 0 2 25 7
3 82 18 0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from fleetrec.core.utils import envs
from fleetrec.core.model import Model as ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def input(self):
neg_num = int(envs.get_global_env("hyper_parameters.neg_num", None, self._namespace))
self.input_word = fluid.data(name="input_word", shape=[None, 1], dtype='int64')
self.true_word = fluid.data(name='true_label', shape=[None, 1], dtype='int64')
self._data_var.append(self.input_word)
self._data_var.append(self.true_word)
with_shuffle_batch = bool(int(envs.get_global_env("hyper_parameters.with_shuffle_batch", None, self._namespace)))
if not with_shuffle_batch:
self.neg_word = fluid.data(name="neg_label", shape=[None, neg_num], dtype='int64')
self._data_var.append(self.neg_word)
if self._platform != "LINUX":
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False)
def net(self):
is_distributed = True if envs.get_trainer() == "CtrTrainer" else False
neg_num = int(envs.get_global_env("hyper_parameters.neg_num", None, self._namespace))
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace)
sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace)
with_shuffle_batch = bool(int(envs.get_global_env("hyper_parameters.with_shuffle_batch", None, self._namespace)))
def embedding_layer(input, table_name, emb_dim, initializer_instance=None, squeeze=False):
emb = fluid.embedding(
input=input,
is_sparse=True,
is_distributed=is_distributed,
size=[sparse_feature_number, emb_dim],
param_attr=fluid.ParamAttr(
name=table_name,
initializer=initializer_instance),
)
if squeeze:
return fluid.layers.squeeze(input=emb, axes=[1])
else:
return emb
init_width = 0.5 / sparse_feature_dim
emb_initializer = fluid.initializer.Uniform(-init_width, init_width)
emb_w_initializer = fluid.initializer.Constant(value=0.0)
input_emb = embedding_layer(self.input_word, "emb", sparse_feature_dim, emb_initializer, True)
true_emb_w = embedding_layer(self.true_word, "emb_w", sparse_feature_dim, emb_w_initializer, True)
true_emb_b = embedding_layer(self.true_word, "emb_b", 1, emb_w_initializer, True)
if with_shuffle_batch:
neg_emb_w_list = []
for i in range(neg_num):
neg_emb_w_list.append(fluid.contrib.layers.shuffle_batch(true_emb_w)) # shuffle true_word
neg_emb_w_concat = fluid.layers.concat(neg_emb_w_list, axis=0)
neg_emb_w = fluid.layers.reshape(neg_emb_w_concat, shape=[-1, neg_num, sparse_feature_dim])
neg_emb_b_list = []
for i in range(neg_num):
neg_emb_b_list.append(fluid.contrib.layers.shuffle_batch(true_emb_b)) # shuffle true_word
neg_emb_b = fluid.layers.concat(neg_emb_b_list, axis=0)
neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num])
else:
neg_emb_w = embedding_layer(self.neg_word, "emb_w", sparse_feature_dim, emb_w_initializer)
neg_emb_b = embedding_layer(self.neg_word, "emb_b", 1, emb_w_initializer)
neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(input_emb, true_emb_w),
dim=1,
keep_dim=True),
true_emb_b)
input_emb_re = fluid.layers.reshape(
input_emb, shape=[-1, 1, sparse_feature_dim])
neg_matmul = fluid.layers.matmul(input_emb_re, neg_emb_w, transpose_y=True)
neg_logits = fluid.layers.elementwise_add(
fluid.layers.reshape(neg_matmul, shape=[-1, neg_num]),
neg_emb_b_vec)
label_ones = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, 1], value=1.0, dtype='float32')
label_zeros = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')
true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits,
label_ones)
neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits,
label_zeros)
cost = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
true_xent, dim=1),
fluid.layers.reduce_sum(
neg_xent, dim=1))
self.avg_cost = fluid.layers.reduce_mean(cost)
global_right_cnt = fluid.layers.create_global_var(name="global_right_cnt", persistable=True, dtype='float32', shape=[1], value=0)
global_total_cnt = fluid.layers.create_global_var(name="global_total_cnt", persistable=True, dtype='float32', shape=[1], value=0)
global_right_cnt.stop_gradient = True
global_total_cnt.stop_gradient = True
def avg_loss(self):
self._cost = self.avg_cost
def metrics(self):
self._metrics["LOSS"] = self.avg_cost
def train_net(self):
self.input()
self.net()
self.avg_loss()
self.metrics()
def optimizer(self):
learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace)
decay_steps = envs.get_global_env("hyper_parameters.decay_steps", None, self._namespace)
decay_rate = envs.get_global_env("hyper_parameters.decay_rate", None, self._namespace)
optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=learning_rate,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=True))
return optimizer
def analogy_input(self):
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace)
self.analogy_a = fluid.data(name="analogy_a", shape=[None], dtype='int64')
self.analogy_b = fluid.data(name="analogy_b", shape=[None], dtype='int64')
self.analogy_c = fluid.data(name="analogy_c", shape=[None], dtype='int64')
self.analogy_d = fluid.data(name="analogy_d", shape=[None], dtype='int64')
self._infer_data_var = [self.analogy_a, self.analogy_b, self.analogy_c, self.analogy_d]
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False)
def infer_net(self):
sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace)
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace)
def embedding_layer(input, table_name, initializer_instance=None):
emb = fluid.embedding(
input=input,
size=[sparse_feature_number, sparse_feature_dim],
param_attr=table_name)
return emb
self.analogy_input()
all_label = np.arange(sparse_feature_number).reshape(sparse_feature_number).astype('int32')
self.all_label = fluid.layers.cast(x=fluid.layers.assign(all_label), dtype='int64')
emb_all_label = embedding_layer(self.all_label, "emb")
emb_a = embedding_layer(self.analogy_a, "emb")
emb_b = embedding_layer(self.analogy_b, "emb")
emb_c = embedding_layer(self.analogy_c, "emb")
target = fluid.layers.elementwise_add(
fluid.layers.elementwise_sub(emb_b, emb_a), emb_c)
emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1)
dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True)
values, pred_idx = fluid.layers.topk(input=dist, k=4)
label = fluid.layers.expand(fluid.layers.unsqueeze(self.analogy_d, axes=[1]), expand_times=[1, 4])
label_ones = fluid.layers.fill_constant_batch_size_like(
label, shape=[-1, 1], value=1.0, dtype='float32')
right_cnt = fluid.layers.reduce_sum(
input=fluid.layers.cast(fluid.layers.equal(pred_idx, label), dtype='float32'))
total_cnt = fluid.layers.reduce_sum(label_ones)
global_right_cnt = fluid.layers.create_global_var(name="global_right_cnt", persistable=True, dtype='float32', shape=[1], value=0)
global_total_cnt = fluid.layers.create_global_var(name="global_total_cnt", persistable=True, dtype='float32', shape=[1], value=0)
global_right_cnt.stop_gradient = True
global_total_cnt.stop_gradient = True
tmp1 = fluid.layers.elementwise_add(right_cnt, global_right_cnt)
fluid.layers.assign(tmp1, global_right_cnt)
tmp2 = fluid.layers.elementwise_add(total_cnt, global_total_cnt)
fluid.layers.assign(tmp2, global_total_cnt)
acc = fluid.layers.elementwise_div(global_right_cnt, global_total_cnt, name="total_acc")
self._infer_results['acc'] = acc
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import io
import six
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class EvaluateReader(Reader):
def init(self):
dict_path = envs.get_global_env("word_id_dict_path", None, "evaluate.reader")
self.word_to_id = dict()
self.id_to_word = dict()
with io.open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
self.word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
self.id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]
self.dict_size = len(self.word_to_id)
def native_to_unicode(self, s):
if self._is_unicode(s):
return s
try:
return self._to_unicode(s)
except UnicodeDecodeError:
res = self._to_unicode(s, ignore_errors=True)
return res
def _is_unicode(self, s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def _to_unicode(self, s, ignore_errors=False):
if self._is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def strip_lines(self, line, vocab):
return self._replace_oov(vocab, self.native_to_unicode(line))
def _replace_oov(self, original_vocab, line):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return u" ".join([
word if word in original_vocab else u"<UNK>" for word in line.split()
])
def generate_sample(self, line):
def reader():
features = self.strip_lines(line.lower(), self.word_to_id)
features = features.split()
yield [('analogy_a', [self.word_to_id[features[0]]]), ('analogy_b', [self.word_to_id[features[1]]]), ('analogy_c', [self.word_to_id[features[2]]]), ('analogy_d', [self.word_to_id[features[3]]])]
return reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import io
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class NumpyRandomInt(object):
def __init__(self, a, b, buf_size=1000):
self.idx = 0
self.buffer = np.random.random_integers(a, b, buf_size)
self.a = a
self.b = b
def __call__(self):
if self.idx == len(self.buffer):
self.buffer = np.random.random_integers(self.a, self.b,
len(self.buffer))
self.idx = 0
result = self.buffer[self.idx]
self.idx += 1
return result
class TrainReader(Reader):
def init(self):
dict_path = envs.get_global_env("word_count_dict_path", None, "train.reader")
self.window_size = envs.get_global_env("hyper_parameters.window_size", None, "train.model")
self.neg_num = envs.get_global_env("hyper_parameters.neg_num", None, "train.model")
self.with_shuffle_batch = envs.get_global_env("hyper_parameters.with_shuffle_batch", None, "train.model")
self.random_generator = NumpyRandomInt(1, self.window_size + 1)
self.cs = None
if not self.with_shuffle_batch:
id_counts = []
word_all_count = 0
with io.open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, count = line.split()[0], int(line.split()[1])
id_counts.append(count)
word_all_count += count
id_frequencys = [
float(count) / word_all_count for count in id_counts
]
np_power = np.power(np.array(id_frequencys), 0.75)
id_frequencys_pow = np_power / np_power.sum()
self.cs = np.array(id_frequencys_pow).cumsum()
def get_context_words(self, words, idx):
"""
Get the context word list of target word.
words: the words of the current line
idx: input word index
window_size: window size
"""
target_window = self.random_generator()
start_point = idx - target_window # if (idx - target_window) > 0 else 0
if start_point < 0:
start_point = 0
end_point = idx + target_window
targets = words[start_point:idx] + words[idx + 1:end_point + 1]
return targets
def generate_sample(self, line):
def reader():
word_ids = [w for w in line.split()]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx)
for context_id in context_word_ids:
output = [('input_word', [int(target_id)]), ('true_label', [int(context_id)])]
if not self.with_shuffle_batch:
neg_array = self.cs.searchsorted(np.random.sample(self.neg_num))
output += [('neg_label', [int(str(i)) for i in neg_array ])]
yield output
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册