diff --git a/examples/train.py b/examples/train.py index 7c0f34df2051441a83a86f704aee25d6f96226b6..46831437da506f24ecfa6ac90554dbef802a178c 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,3 +1,28 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import yaml from .. trainer.factory import TrainerFactory diff --git a/trainer/factory.py b/trainer/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..b96d244365bf6ff4e2703ccf6b73fb572dfe986a --- /dev/null +++ b/trainer/factory.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml + +from .single_train import SingleTrainerWithDataloader +from .single_train import SingleTrainerWithDataset + +from .cluster_train import ClusterTrainerWithDataloader +from .cluster_train import ClusterTrainerWithDataset + +from .ctr_trainer import CtrPaddleTrainer + +from ..utils import envs + + +class TrainerFactory(object): + def __init__(self): + pass + + @staticmethod + def _build_trainer(config): + train_mode = envs.get_global_env("train.trainer") + reader_mode = envs.get_global_env("train.reader.mode") + if train_mode == "SingleTraining": + if reader_mode == "dataset": + trainer = SingleTrainerWithDataset() + elif reader_mode == "dataloader": + trainer = SingleTrainerWithDataloader() + else: + raise ValueError("reader only support dataset/dataloader") + elif train_mode == "ClusterTraining": + if reader_mode == "dataset": + trainer = ClusterTrainerWithDataset() + elif reader_mode == "dataloader": + trainer = ClusterTrainerWithDataloader() + else: + raise ValueError("reader only support dataset/dataloader") + elif train_mode == "CtrTrainer": + trainer = CtrPaddleTrainer(config) + else: + raise ValueError("trainer only support SingleTraining/ClusterTraining") + + return trainer + + @staticmethod + def create(config): + _config = None + if isinstance(config, dict): + _config = config + elif isinstance(config, str): + if os.path.exists(config) and os.path.isfile(config): + with open(config, 'r') as rb: + _config = yaml.load(rb.read()) + else: + raise ValueError("unknown config about eleps") + + envs.set_global_envs(_config) + trainer = TrainerFactory._build_trainer(_config) + + return trainer diff --git a/trainer/single_train.py b/trainer/single_train.py index 9422dd5ab855bd880fff906ef413e748c52d52ad..0865b419ce16927e6b22a805c69f6bd7c41c16cc 100644 --- a/trainer/single_train.py +++ b/trainer/single_train.py @@ -39,9 +39,8 @@ def need_save(epoch_id, epoch_interval, is_last=False): class SingleTrainer(Trainer): - - def __init__(self, config=None, yaml_file=None): - Trainer.__init__(self, config, yaml_file) + def __init__(self, config=None): + Trainer.__init__(self, config) self.exe = fluid.Executor(fluid.CPUPlace()) diff --git a/trainer/trainer.py b/trainer/trainer.py index cd8f971cb00aabe5def41f44b52a656b8a525d52..5aee3bf1b0282b1006a26979160926739de0996f 100755 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -24,20 +24,7 @@ class Trainer(object): """ __metaclass__ = abc.ABCMeta - def __init__(self, config=None, yaml_file=None): - - if not config and not yaml_file: - raise ValueError("config and yaml file have at least one not empty") - - if config and yaml_file: - print("config and yaml file are all assigned, will use yaml file: {}".format(yaml_file)) - - if yaml_file: - with open(yaml_file, "r") as rb: - config = yaml.load(rb.read()) - - envs.set_global_envs(config) - + def __init__(self, config=None): self._status_processor = {} self._context = {'status': 'uninit', 'is_exit': False}