未验证 提交 c5aec2fe 编写于 作者: H hutuxian 提交者: GitHub

Paddlebox Related to Framework (#21586)

* Add a single_process_multi_thread transpiler.
* Add some UTs.
* Fix some API description.
上级 9da7e6b4
......@@ -433,6 +433,12 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num) {
total_data_channel->Close();
total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num +
1);
// will discard the remaining instances,
// TODO(hutuxian): should add a config here to choose how to deal with
// remaining instances
if (static_cast<int>(input_channel_->Size()) >= channel_num) {
input_channel_->SetBlockSize(input_channel_->Size() / channel_num);
}
for (int i = 0; i < channel_num; ++i) {
local_vec.clear();
......
......@@ -121,13 +121,27 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
void PipelineTrainer::InitFirstScopeQueue(ScopeQueue* scope_queue,
int pipeline_id,
const ProgramDesc& main_program) {
const ProgramDesc& main_program,
const Scope& root_scope) {
for (int i = 0; i < scope_queue_size_; ++i) {
Scope* scope = &pipeline_scopes_[pipeline_id]->NewScope();
for (auto& var : main_program.Block(0).AllVars()) {
if (!var->Persistable()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
if (section_num_ == 1) { // Means only one section and it must be
// CUDAPlace, so copy all persistable vars to
// pipeline scope
const LoDTensor& root_tensor =
root_scope.FindVar(var->Name())->Get<LoDTensor>();
LoDTensor* gpu_tensor = pipeline_scopes_[pipeline_id]
->Var(var->Name())
->GetMutable<LoDTensor>();
platform::Place place = platform::CUDAPlace(pipeline_id);
TensorCopy(*static_cast<const Tensor*>(&root_tensor), place,
static_cast<Tensor*>(gpu_tensor));
}
}
}
scope_queue->Send(scope);
......@@ -162,7 +176,8 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
if (i == 0) {
pipeline_scopes_[j] = &root_scope_->NewScope();
CopyParameters(*root_scope_, j);
InitFirstScopeQueue(scope_queues_[0].back().get(), j, main_program);
InitFirstScopeQueue(scope_queues_[0].back().get(), j, main_program,
*root_scope_);
}
}
}
......@@ -192,7 +207,7 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
}
if (pipeline_num_ > 1) {
if (pipeline_num_ > 1 && sync_steps_ != -1) {
construct_sync_functor();
}
}
......
......@@ -159,7 +159,8 @@ class PipelineTrainer : public TrainerBase {
std::vector<DataFeed*> readers_;
void InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id,
const ProgramDesc& main_program);
const ProgramDesc& main_program,
const Scope& root_scope);
void CopyParameters(const Scope& root_scope, int pipeline_id);
void construct_sync_functor();
};
......
......@@ -40,11 +40,16 @@ void BindBoxHelper(py::module* m) {
.def(py::init([](paddle::framework::Dataset* dataset) {
return std::make_shared<paddle::framework::BoxHelper>(dataset);
}))
.def("begin_pass", &framework::BoxHelper::BeginPass)
.def("end_pass", &framework::BoxHelper::EndPass)
.def("wait_feed_pass_done", &framework::BoxHelper::WaitFeedPassDone)
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory)
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory);
.def("begin_pass", &framework::BoxHelper::BeginPass,
py::call_guard<py::gil_scoped_release>())
.def("end_pass", &framework::BoxHelper::EndPass,
py::call_guard<py::gil_scoped_release>())
.def("wait_feed_pass_done", &framework::BoxHelper::WaitFeedPassDone,
py::call_guard<py::gil_scoped_release>())
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>());
} // end BoxHelper
} // end namespace pybind
} // end namespace paddle
......@@ -772,7 +772,7 @@ class BoxPSDataset(InMemoryDataset):
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("BoxPSDataset")
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
"""
def __init__(self):
......@@ -786,34 +786,72 @@ class BoxPSDataset(InMemoryDataset):
def begin_pass(self):
"""
Begin Pass
Notify BoxPS to begin next pass
"""
Notify BoxPS to load sparse parameters of next pass to GPU Memory
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
dataset.begin_pass()
"""
self.boxps.begin_pass()
def end_pass(self):
"""
End Pass
Notify BoxPS to end current pass
"""
Notify BoxPS that current pass ended
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
dataset.end_pass()
"""
self.boxps.end_pass()
def wait_preload_done(self):
"""
Wait async proload done
Wait Until Feed Pass Done
"""
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.preload_into_memory()
dataset.wait_preload_done()
"""
self.boxps.wait_feed_pass_done()
def load_into_memory(self):
"""
Load next pass into memory and notify boxps to fetch its emb from SSD
"""
Load next pass into memory and notify boxps to fetch its emb from SSD
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
"""
self._prepare_to_run()
self.boxps.load_into_memory()
def preload_into_memory(self):
"""
begin async preload next pass while current pass may be training
"""
Begin async preload next pass while current pass may be training
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.preload_into_memory()
"""
self._prepare_to_run()
self.boxps.preload_into_memory()
......@@ -986,7 +986,9 @@ class Executor(object):
if dataset is None:
raise RuntimeError("dataset is need and should be initialized")
if program._pipeline_opt:
if program._pipeline_opt is not None and program._pipeline_opt[
"sync_steps"] != -1:
# hack for paddlebox: sync_steps(-1) denotes paddlebox
thread = self._adjust_pipeline_resource(program._pipeline_opt,
dataset, thread)
......
......@@ -3455,9 +3455,14 @@ class PipelineOptimizer(object):
self._optimizer.minimize(loss, startup_program, parameter_list,
no_grad_set)
program = loss.block.program
program_list = self._split_program(program, self._cut_list)
for p in program_list:
self._create_vars(p["program"].block(0), program)
if len(self._cut_list) == 0:
program_list = []
ptmp = {"program": program, "input_set": set(), "output_set": set()}
program_list.append(ptmp)
else:
program_list = self._split_program(program, self._cut_list)
for p in program_list:
self._create_vars(p["program"].block(0), program)
whole_parameters = [e.name for e in program.block(0).all_parameters()]
param_need_sync = []
for i, section_p in enumerate(program_list):
......
......@@ -19,6 +19,61 @@ import os
import paddle.fluid.core as core
import unittest
from paddle.fluid.layers.nn import _pull_box_sparse
from paddle.fluid.transpiler import collective
class TestTranspile(unittest.TestCase):
""" TestCases for BoxPS Preload """
def get_transpile(self, mode, trainers="127.0.0.1:6174"):
config = fluid.DistributeTranspilerConfig()
config.mode = 'collective'
config.collective_mode = mode
t = fluid.DistributeTranspiler(config=config)
return t
def test_transpile(self):
main_program = fluid.Program()
startup_program = fluid.Program()
t = self.get_transpile("single_process_multi_thread")
t.transpile(
trainer_id=0,
startup_program=startup_program,
trainers="127.0.0.1:6174",
program=main_program)
t = self.get_transpile("grad_allreduce")
try:
t.transpile(
trainer_id=0,
startup_program=startup_program,
trainers="127.0.0.1:6174",
program=main_program)
except ValueError as e:
print(e)
def test_single_trainers(self):
transpiler = collective.GradAllReduce(0)
try:
transpiler.transpile(
startup_program=fluid.Program(),
main_program=fluid.Program(),
rank=1,
endpoints="127.0.0.1:6174",
current_endpoint="127.0.0.1:6174",
wait_port="6174")
except ValueError as e:
print(e)
transpiler = collective.LocalSGD(0)
try:
transpiler.transpile(
startup_program=fluid.Program(),
main_program=fluid.Program(),
rank=1,
endpoints="127.0.0.1:6174",
current_endpoint="127.0.0.1:6174",
wait_port="6174")
except ValueError as e:
print(e)
class TestBoxPSPreload(unittest.TestCase):
......
......@@ -147,6 +147,92 @@ class TestPipeline(unittest.TestCase):
for f in filelist:
os.remove(f)
def test_pipeline_single_section(self):
program = fluid.Program()
with fluid.program_guard(program):
x = fluid.layers.data(
name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(
name='y', shape=[1], dtype='int64', lod_level=0)
emb_x = layers.embedding(
input=x,
param_attr=fluid.ParamAttr(name="embx"),
size=[10, 2],
is_sparse=False)
emb_y = layers.embedding(
input=y,
param_attr=fluid.ParamAttr(
name="emby", learning_rate=0.9),
size=[10, 2],
is_sparse=False)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
num_flatten_dims=1,
bias_attr=False)
loss = layers.reduce_mean(fc)
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[],
place_list=[fluid.CUDAPlace(0)],
concurrency_list=[1],
queue_size=1,
sync_steps=-1)
optimizer.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#prepare data
batch_size = 100
def binary_print(slot, fout):
num = np.int16(len(slot) + 1)
num.tofile(fout)
a = np.int64(batch_size)
a.tofile(fout)
slot.tofile(fout)
#batch1 = np.array([[0,1], [1,2], [2,3]]).astype("int64").reshape(batch_size,2,1)
#batch2 = np.array([[1,2], [2,3], [3,4]]).astype("int64").reshape(batch_size,2,1)
batch1 = np.ones(
(batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1)
batch2 = np.ones(
(batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1)
data = [batch1, batch2]
filelist = []
for i in range(2):
filelist.append("test_pipeline_input_" + str(i))
for f in filelist:
with open(f, "wb") as fout:
for batch_data in data:
for ins in batch_data:
for slot in ins:
binary_print(slot, fout)
dataset = fluid.DatasetFactory().create_dataset(
"FileInstantDataset")
dataset.set_use_var([x, y])
dataset.set_batch_size(batch_size)
dataset.set_filelist(filelist)
for epoch in range(1):
exe.train_from_dataset(
fluid.default_main_program(),
dataset,
thread=1,
debug=False,
fetch_list=[],
fetch_info=[],
print_period=1)
for f in filelist:
os.remove(f)
if __name__ == '__main__':
unittest.main()
......@@ -64,7 +64,7 @@ class Collective(object):
self.main_program = default_main_program()
self.nranks = len(endpoints)
if self.nranks == 1:
if self.nranks == 1 and self.mode != "single_process_multi_thread":
raise ValueError('the number of endpoints must > 1')
if rank < 0:
......@@ -181,6 +181,7 @@ class GradAllReduce(Collective):
def __init__(self, nrings=2):
Collective.__init__(self, nrings)
self.mode = "grad_allreduce"
def _transpile_main_program(self):
self._insert_scale_loss_grad_ops()
......@@ -273,6 +274,7 @@ class LocalSGD(Collective):
def __init__(self, nrings=2):
Collective.__init__(self, nrings)
self.snapshot_key = '@SNAPSHOT'
self.mode = "local_sgd"
def _transpile_startup_program(self):
Collective._transpile_startup_program(self)
......@@ -370,3 +372,16 @@ class LocalSGD(Collective):
inputs={'X': [param]},
outputs={'Out': [snapshot]},
attrs={self.op_role_key: OpRole.Optimize})
class SingleProcessMultiThread(GradAllReduce):
'''
'''
def __init__(self):
GradAllReduce.__init__(self, -1)
self.mode = "single_process_multi_thread"
def _transpile_startup_program(self):
block = self.startup_program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
......@@ -370,10 +370,11 @@ class DistributeTranspiler(object):
endpoints = trainers.split(",")
elif isinstance(trainers, list):
endpoints = trainers
else:
elif collective_mode != "single_process_multi_thread":
raise ValueError('invalid trainers config: ' + str(trainers))
if len(endpoints) == 1:
if len(endpoints
) == 1 and collective_mode != "single_process_multi_thread":
raise ValueError('invalid trainer number in distributed: 1')
if startup_program is None:
......@@ -387,6 +388,8 @@ class DistributeTranspiler(object):
transpiler = collective.GradAllReduce(self.config.nccl_comm_num)
elif collective_mode == 'local_sgd':
transpiler = collective.LocalSGD(self.config.nccl_comm_num)
elif collective_mode == "single_process_multi_thread":
transpiler = collective.SingleProcessMultiThread()
else:
raise ValueError('invalid collective_mode: %s' % collective_mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册