未验证 提交 12fc8c82 编写于 作者: C Chengmo 提交者: GitHub

Fix k8s_datasplit & save_model (#167)

* fix save inference
上级 d4a280b5
......@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/"
# 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193
PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=K8S
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
......
......@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $>
thirdparty_path=<$ THIRDPARTY_PATH $>
PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=MPI
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
......
......@@ -107,6 +107,7 @@ class Trainer(object):
self.device = Device.GPU
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
self._place = fluid.CUDAPlace(gpu_id)
print("PaddleRec run on device GPU: {}".format(gpu_id))
self._exe = fluid.Executor(self._place)
elif device == "CPU":
self.device = Device.CPU
......@@ -146,6 +147,7 @@ class Trainer(object):
elif engine.upper() == "CLUSTER":
self.engine = EngineMode.CLUSTER
self.is_fleet = True
self.which_cluster_type()
else:
raise ValueError("Not Support Engine {}".format(engine))
self._context["is_fleet"] = self.is_fleet
......@@ -165,6 +167,14 @@ class Trainer(object):
self._context["is_pslib"] = (fleet_mode.upper() == "PSLIB")
self._context["fleet_mode"] = fleet_mode
def which_cluster_type(self):
cluster_type = os.getenv("PADDLEREC_CLUSTER_TYPE", "MPI")
print("PADDLEREC_CLUSTER_TYPE: {}".format(cluster_type))
if cluster_type and cluster_type.upper() == "K8S":
self._context["cluster_type"] = "K8S"
else:
self._context["cluster_type"] = "MPI"
def which_executor_mode(self):
executor_mode = envs.get_runtime_environ("train.trainer.executor_mode")
if executor_mode.upper() not in ["TRAIN", "INFER"]:
......
......@@ -123,10 +123,21 @@ class QueueDataset(DatasetBase):
os.path.join(train_data_path, x)
for x in os.listdir(train_data_path)
]
file_list.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount afs, split files for every node
need_split_files = True
if need_split_files:
file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("File_list: {}".format(file_list))
dataset.set_filelist(file_list)
for model_dict in context["phases"]:
if model_dict["dataset_name"] == dataset_name:
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import os
import time
import warnings
import numpy as np
import paddle.fluid as fluid
......@@ -284,6 +285,7 @@ class RunnerBase(object):
return (epoch_id + 1) % epoch_interval == 0
def save_inference_model():
# get global env
name = "runner." + context["runner_name"] + "."
save_interval = int(
envs.get_global_env(name + "save_inference_interval", -1))
......@@ -296,18 +298,44 @@ class RunnerBase(object):
if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
len(feed_varnames) == 0 or len(fetch_varnames) == 0:
return
fetch_vars = [
fluid.default_main_program().global_block().vars[varname]
for varname in fetch_varnames
]
# check feed var exist
for var_name in feed_varnames:
if var_name not in fluid.default_main_program().global_block(
).vars:
raise ValueError(
"Feed variable: {} not in default_main_program, global block has follow vars: {}".
format(var_name,
fluid.default_main_program().global_block()
.vars.keys()))
# check fetch var exist
fetch_vars = []
for var_name in fetch_varnames:
if var_name not in fluid.default_main_program().global_block(
).vars:
raise ValueError(
"Fetch variable: {} not in default_main_program, global block has follow vars: {}".
format(var_name,
fluid.default_main_program().global_block()
.vars.keys()))
else:
fetch_vars.append(fluid.default_main_program()
.global_block().vars[var_name])
dirname = envs.get_global_env(name + "save_inference_path", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
context["fleet"].save_inference_model(
context["exe"], dirname, feed_varnames, fetch_vars)
warnings.warn(
"Save inference model in cluster training is not recommended! Using save checkpoint instead.",
category=UserWarning,
stacklevel=2)
if context["fleet"].worker_index() == 0:
context["fleet"].save_inference_model(
context["exe"], dirname, feed_varnames, fetch_vars)
else:
fluid.io.save_inference_model(dirname, feed_varnames,
fetch_vars, context["exe"])
......@@ -323,7 +351,8 @@ class RunnerBase(object):
return
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
context["fleet"].save_persistables(context["exe"], dirname)
if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname)
else:
fluid.io.save_persistables(context["exe"], dirname)
......
......@@ -39,9 +39,21 @@ def dataloader_by_name(readerclass,
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
print("need_split_files: {}".format(need_split_files))
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list : {}".format(files))
reader = reader_class(yaml_file)
......@@ -81,10 +93,20 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env(name + "sparse_slots", "#")
if sparse == "":
......@@ -135,10 +157,20 @@ def slotdataloader(readerclass, train, yaml_file, context):
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env("sparse_slots", "#", namespace)
if sparse == "":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册