From 8a9c5d2476791edea3a06a3db054a3294d5de2a9 Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 1 Apr 2020 17:50:58 +0800 Subject: [PATCH] add cluster training --- examples/train.py | 25 ++++++++++++ trainer/factory.py | 85 +++++++++++++++++++++++++++++++++++++++++ trainer/single_train.py | 5 +-- trainer/trainer.py | 15 +------- 4 files changed, 113 insertions(+), 17 deletions(-) create mode 100644 trainer/factory.py diff --git a/examples/train.py b/examples/train.py index 7c0f34df..46831437 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,3 +1,28 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import yaml from .. trainer.factory import TrainerFactory diff --git a/trainer/factory.py b/trainer/factory.py new file mode 100644 index 00000000..b96d2443 --- /dev/null +++ b/trainer/factory.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml + +from .single_train import SingleTrainerWithDataloader +from .single_train import SingleTrainerWithDataset + +from .cluster_train import ClusterTrainerWithDataloader +from .cluster_train import ClusterTrainerWithDataset + +from .ctr_trainer import CtrPaddleTrainer + +from ..utils import envs + + +class TrainerFactory(object): + def __init__(self): + pass + + @staticmethod + def _build_trainer(config): + train_mode = envs.get_global_env("train.trainer") + reader_mode = envs.get_global_env("train.reader.mode") + if train_mode == "SingleTraining": + if reader_mode == "dataset": + trainer = SingleTrainerWithDataset() + elif reader_mode == "dataloader": + trainer = SingleTrainerWithDataloader() + else: + raise ValueError("reader only support dataset/dataloader") + elif train_mode == "ClusterTraining": + if reader_mode == "dataset": + trainer = ClusterTrainerWithDataset() + elif reader_mode == "dataloader": + trainer = ClusterTrainerWithDataloader() + else: + raise ValueError("reader only support dataset/dataloader") + elif train_mode == "CtrTrainer": + trainer = CtrPaddleTrainer(config) + else: + raise ValueError("trainer only support SingleTraining/ClusterTraining") + + return trainer + + @staticmethod + def create(config): + _config = None + if isinstance(config, dict): + _config = config + elif isinstance(config, str): + if os.path.exists(config) and os.path.isfile(config): + with open(config, 'r') as rb: + _config = yaml.load(rb.read()) + else: + raise ValueError("unknown config about eleps") + + envs.set_global_envs(_config) + trainer = TrainerFactory._build_trainer(_config) + + return trainer diff --git a/trainer/single_train.py b/trainer/single_train.py index 9422dd5a..0865b419 100644 --- a/trainer/single_train.py +++ b/trainer/single_train.py @@ -39,9 +39,8 @@ def need_save(epoch_id, epoch_interval, is_last=False): class SingleTrainer(Trainer): - - def __init__(self, config=None, yaml_file=None): - Trainer.__init__(self, config, yaml_file) + def __init__(self, config=None): + Trainer.__init__(self, config) self.exe = fluid.Executor(fluid.CPUPlace()) diff --git a/trainer/trainer.py b/trainer/trainer.py index cd8f971c..5aee3bf1 100755 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -24,20 +24,7 @@ class Trainer(object): """ __metaclass__ = abc.ABCMeta - def __init__(self, config=None, yaml_file=None): - - if not config and not yaml_file: - raise ValueError("config and yaml file have at least one not empty") - - if config and yaml_file: - print("config and yaml file are all assigned, will use yaml file: {}".format(yaml_file)) - - if yaml_file: - with open(yaml_file, "r") as rb: - config = yaml.load(rb.read()) - - envs.set_global_envs(config) - + def __init__(self, config=None): self._status_processor = {} self._context = {'status': 'uninit', 'is_exit': False} -- GitLab