提交 3cfb57a9 编写于 作者: X xjqbest

test

上级 74625fc4
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
import os import os
import sys import sys
import yaml import yaml
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
trainer_abs = os.path.join( trainer_abs = os.path.join(
...@@ -66,16 +64,9 @@ class TrainerFactory(object): ...@@ -66,16 +64,9 @@ class TrainerFactory(object):
@staticmethod @staticmethod
def create(config): def create(config):
_config = None _config = envs.load_yaml(config)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("paddlerec's config only support yaml")
envs.set_global_envs(_config) envs.set_global_envs(_config)
envs.update_workspace() envs.update_workspace()
trainer = TrainerFactory._build_trainer(config) trainer = TrainerFactory._build_trainer(config)
return trainer return trainer
......
...@@ -13,13 +13,11 @@ ...@@ -13,13 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import abc import abc
import os import os
from functools import reduce
import paddle.fluid.incubate.data_generator as dg import paddle.fluid.incubate.data_generator as dg
import yaml import yaml
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
...@@ -28,12 +26,9 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -28,12 +26,9 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config)
if os.path.isfile(config): envs.set_global_envs(_config)
with open(config, 'r') as rb: envs.update_workspace()
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
...@@ -50,11 +45,9 @@ class SlotReader(dg.MultiSlotDataGenerator): ...@@ -50,11 +45,9 @@ class SlotReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config): _config = envs.load_yaml(config)
with open(config, 'r') as rb: envs.set_global_envs(_config)
_config = yaml.load(rb.read(), Loader=yaml.FullLoader) envs.update_workspace()
else:
raise ValueError("reader config only support yaml")
def init(self, sparse_slots, dense_slots, padding=0): def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul from operator import mul
......
...@@ -30,16 +30,12 @@ class Trainer(object): ...@@ -30,16 +30,12 @@ class Trainer(object):
def __init__(self, config=None): def __init__(self, config=None):
self._status_processor = {} self._status_processor = {}
self._place = fluid.CPUPlace() self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place) self._exe = fluid.Executor(self._place)
self._exector_context = {} self._exector_context = {}
self._context = {'status': 'uninit', 'is_exit': False} self._context = {'status': 'uninit', 'is_exit': False}
self._config_yaml = config self._config_yaml = config
self._config = envs.load_yaml(config)
with open(config, 'r') as rb:
self._config = yaml.load(rb.read(), Loader=yaml.FullLoader)
def regist_context_processor(self, status_name, processor): def regist_context_processor(self, status_name, processor):
""" """
...@@ -87,12 +83,8 @@ class Trainer(object): ...@@ -87,12 +83,8 @@ class Trainer(object):
def user_define_engine(engine_yaml): def user_define_engine(engine_yaml):
with open(engine_yaml, 'r') as rb: _config = envs.load_yaml(engine_yaml)
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_runtime_environs(_config) envs.set_runtime_environs(_config)
train_location = envs.get_global_env("engine.file") train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location) train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0] base_name = os.path.splitext(os.path.basename(train_location))[0]
......
...@@ -203,3 +203,26 @@ def find_free_port(): ...@@ -203,3 +203,26 @@ def find_free_port():
new_port = __free_port() new_port = __free_port()
return new_port return new_port
def load_yaml(config):
vs = [int(i) for i in yaml.__version__.split(".")]
if vs[0] < 5:
use_full_loader = False
elif vs[0] > 5:
use_full_loader = True
else:
if vs[1] >= 1:
use_full_loader = True
else:
use_full_loader = False
if os.path.isfile(config):
with open(config, 'r') as rb:
if use_full_loader:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_config = yaml.load(rb.read())
return _config
else:
raise ValueError("config {} can not be supported".format(config))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlerec.core.utils import envs
class ValueFormat:
def __init__(self, type, value, value_handler):
self.type = type
self.value = value
self.value_handler = value_handler
self.help = help
def is_valid(self, name, value):
ret = self.is_type_valid(name, value)
if not ret:
return ret
ret = self.is_value_valid(name, value)
return ret
def is_type_valid(self, name, value):
if self.type == "int":
if not isinstance(value, int):
print("\nattr {} should be int, but {} now\n".format(
name, self.type))
return False
return True
elif self.type == "str":
if not isinstance(value, str):
print("\nattr {} should be str, but {} now\n".format(
name, self.type))
return False
return True
elif self.type == "strs":
if not isinstance(value, list):
print("\nattr {} should be list(str), but {} now\n".format(
name, self.type))
return False
for v in value:
if not isinstance(v, str):
print("\nattr {} should be list(str), but list({}) now\n".
format(name, type(v)))
return False
return True
elif self.type == "ints":
if not isinstance(value, list):
print("\nattr {} should be list(int), but {} now\n".format(
name, self.type))
return False
for v in value:
if not isinstance(v, int):
print("\nattr {} should be list(int), but list({}) now\n".
format(name, type(v)))
return False
return True
else:
print("\nattr {}'s type is {}, can not be supported now\n".format(
name, type(value)))
return False
def is_value_valid(self, name, value):
ret = self.value_handler(value)
return ret
def in_value_handler(name, value, values):
if value not in values:
print("\nattr {}'s value is {}, but {} is expected\n".format(
name, value, values))
return False
return True
def eq_value_handler(name, value, values):
if value != values:
print("\nattr {}'s value is {}, but == {} is expected\n".format(
name, value, values))
return False
return True
def ge_value_handler(name, value, values):
if value < values:
print("\nattr {}'s value is {}, but >= {} is expected\n".format(
name, value, values))
return False
return True
def le_value_handler(name, value, values):
if value > values:
print("\nattr {}'s value is {}, but <= {} is expected\n".format(
name, value, values))
return False
return True
def register():
validations = {}
validations["train.workspace"] = ValueFormat("str", None, eq_value_handler)
validations["train.device"] = ValueFormat("str", ["cpu", "gpu"],
in_value_handler)
validations["train.epochs"] = ValueFormat("int", 1, ge_value_handler)
validations["train.engine"] = ValueFormat(
"str", ["single", "local_cluster", "cluster"], in_value_handler)
requires = [
"train.namespace", "train.device", "train.epochs", "train.engine"
]
return validations, requires
def yaml_validation(config):
all_checkers, require_checkers = register()
_config = envs.load_yaml(config)
flattens = envs.flatten_environs(_config)
for required in require_checkers:
if required not in flattens.keys():
print("\ncan not find {} in yaml, which is required\n".format(
required))
return False
for name, flatten in flattens.items():
checker = all_checkers.get(name, None)
if not checker:
continue
ret = checker.is_valid(name, flattens)
if not ret:
return False
return True
...@@ -197,13 +197,7 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -197,13 +197,7 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
envs.set_global_envs(_config) envs.set_global_envs(_config)
envs.update_workspace() envs.update_workspace()
......
...@@ -12,18 +12,11 @@ ...@@ -12,18 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import sys
import yaml
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import math
import os import os
try: try:
import cPickle as pickle import cPickle as pickle
except ImportError: except ImportError:
import pickle import pickle
from collections import Counter
import os
import paddle.fluid.incubate.data_generator as dg import paddle.fluid.incubate.data_generator as dg
...@@ -31,12 +24,6 @@ class TrainReader(dg.MultiSlotDataGenerator): ...@@ -31,12 +24,6 @@ class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self): def init(self):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [ self.cont_max_ = [
......
...@@ -12,10 +12,7 @@ ...@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import yaml, os import os
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import paddle.fluid.incubate.data_generator as dg import paddle.fluid.incubate.data_generator as dg
try: try:
import cPickle as pickle import cPickle as pickle
...@@ -27,12 +24,6 @@ class TrainReader(dg.MultiSlotDataGenerator): ...@@ -27,12 +24,6 @@ class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self): def init(self):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [ self.cont_max_ = [
......
...@@ -32,7 +32,7 @@ class Model(ModelBase): ...@@ -32,7 +32,7 @@ class Model(ModelBase):
self.sparse_feature_dim = envs.get_global_env( self.sparse_feature_dim = envs.get_global_env(
"hyper_parameters.sparse_feature_dim") "hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env( self.learning_rate = envs.get_global_env(
"hyper_parameters.learning_rate") "hyper_parameters.optimizer.learning_rate")
def net(self, input, is_infer=False): def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:] self.sparse_inputs = self._sparse_data_var[1:]
......
...@@ -11,10 +11,7 @@ ...@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import yaml, os import os
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
try: try:
import cPickle as pickle import cPickle as pickle
except ImportError: except ImportError:
...@@ -26,12 +23,6 @@ class TrainReader(dg.MultiSlotDataGenerator): ...@@ -26,12 +23,6 @@ class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self): def init(self):
pass pass
......
...@@ -12,9 +12,7 @@ ...@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import yaml, os import os
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
try: try:
import cPickle as pickle import cPickle as pickle
except ImportError: except ImportError:
...@@ -25,11 +23,6 @@ import paddle.fluid.incubate.data_generator as dg ...@@ -25,11 +23,6 @@ import paddle.fluid.incubate.data_generator as dg
class TrainReader(dg.MultiSlotDataGenerator): class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self): def init(self):
pass pass
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import subprocess import subprocess
import sys
import argparse import argparse
import tempfile import tempfile
import yaml import yaml
...@@ -22,6 +22,7 @@ import copy ...@@ -22,6 +22,7 @@ import copy
from paddlerec.core.factory import TrainerFactory from paddlerec.core.factory import TrainerFactory
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.utils import util from paddlerec.core.utils import util
from paddlerec.core.utils import validation
engines = {} engines = {}
device = ["CPU", "GPU"] device = ["CPU", "GPU"]
...@@ -48,9 +49,7 @@ def engine_registry(): ...@@ -48,9 +49,7 @@ def engine_registry():
def get_inters_from_yaml(file, filters): def get_inters_from_yaml(file, filters):
with open(file, 'r') as rb: _envs = envs.load_yaml(file)
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs) flattens = envs.flatten_environs(_envs)
inters = {} inters = {}
for k, v in flattens.items(): for k, v in flattens.items():
...@@ -197,9 +196,7 @@ def cluster_engine(args): ...@@ -197,9 +196,7 @@ def cluster_engine(args):
def master(): def master():
role = "MASTER" role = "MASTER"
from paddlerec.core.engine.cluster.cluster import ClusterEngine from paddlerec.core.engine.cluster.cluster import ClusterEngine
with open(args.backend, 'r') as rb: _envs = envs.load_yaml(args.backend)
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs, "_") flattens = envs.flatten_environs(_envs, "_")
flattens["engine_role"] = role flattens["engine_role"] = role
flattens["engine_run_config"] = args.model flattens["engine_run_config"] = args.model
...@@ -322,8 +319,9 @@ if __name__ == "__main__": ...@@ -322,8 +319,9 @@ if __name__ == "__main__":
model_name = args.model.split('.')[-1] model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model) args.model = get_abs_model(args.model)
if not validation.yaml_validation(args.model):
sys.exit(-1)
engine_registry() engine_registry()
which_engine = get_engine(args) which_engine = get_engine(args)
engine = which_engine(args) engine = which_engine(args)
engine.run() engine.run()
...@@ -21,7 +21,7 @@ from setuptools import setup, find_packages ...@@ -21,7 +21,7 @@ from setuptools import setup, find_packages
import shutil import shutil
import tempfile import tempfile
requires = ["paddlepaddle == 1.7.2", "pyyaml >= 5.1.1"] requires = ["paddlepaddle == 1.7.2", "PyYAML >= 5.1.1"]
about = {} about = {}
about["__title__"] = "paddle-rec" about["__title__"] = "paddle-rec"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册