未验证 提交 3583efec 编写于 作者: W wuzhihua 提交者: GitHub

Merge branch 'master' into add_FM

......@@ -14,9 +14,7 @@
import os
import sys
import yaml
from paddlerec.core.utils import envs
trainer_abs = os.path.join(
......@@ -66,16 +64,9 @@ class TrainerFactory(object):
@staticmethod
def create(config):
_config = None
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("paddlerec's config only support yaml")
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
trainer = TrainerFactory._build_trainer(config)
return trainer
......
......@@ -13,13 +13,11 @@
# limitations under the License.
from __future__ import print_function
import abc
import os
from functools import reduce
import paddle.fluid.incubate.data_generator as dg
import yaml
from paddlerec.core.utils import envs
......@@ -28,12 +26,9 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
@abc.abstractmethod
def init(self):
......@@ -50,11 +45,9 @@ class SlotReader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul
......
......@@ -30,16 +30,12 @@ class Trainer(object):
def __init__(self, config=None):
self._status_processor = {}
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
self._exector_context = {}
self._context = {'status': 'uninit', 'is_exit': False}
self._config_yaml = config
with open(config, 'r') as rb:
self._config = yaml.load(rb.read(), Loader=yaml.FullLoader)
self._config = envs.load_yaml(config)
def regist_context_processor(self, status_name, processor):
"""
......@@ -87,12 +83,8 @@ class Trainer(object):
def user_define_engine(engine_yaml):
with open(engine_yaml, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
_config = envs.load_yaml(engine_yaml)
envs.set_runtime_environs(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
......
......@@ -203,3 +203,26 @@ def find_free_port():
new_port = __free_port()
return new_port
def load_yaml(config):
vs = [int(i) for i in yaml.__version__.split(".")]
if vs[0] < 5:
use_full_loader = False
elif vs[0] > 5:
use_full_loader = True
else:
if vs[1] >= 1:
use_full_loader = True
else:
use_full_loader = False
if os.path.isfile(config):
with open(config, 'r') as rb:
if use_full_loader:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_config = yaml.load(rb.read())
return _config
else:
raise ValueError("config {} can not be supported".format(config))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlerec.core.utils import envs
class ValueFormat:
def __init__(self, type, value, value_handler):
self.type = type
self.value = value
self.value_handler = value_handler
self.help = help
def is_valid(self, name, value):
ret = self.is_type_valid(name, value)
if not ret:
return ret
ret = self.is_value_valid(name, value)
return ret
def is_type_valid(self, name, value):
if self.type == "int":
if not isinstance(value, int):
print("\nattr {} should be int, but {} now\n".format(
name, self.type))
return False
return True
elif self.type == "str":
if not isinstance(value, str):
print("\nattr {} should be str, but {} now\n".format(
name, self.type))
return False
return True
elif self.type == "strs":
if not isinstance(value, list):
print("\nattr {} should be list(str), but {} now\n".format(
name, self.type))
return False
for v in value:
if not isinstance(v, str):
print("\nattr {} should be list(str), but list({}) now\n".
format(name, type(v)))
return False
return True
elif self.type == "ints":
if not isinstance(value, list):
print("\nattr {} should be list(int), but {} now\n".format(
name, self.type))
return False
for v in value:
if not isinstance(v, int):
print("\nattr {} should be list(int), but list({}) now\n".
format(name, type(v)))
return False
return True
else:
print("\nattr {}'s type is {}, can not be supported now\n".format(
name, type(value)))
return False
def is_value_valid(self, name, value):
ret = self.value_handler(value)
return ret
def in_value_handler(name, value, values):
if value not in values:
print("\nattr {}'s value is {}, but {} is expected\n".format(
name, value, values))
return False
return True
def eq_value_handler(name, value, values):
if value != values:
print("\nattr {}'s value is {}, but == {} is expected\n".format(
name, value, values))
return False
return True
def ge_value_handler(name, value, values):
if value < values:
print("\nattr {}'s value is {}, but >= {} is expected\n".format(
name, value, values))
return False
return True
def le_value_handler(name, value, values):
if value > values:
print("\nattr {}'s value is {}, but <= {} is expected\n".format(
name, value, values))
return False
return True
def register():
validations = {}
validations["train.workspace"] = ValueFormat("str", None, eq_value_handler)
validations["train.device"] = ValueFormat("str", ["cpu", "gpu"],
in_value_handler)
validations["train.epochs"] = ValueFormat("int", 1, ge_value_handler)
validations["train.engine"] = ValueFormat(
"str", ["single", "local_cluster", "cluster"], in_value_handler)
requires = [
"train.namespace", "train.device", "train.epochs", "train.engine"
]
return validations, requires
def yaml_validation(config):
all_checkers, require_checkers = register()
_config = envs.load_yaml(config)
flattens = envs.flatten_environs(_config)
for required in require_checkers:
if required not in flattens.keys():
print("\ncan not find {} in yaml, which is required\n".format(
required))
return False
for name, flatten in flattens.items():
checker = all_checkers.get(name, None)
if not checker:
continue
ret = checker.is_valid(name, flattens)
if not ret:
return False
return True
......@@ -197,13 +197,7 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
......
......@@ -12,18 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
import yaml
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import math
import os
try:
import cPickle as pickle
except ImportError:
import pickle
from collections import Counter
import os
import paddle.fluid.incubate.data_generator as dg
......@@ -31,12 +24,6 @@ class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [
......
......@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml, os
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import os
import paddle.fluid.incubate.data_generator as dg
try:
import cPickle as pickle
......@@ -27,12 +24,6 @@ class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [
......
......@@ -32,7 +32,7 @@ class Model(ModelBase):
self.sparse_feature_dim = envs.get_global_env(
"hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env(
"hyper_parameters.learning_rate")
"hyper_parameters.optimizer.learning_rate")
def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:]
......
......@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml, os
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import os
try:
import cPickle as pickle
except ImportError:
......@@ -26,12 +23,6 @@ class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self):
pass
......
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml, os
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import os
try:
import cPickle as pickle
except ImportError:
......@@ -25,11 +23,6 @@ import paddle.fluid.incubate.data_generator as dg
class TrainReader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
if os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
def init(self):
pass
......
......@@ -14,7 +14,7 @@
import os
import subprocess
import sys
import argparse
import tempfile
import yaml
......@@ -22,6 +22,7 @@ import copy
from paddlerec.core.factory import TrainerFactory
from paddlerec.core.utils import envs
from paddlerec.core.utils import util
from paddlerec.core.utils import validation
engines = {}
device = ["CPU", "GPU"]
......@@ -48,9 +49,7 @@ def engine_registry():
def get_inters_from_yaml(file, filters):
with open(file, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
_envs = envs.load_yaml(file)
flattens = envs.flatten_environs(_envs)
inters = {}
for k, v in flattens.items():
......@@ -197,9 +196,7 @@ def cluster_engine(args):
def master():
role = "MASTER"
from paddlerec.core.engine.cluster.cluster import ClusterEngine
with open(args.backend, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
_envs = envs.load_yaml(args.backend)
flattens = envs.flatten_environs(_envs, "_")
flattens["engine_role"] = role
flattens["engine_run_config"] = args.model
......@@ -322,8 +319,9 @@ if __name__ == "__main__":
model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model)
if not validation.yaml_validation(args.model):
sys.exit(-1)
engine_registry()
which_engine = get_engine(args)
engine = which_engine(args)
engine.run()
......@@ -21,7 +21,7 @@ from setuptools import setup, find_packages
import shutil
import tempfile
requires = ["paddlepaddle == 1.7.2", "pyyaml >= 5.1.1"]
requires = ["paddlepaddle == 1.7.2", "PyYAML >= 5.1.1"]
about = {}
about["__title__"] = "paddle-rec"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册