提交 91dad673 编写于 作者: X xujiaqi01

add slot reader

上级 13e62b7f
......@@ -23,6 +23,35 @@ class Model(object):
self._fetch_interval = 20
self._namespace = "train.model"
self._platform = envs.get_platform()
self._init_slots()
def _init_slots(self):
sparse_slots = envs.get_global_env("sparse_slots", None, "train.reader")
dense_slots = envs.get_global_env("dense_slots", None, "train.reader")
if sparse_slots is not None or dense_slots is not None:
sparse_slots = sparse_slots.strip().split(" ")
dense_slots = dense_slots.strip().split(" ")
dense_slots_shape = [[int(j) for j in i.split(":")[1].strip("[]").split(",")] for i in dense_slots]
dense_slots = [i.split(":")[0] for i in dense_slots]
self._dense_data_var = []
for i in range(len(dense_slots)):
l = fluid.layers.data(name=dense_slots[i], shape=dense_slots_shape[i], dtype="float32")
self._data_var.append(l)
self._dense_data_var.append(l)
self._sparse_data_var = []
for name in sparse_slots:
l = fluid.layers.data(name=name, shape=[1], lod_level=1, dtype="int64")
self._data_var.append(l)
self._sparse_data_var.append(l)
dataset_class = envs.get_global_env("dataset_class", None, "train.reader")
if dataset_class == "DataLoader":
self._init_dataloader()
def _init_dataloader(self):
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False)
def get_inputs(self):
return self._data_var
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
import sys
import abc
import os
......@@ -44,3 +45,58 @@ class Reader(dg.MultiSlotDataGenerator):
@abc.abstractmethod
def generate_sample(self, line):
pass
class SlotReader(dg.MultiSlotDataGenerator):
__metaclass__ = abc.ABCMeta
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
envs.set_global_envs(_config)
envs.update_workspace()
def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul
self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = dense_slots.strip().split(" ")
self.dense_slots_shape = [reduce(mul, [int(j) for j in i.split(":")[1].strip("[]").split(",")]) for i in self.dense_slots]
self.dense_slots = [i.split(":")[0] for i in self.dense_slots]
self.slots = self.dense_slots + self.sparse_slots
self.slot2index = {}
self.visit = {}
for i in range(len(self.slots)):
self.slot2index[self.slots[i]] = i
self.visit[self.slots[i]] = False
self.padding = padding
def generate_sample(self, l):
def reader():
line = l.strip().split(" ")
output = [(i, []) for i in self.slots]
for i in line:
slot_feasign = i.split(":")
slot = slot_feasign[0]
if slot not in self.slots:
continue
if slot in self.sparse_slots:
feasign = int(slot_feasign[1])
else:
feasign = float(slot_feasign[1])
output[self.slot2index[slot]][1].append(feasign)
self.visit[slot] = True
for i in self.visit:
slot = i
if not self.visit[slot]:
if i in self.dense_slots:
output[self.slot2index[i]][1].extend([self.padding] * self.dense_slots_shape[self.slot2index[i]])
else:
output[self.slot2index[i]][1].extend([self.padding])
else:
self.visit[slot] = False
yield output
return reader
......@@ -23,6 +23,7 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f
from paddlerec.core.trainer import Trainer
from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader
class TranspileTrainer(Trainer):
......@@ -50,14 +51,23 @@ class TranspileTrainer(Trainer):
namespace = "evaluate.reader"
class_name = "EvaluateReader"
sparse_slots = envs.get_global_env("sparse_slots", None, namespace)
dense_slots = envs.get_global_env("dense_slots", None, namespace)
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
print("batch_size: {}".format(batch_size))
reader = dataloader_instance.dataloader(
reader_class, state, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
if sparse_slots is None and dense_slots is None:
reader_class = envs.get_global_env("class", None, namespace)
reader = dataloader_instance.dataloader(
reader_class, state, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
reader_ins = reader_class(self._config_yaml)
else:
reader = dataloader_instance.slotdataloader("", state, self._config_yaml)
reader_ins = SlotReader(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
else:
......@@ -93,13 +103,23 @@ class TranspileTrainer(Trainer):
train_data_path = envs.get_global_env(
"test_data_path", None, namespace)
sparse_slots = envs.get_global_env("sparse_slots", None, namespace)
dense_slots = envs.get_global_env("dense_slots", None, namespace)
threads = int(envs.get_runtime_environ("train.trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
pipe_cmd = "python {} {} {} {}".format(
reader, reader_class, state, self._config_yaml)
if sparse_slots is None and dense_slots is None:
pipe_cmd = "python {} {} {} {}".format(
reader, reader_class, state, self._config_yaml)
else:
padding = envs.get_global_env("padding", 0, namespace)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, namespace, \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
if train_data_path.startswith("paddlerec::"):
package_base = envs.get_runtime_environ("PACKAGE_BASE")
......@@ -147,9 +167,6 @@ class TranspileTrainer(Trainer):
if not need_save(epoch_id, save_interval, False):
return
# print("save inference model is not supported now.")
# return
feed_varnames = envs.get_global_env(
"save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env(
......@@ -218,6 +235,7 @@ class TranspileTrainer(Trainer):
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(infer_program, startup_program):
self.model._init_slots()
self.model.infer_net()
if self.model._infer_data_loader is None:
......
......@@ -19,6 +19,7 @@ import sys
from paddlerec.core.utils.envs import lazy_instance_by_fliename
from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader
def dataloader(readerclass, train, yaml_file):
......@@ -63,3 +64,49 @@ def dataloader(readerclass, train, yaml_file):
if hasattr(reader, 'generate_batch_from_trainfiles'):
return gen_batch_reader()
return gen_reader
def slotdataloader(readerclass, train, yaml_file):
if train == "TRAIN":
reader_name = "SlotReader"
namespace = "train.reader"
data_path = get_global_env("train_data_path", None, namespace)
else:
reader_name = "SlotReader"
namespace = "evaluate.reader"
data_path = get_global_env("test_data_path", None, namespace)
if data_path.startswith("paddlerec::"):
package_base = get_runtime_environ("PACKAGE_BASE")
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env("sparse_slots", None, namespace)
dense = get_global_env("dense_slots", None, namespace)
padding = get_global_env("padding", 0, namespace)
reader = SlotReader(yaml_file)
reader.init(sparse, dense, int(padding))
def gen_reader():
for file in files:
with open(file, 'r') as f:
for line in f:
line = line.rstrip('\n')
iter = reader.generate_sample(line)
for parsed_line in iter():
if parsed_line is None:
continue
else:
values = []
for pased in parsed_line:
values.append(pased[1])
yield values
def gen_batch_reader():
return reader.generate_batch_from_trainfiles(files)
if hasattr(reader, 'generate_batch_from_trainfiles'):
return gen_batch_reader()
return gen_reader
......@@ -15,19 +15,33 @@ from __future__ import print_function
import sys
from paddlerec.core.utils.envs import lazy_instance_by_fliename
from paddlerec.core.reader import SlotReader
from paddlerec.core.utils import envs
if len(sys.argv) != 4:
raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path")
if len(sys.argv) < 4:
raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate/slotreader 3.yaml_abs_path")
reader_package = sys.argv[1]
if sys.argv[2] == "TRAIN":
if sys.argv[2].upper() == "TRAIN":
reader_name = "TrainReader"
else:
elif sys.argv[2].upper() == "EVALUATE":
reader_name = "EvaluateReader"
else:
reader_name = "SlotReader"
namespace = sys.argv[4]
sparse_slots = sys.argv[5].replace("#", " ")
dense_slots = sys.argv[6].replace("#", " ")
padding = int(sys.argv[7])
yaml_abs_path = sys.argv[3]
reader_class = lazy_instance_by_fliename(reader_package, reader_name)
reader = reader_class(yaml_abs_path)
reader.init()
reader.run_from_stdin()
if reader_name != "SlotReader":
reader_class = lazy_instance_by_fliename(reader_package, reader_name)
reader = reader_class(yaml_abs_path)
reader.init()
reader.run_from_stdin()
else:
reader = SlotReader(yaml_abs_path)
reader.init(sparse_slots, dense_slots, padding)
reader.run_from_stdin()
......@@ -88,18 +88,3 @@ python -m paddlerec.run -m paddlerec.models.contentunderstanding.classification
| ag news dataset | TagSpace | -- | -- | -- | -- |
| -- | Classification | -- | -- | -- | -- |
## 分布式
### 模型训练性能 (样本/s)
| 数据集 | 模型 | 单机 | 同步 (4节点) | 同步 (8节点) | 同步 (16节点) | 同步 (32节点) |
| :------------------: | :--------------------: | :---------: |:---------: |:---------: |:---------: |:---------: |
| -- | TagSpace | -- | -- | -- | -- | -- |
| -- | Classification | -- | -- | -- | -- | -- |
----
| 数据集 | 模型 | 单机 | 异步 (4节点) | 异步 (8节点) | 异步 (16节点) | 异步 (32节点) |
| :------------------: | :--------------------: | :---------: |:---------: |:---------: |:---------: |:---------: |
| -- | TagSpace | -- | -- | -- | -- | -- |
| -- | Classification | -- | -- | -- | -- | -- |
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册