提交 65909976 编写于 作者: T tangwei

add windons/mac support

上级 5d92ae8f
......@@ -12,6 +12,7 @@ class Model(object):
self._cost = None
self._metrics = {}
self._data_var = []
self._data_loader = None
self._fetch_interval = 20
self._namespace = "train.model"
......
......@@ -39,7 +39,11 @@ class ClusterTrainer(TranspileTrainer):
else:
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('train_pass', self.train)
if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train)
else:
self.regist_context_processor('train_pass', self.dataloader_train)
self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self):
......@@ -87,7 +91,10 @@ class ClusterTrainer(TranspileTrainer):
fleet.run_server()
context['is_exit'] = True
def train(self, context):
def dataloader_train(self, context):
pass
def dataset_train(self, context):
self._exe.run(fleet.startup_program)
fleet.init_worker()
......
......@@ -57,7 +57,7 @@ class CtrPaddleTrainer(Trainer):
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'reader_instance.py')
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
......
......@@ -22,6 +22,7 @@ import paddle.fluid as fluid
from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer
from fleetrec.core.utils import envs
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
......@@ -32,7 +33,12 @@ class SingleTrainer(TranspileTrainer):
def processor_register(self):
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('train_pass', self.train)
if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train)
else:
self.regist_context_processor('train_pass', self.dataloader_train)
self.regist_context_processor('infer_pass', self.infer)
self.regist_context_processor('terminal_pass', self.terminal)
......@@ -51,12 +57,51 @@ class SingleTrainer(TranspileTrainer):
self.fetch_alias = metrics.keys()
context['status'] = 'train_pass'
def train(self, context):
# run startup program at once
def dataloader_train(self, context):
self._exe.run(fluid.default_startup_program())
reader = self._get_dataloader()
epochs = envs.get_global_env("train.epochs")
dataset = self._get_dataset()
program = fluid.compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=self.model.get_cost_op.name)
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_metrics().items():
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
for epoch in range(epochs):
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics_rets = np.mean(metrics_rets, axis=0)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets.tolist())
if batch_id % 10 == 0 and batch_id != 0:
print(metrics_format.format(metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'infer_pass'
def dataset_train(self, context):
# run startup program at once
self._exe.run(fluid.default_startup_program())
dataset = self._get_dataset()
epochs = envs.get_global_env("train.epochs")
for i in range(epochs):
......
......@@ -22,6 +22,7 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f
from fleetrec.core.trainer import Trainer
from fleetrec.core.utils import envs
from fleetrec.core.utils import dataloader_instance
class TranspileTrainer(Trainer):
......@@ -35,6 +36,16 @@ class TranspileTrainer(Trainer):
def processor_register(self):
print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
def _get_dataloader(self):
namespace = "train.reader"
dataloader = self.model._data_loader
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
reader = dataloader_instance.dataloader(reader_class, "TRAIN", self._config_yaml)
dataloader.set_sample_generator(reader, batch_size)
return dataloader
def _get_dataset(self):
namespace = "train.reader"
......@@ -43,7 +54,7 @@ class TranspileTrainer(Trainer):
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'reader_instance.py')
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
......@@ -123,7 +134,11 @@ class TranspileTrainer(Trainer):
print("Need to be implement")
context['is_exit'] = True
def train(self, context):
def dataloader_train(self, context):
print("Need to be implement")
context['is_exit'] = True
def dataset_train(self, context):
print("Need to be implement")
context['is_exit'] = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
from fleetrec.core.utils.envs import lazy_instance
from fleetrec.core.utils.envs import get_global_env
def dataloader(readerclass, train, yaml_file):
namespace = "train.reader"
if train == "TRAIN":
reader_name = "TrainReader"
data_path = get_global_env("train_data_path", None, namespace)
else:
reader_name = "EvaluateReader"
data_path = get_global_env("test_data_path", None, namespace)
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
reader_class = lazy_instance(readerclass, reader_name)
reader = reader_class(yaml_file)
reader.init()
def gen_reader():
for file in files:
with open(file, 'r') as f:
for line in f:
line = line.rstrip('\n').split('\t')
iter = reader.generate_sample(line)
for parsed_line in iter():
if parsed_line is None:
continue
else:
values = []
for pased in parsed_line:
values.append(pased[1])
yield values
return gen_reader
......@@ -57,6 +57,8 @@ class Model(ModelBase):
self._data_var.append(input)
self._data_var.append(self.label_input)
self._data_loader = fluid.io.PyReader(
feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False)
def net(self):
trainer = envs.get_trainer()
......
......@@ -85,6 +85,7 @@ def single_engine(args):
single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single"
single_envs["train.trainer.device"] = args.device
single_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
......@@ -98,6 +99,7 @@ def cluster_engine(args):
cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model)
......@@ -111,6 +113,7 @@ def cluster_mpi_engine(args):
cluster_envs = {}
cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model)
......@@ -131,7 +134,10 @@ def local_cluster_engine(args):
cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["train.trainer.threads"] = "2"
cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.model)
......@@ -154,7 +160,9 @@ def local_mpi_engine(args):
cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model)
launch = LocalMPIEngine(cluster_envs, args.model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册