未验证 提交 b76594a0 编写于 作者: W wangzhen38 提交者: GitHub

[Add GPUPS CI] GPUBox unittest (#50130)

上级 494431c6
......@@ -9,3 +9,7 @@ foreach(TEST_OP ${TEST_OPS})
list(APPEND TEST_OPS ${TEST_OP})
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 50)
endforeach()
if(WITH_PSCORE)
set_tests_properties(test_gpubox_ps PROPERTIES LABELS "RUN_TYPE=GPUPS")
endif()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# workspace
#workspace: "models/rank/dnn"
runner:
train_data_dir: "data/sample_data/train"
train_reader_path: "criteo_reader" # importlib format
use_gpu: True
use_auc: False
train_batch_size: 32
epochs: 3
print_interval: 10
model_save_path: "output_model_dnn_queue"
sync_mode: "gpubox"
thread_num: 30
reader_type: "InmemoryDataset" # DataLoader / QueueDataset / RecDataset / InmemoryDataset
pipe_command: "python3.7 dataset_generator_criteo.py"
dataset_debug: False
split_file_list: False
infer_batch_size: 2
infer_reader_path: "criteo_reader" # importlib format
test_data_dir: "data/sample_data/train"
infer_load_path: "output_model_dnn_queue"
infer_start_epoch: 0
infer_end_epoch: 3
# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.001
strategy: async
# user-defined <key, value> pairs
sparse_inputs_slots: 27
sparse_feature_number: 1024
sparse_feature_dim: 9
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
distributed_embedding: 0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle.distributed.fleet as fleet
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO
)
logger = logging.getLogger(__name__)
class Reader(fleet.MultiSlotDataGenerator):
def init(self):
padding = 0
sparse_slots = "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = ["dense_feature"]
self.dense_slots_shape = [13]
self.slots = self.sparse_slots + self.dense_slots
self.slot2index = {}
self.visit = {}
for i in range(len(self.slots)):
self.slot2index[self.slots[i]] = i
self.visit[self.slots[i]] = False
self.padding = padding
logger.info("pipe init success")
def line_process(self, line):
line = line.strip().split(" ")
output = [(i, []) for i in self.slots]
for i in line:
slot_feasign = i.split(":")
slot = slot_feasign[0]
if slot not in self.slots:
continue
if slot in self.sparse_slots:
feasign = int(slot_feasign[1])
else:
feasign = float(slot_feasign[1])
output[self.slot2index[slot]][1].append(feasign)
self.visit[slot] = True
for i in self.visit:
slot = i
if not self.visit[slot]:
if i in self.dense_slots:
output[self.slot2index[i]][1].extend(
[self.padding]
* self.dense_slots_shape[self.slot2index[i]]
)
else:
output[self.slot2index[i]][1].extend([self.padding])
else:
self.visit[slot] = False
return output
# return [label] + sparse_feature + [dense_feature]
def generate_sample(self, line):
r"Dataset Generator"
def reader():
output_dict = self.line_process(line)
# {key, value} dict format: {'labels': [1], 'sparse_slot1': [2, 3], 'sparse_slot2': [4, 5, 6, 8], 'dense_slot': [1,2,3,4]}
# dict must match static_model.create_feed()
yield output_dict
return reader
if __name__ == "__main__":
r = Reader()
r.init()
r.run_from_stdin()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
wget --no-check-certificate https://paddlerec.bj.bcebos.com/benchmark/sample_train.txt
mkdir train_data
mv sample_train.txt train_data/
# !/bin/bash
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ ! -d "./log" ]; then
mkdir ./log
echo "Create log floder for store running log"
fi
export FLAGS_LAUNCH_BARRIER=0
export PADDLE_TRAINER_ID=0
export PADDLE_PSERVER_NUMS=1
export PADDLE_TRAINERS=1
export PADDLE_TRAINERS_NUM=${PADDLE_TRAINERS}
export POD_IP=127.0.0.1
# set free port if 29011 is occupied
export PADDLE_PSERVERS_IP_PORT_LIST="127.0.0.1:29011"
export PADDLE_PSERVER_PORT_ARRAY=(29011)
# set gpu numbers according to your device
export FLAGS_selected_gpus="0,1,2,3,4,5,6,7"
#export FLAGS_selected_gpus="0,1"
# set your model yaml
#SC="gpubox_ps_trainer.py"
SC="static_gpubox_trainer.py"
# run pserver
export TRAINING_ROLE=PSERVER
for((i=0;i<$PADDLE_PSERVER_NUMS;i++))
do
cur_port=${PADDLE_PSERVER_PORT_ARRAY[$i]}
echo "PADDLE WILL START PSERVER "$cur_port
export PADDLE_PORT=${cur_port}
python3.7 -u $SC &> ./log/pserver.$i.log &
done
# run trainer
export TRAINING_ROLE=TRAINER
for((i=0;i<$PADDLE_TRAINERS;i++))
do
echo "PADDLE WILL START Trainer "$i
export PADDLE_TRAINER_ID=$i
python3.7 -u $SC &> ./log/worker.$i.log
done
echo "Training log stored in ./log/"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
import os
import sys
import time
import paddle
import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
fleet_util = FleetUtil()
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO
)
logger = logging.getLogger(__name__)
def get_dataset(inputs, config):
dataset = paddle.distributed.InMemoryDataset()
dataset._set_use_ps_gpu(config.get('runner.use_gpu'))
pipe_cmd = config.get('runner.pipe_command')
dataset.init(
use_var=inputs,
pipe_command=pipe_cmd,
batch_size=32,
thread_num=int(config.get('runner.thread_num')),
fs_name=config.get("runner.fs_name", ""),
fs_ugi=config.get("runner.fs_ugi", ""),
)
dataset.set_filelist(["train_data/sample_train.txt"])
dataset.update_settings(
parse_ins_id=config.get("runner.parse_ins_id", False),
parse_content=config.get("runner.parse_content", False),
)
return dataset
class Main(object):
def __init__(self):
self.metrics = {}
self.input_data = None
self.reader = None
self.exe = None
self.model = None
self.PSGPU = None
self.train_result_dict = {}
self.train_result_dict["speed"] = []
self.train_result_dict["auc"] = []
def run(self):
from ps_dnn_trainer import YamlHelper
yaml_helper = YamlHelper()
config_yaml_path = 'config_gpubox.yaml'
self.config = yaml_helper.load_yaml(config_yaml_path)
os.environ["CPU_NUM"] = str(self.config.get("runner.thread_num"))
fleet.init()
self.network()
if fleet.is_server():
self.run_server()
elif fleet.is_worker():
self.run_worker()
fleet.stop_worker()
logger.info("Run Success, Exit.")
logger.info("-" * 100)
def network(self):
from ps_dnn_trainer import StaticModel, get_user_defined_strategy
# self.model = get_model(self.config)
self.model = StaticModel(self.config)
self.input_data = self.model.create_feeds()
self.init_reader()
self.metrics = self.model.net(self.input_data)
self.inference_target_var = self.model.inference_target_var
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
# self.model.create_optimizer(get_strategy(self.config)
user_defined_strategy = get_user_defined_strategy(self.config)
optimizer = paddle.optimizer.Adam(0.01, lazy_mode=True)
optimizer = fleet.distributed_optimizer(
optimizer, user_defined_strategy
)
optimizer.minimize(self.model._cost)
logger.info("end network.....")
def run_server(self):
logger.info("Run Server Begin")
fleet.init_server(self.config.get("runner.warmup_model_path"))
fleet.run_server()
def run_worker(self):
logger.info("Run Worker Begin")
use_cuda = int(self.config.get("runner.use_gpu"))
use_auc = self.config.get("runner.use_auc", False)
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
self.exe = paddle.static.Executor(place)
'''
with open("./{}_worker_main_program.prototxt".format(
fleet.worker_index()), 'w+') as f:
f.write(str(paddle.static.default_main_program()))
with open("./{}_worker_startup_program.prototxt".format(
fleet.worker_index()), 'w+') as f:
f.write(str(paddle.static.default_startup_program()))
'''
self.exe.run(paddle.static.default_startup_program())
fleet.init_worker()
'''
save_model_path = self.config.get("runner.model_save_path")
if save_model_path and (not os.path.exists(save_model_path)):
os.makedirs(save_model_path)
'''
reader_type = self.config.get("runner.reader_type", None)
epochs = int(self.config.get("runner.epochs"))
sync_mode = self.config.get("runner.sync_mode")
gpus_env = os.getenv("FLAGS_selected_gpus")
self.PSGPU = paddle.framework.core.PSGPU()
gpuslot = [int(i) for i in range(1, self.model.sparse_inputs_slots)]
gpu_mf_sizes = [self.model.sparse_feature_dim - 1] * (
self.model.sparse_inputs_slots - 1
)
self.PSGPU.set_slot_vector(gpuslot)
self.PSGPU.set_slot_dim_vector(gpu_mf_sizes)
self.PSGPU.init_gpu_ps([int(s) for s in gpus_env.split(",")])
gpu_num = len(gpus_env.split(","))
opt_info = paddle.static.default_main_program()._fleet_opt
if use_auc is True:
opt_info['stat_var_names'] = [
self.model.stat_pos.name,
self.model.stat_neg.name,
]
else:
opt_info['stat_var_names'] = []
for epoch in range(epochs):
epoch_start_time = time.time()
self.dataset_train_loop(epoch)
epoch_time = time.time() - epoch_start_time
self.PSGPU.end_pass()
fleet.barrier_worker()
self.reader.release_memory()
logger.info("finish {} epoch training....".format(epoch))
self.PSGPU.finalize()
def init_reader(self):
if fleet.is_server():
return
# self.reader, self.file_list = get_reader(self.input_data, config)
self.reader = get_dataset(self.input_data, self.config)
def dataset_train_loop(self, epoch):
start_time = time.time()
self.reader.load_into_memory()
print(
"self.reader.load_into_memory cost :{} seconds".format(
time.time() - start_time
)
)
begin_pass_time = time.time()
self.PSGPU.begin_pass()
print(
"begin_pass cost:{} seconds".format(time.time() - begin_pass_time)
)
logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
fetch_info = [
"Epoch {} Var {}".format(epoch, var_name)
for var_name in self.metrics
]
fetch_vars = [var for _, var in self.metrics.items()]
print_step = int(self.config.get("runner.print_interval"))
self.exe.train_from_dataset(
program=paddle.static.default_main_program(),
dataset=self.reader,
debug=self.config.get("runner.dataset_debug"),
)
if __name__ == "__main__":
paddle.enable_static()
benchmark_main = Main()
benchmark_main.run()
#!/bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shlex # noqa: F401
import unittest
class GpuBoxTest(unittest.TestCase):
def test_gpubox(self):
exitcode = os.system('sh gpubox_run.sh')
os.system('rm *_train_desc.prototxt')
if os.path.exists('./train_data'):
os.system('rm -rf train_data')
if os.path.exists('./log'):
os.system('rm -rf log')
if __name__ == '__main__':
if not os.path.exists('./train_data'):
os.system('sh download_criteo_data.sh')
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册