未验证 提交 e42057cd 编写于 作者: H hutuxian 提交者: GitHub

add ut for pipeline training (#18289)

上级 5826b72e
...@@ -808,25 +808,6 @@ class Executor(object): ...@@ -808,25 +808,6 @@ class Executor(object):
else: else:
trainer._set_thread(thread) trainer._set_thread(thread)
# Adjust the reader size for small file num
if program._pipeline_opt:
dataset.set_thread(thread *
program._pipeline_opt["concurrency_list"][0])
file_size = len(dataset.dataset.get_filelist())
if file_size < thread:
thread = file_size
print(
"Pipeline: setting the pipeline num to %d is enough because there are only %d files"
% (file_size, file_size))
if file_size < thread * program._pipeline_opt["concurrency_list"][
0]:
print(
"Pipeline: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
% (file_size / thread, file_size))
program._pipeline_opt["concurrency_list"][
0] = file_size / thread
dataset.set_thread(
program._pipeline_opt["concurrency_list"][0] * thread)
trainer._set_debug(debug) trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period) trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return scope, trainer return scope, trainer
...@@ -970,6 +951,25 @@ class Executor(object): ...@@ -970,6 +951,25 @@ class Executor(object):
if dataset == None: if dataset == None:
raise RuntimeError("dataset is need and should be initialized") raise RuntimeError("dataset is need and should be initialized")
# Adjust the reader size for small file num
if program._pipeline_opt:
dataset.set_thread(thread *
program._pipeline_opt["concurrency_list"][0])
file_size = len(dataset.dataset.get_filelist())
if file_size < thread:
thread = file_size
print(
"Pipeline: setting the pipeline num to %d is enough because there are only %d files"
% (file_size, file_size))
if file_size < thread * program._pipeline_opt["concurrency_list"][
0]:
print(
"Pipeline: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
% (file_size / thread, file_size))
program._pipeline_opt["concurrency_list"][
0] = file_size / thread
dataset.set_thread(
program._pipeline_opt["concurrency_list"][0] * thread)
dataset._prepare_to_run() dataset._prepare_to_run()
scope, trainer = self._prepare_trainer( scope, trainer = self._prepare_trainer(
program=program, program=program,
......
...@@ -29,6 +29,9 @@ elseif(${CUDNN_VERSION} VERSION_LESS 7100) ...@@ -29,6 +29,9 @@ elseif(${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
endif() endif()
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_pipeline)
endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184 list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import numpy as np
import os
import shutil
import unittest
class TestPipeline(unittest.TestCase):
""" TestCases for Pipeline Training. """
def test_pipeline(self):
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
emb_x = layers.embedding(
input=x,
param_attr=fluid.ParamAttr(name="embx"),
size=[10, 2],
is_sparse=False)
emb_y = layers.embedding(
input=y,
param_attr=fluid.ParamAttr(
name="emby", learning_rate=0.9),
size=[10, 2],
is_sparse=False)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
num_flatten_dims=1,
bias_attr=False)
loss = layers.reduce_mean(fc)
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[[emb_x, emb_y], [loss]],
place_list=[
fluid.CPUPlace(), fluid.CUDAPlace(0), fluid.CPUPlace()
],
concurrency_list=[1, 1, 1],
queue_size=1,
sync_steps=10000000, )
optimizer.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#prepare data
batch_size = 100
def binary_print(slot, fout):
num = np.int16(len(slot) + 1)
num.tofile(fout)
a = np.int64(batch_size)
a.tofile(fout)
slot.tofile(fout)
#batch1 = np.array([[0,1], [1,2], [2,3]]).astype("int64").reshape(batch_size,2,1)
#batch2 = np.array([[1,2], [2,3], [3,4]]).astype("int64").reshape(batch_size,2,1)
batch1 = np.ones(
(batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1)
batch2 = np.ones(
(batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1)
data = [batch1, batch2]
filelist = []
for i in range(2):
filelist.append("test_pipeline_input_" + str(i))
for f in filelist:
with open(f, "wb") as fout:
for batch_data in data:
for ins in batch_data:
for slot in ins:
binary_print(slot, fout)
dataset = fluid.DatasetFactory().create_dataset("FileInstantDataset")
dataset.set_use_var([x, y])
dataset.set_batch_size(batch_size)
dataset.set_filelist(filelist)
for epoch in range(1):
exe.train_from_dataset(
fluid.default_main_program(),
dataset,
thread=1,
debug=False,
fetch_list=[],
fetch_info=[],
print_period=1)
for f in filelist:
os.remove(f)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册