diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 3c73b6cc55c187c3f6e7edd1ce38cc58f4e8413d..4fb4ec38ee965a2790d11378a1ce6befa0ef5a00 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -25,11 +25,12 @@ else() cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) endif() +cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 64e83acb4dc1995800c4ca3caf81668b24a7c9fe..9c2c845c6efb206fb1ad5150189430b9a6fe9ea3 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -33,6 +33,8 @@ struct BuildStrategy { GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; std::string debug_graphviz_path_{""}; + + bool enable_data_balance_{true}; }; } // namespace details diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..b914851fe0add74f6d85589f4686224b668b8064 --- /dev/null +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/data_balance_op_handle.h" +#include +#include "paddle/fluid/framework/details/container_cast.h" + +namespace paddle { +namespace framework { +namespace details { + +#ifdef PADDLE_WITH_CUDA +DataBalanceOpHandle::DataBalanceOpHandle( + const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs) + : local_scopes_(local_scopes), places_(places) { + if (ctxs) { + for (auto &p : places_) { + this->dev_ctxes_[p] = ctxs->DevCtx(p); + } + } +} +#else +DataBalanceOpHandle::DataBalanceOpHandle( + const std::vector &local_scopes, + const std::vector &places) + : local_scopes_(local_scopes), places_(places) {} +#endif + +std::string DataBalanceOpHandle::Name() const { return "data balance"; } + +std::vector> DataBalanceOpHandle::GetBalancePlan( + const std::vector &device_sizes) { + int device_num = device_sizes.size(); + int total_size = 0; + int empty_num = 0; + std::vector> size_device_vec; + size_device_vec.reserve(device_num); + for (int i = 0; i < device_num; ++i) { + if (device_sizes[i] == 0) { + ++empty_num; + } + total_size += device_sizes[i]; + size_device_vec.push_back({{device_sizes[i], i}}); + } + std::vector> res; + if (empty_num == 0) { + // No need to do data balance. + return res; + } + if (total_size < device_num) { + // No enough data. + PADDLE_THROW("There is no next data."); + } + std::sort(size_device_vec.begin(), size_device_vec.end(), + [](const std::array &a, const std::array &b) { + return a[0] > b[0]; + }); + int expected_device_size = total_size / device_num; + int src_idx = 0; + for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) { + if (size_device_vec[src_idx][0] <= expected_device_size) { + ++src_idx; + PADDLE_ENFORCE_LT( + src_idx, device_num - empty_num, + "In current srategy an empty tensor should not be copy source."); + } + size_device_vec[src_idx][0] -= expected_device_size; + size_device_vec[dst_idx][0] += expected_device_size; + res.push_back({{size_device_vec[src_idx][1], size_device_vec[dst_idx][1], + expected_device_size}}); + } + return res; +} + +void DataBalanceOpHandle::RunImpl() { + if (places_.size() == 1) { + return; + } + auto in_var_handles = DynamicCast(inputs_); + auto out_var_handles = DynamicCast(outputs_); + PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), out_var_handles.size(), + "The NoDummyInputSize and NoDummyOutputSize should be equal."); + int data_num = in_var_handles.size() / places_.size(); + WaitInputVarGenerated(); + std::vector> lod_tensors(data_num); + std::vector device_sizes; + for (int i = 0; i < static_cast(in_var_handles.size()); ++i) { + PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, + "The name of input and output should be equal."); + int place_idx = i / data_num; + int data_idx = i % data_num; + auto *local_scope = + local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get(); + auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name_); + PADDLE_ENFORCE(tensor_var->IsType()); + auto *tensor = tensor_var->GetMutable(); + lod_tensors[data_idx].push_back(tensor); + int ins_size = + tensor->lod().empty() ? tensor->dims()[0] : tensor->NumElements(); + if (data_idx == 0) { + device_sizes.emplace_back(ins_size); + } else { + PADDLE_ENFORCE_EQ( + ins_size, device_sizes.at(place_idx), + "All data on the same device shall have the same batch size."); + } + } + const auto &balance_plan = GetBalancePlan(device_sizes); + + for (const auto &trans : balance_plan) { + for (int data_idx = 0; data_idx < data_num; ++data_idx) { + LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]]; + LoDTensor *dst_tensor = lod_tensors[data_idx][trans[1]]; + int trans_ins_size = trans[2]; + LoD src_lod = src_tensor->lod(); + int src_ins_size = + src_lod.empty() ? src_tensor->dims()[0] : src_tensor->NumElements(); + int cut_point = src_ins_size - trans_ins_size; + if (!src_lod.empty()) { + for (auto &level : src_lod) { + cut_point = level[cut_point]; + } + } + TensorCopySync(src_tensor->Slice(cut_point, src_tensor->dims()[0]), + dst_tensor->place(), dst_tensor); + src_tensor->ShareDataWith(src_tensor->Slice(0, cut_point)); + if (!src_lod.empty()) { + dst_tensor->set_lod(SliceInLevel( + src_lod, 0, src_ins_size - trans_ins_size, src_ins_size)); + src_tensor->set_lod( + SliceInLevel(src_lod, 0, 0, src_ins_size - trans_ins_size)); + } + } + } +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..76a407e3610e8bb48facf1f814779f4c23f92d98 --- /dev/null +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +struct DataBalanceOpHandle : public OpHandleBase { + public: +#ifdef PADDLE_WITH_CUDA + DataBalanceOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs); +#else + DataBalanceOpHandle(const std::vector &local_scopes, + const std::vector &places); +#endif + + std::string Name() const override; + + bool IsMultiDeviceTransfer() override { return false; }; + + protected: + void RunImpl() override; + + private: + // std::vector<(src_dev_id, dst_dev_id, trans_size)> + std::vector> GetBalancePlan( + const std::vector &batch_size_per_device); + + const std::vector local_scopes_; + const std::vector places_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 224e8e1f6efd7a894591ac51c929517cae7539ce..d646c944601e81477787740189d7ac60ae97fa80 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -67,8 +67,8 @@ void FetchOpHandle::RunImpl() { #endif } else { tensors_[i].ShareDataWith(t); - tensors_[i].set_lod(t.lod()); } + tensors_[i].set_lod(t.lod()); } this->WaitAndMergeCPUTensors(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cc7b94d0653e34c8ac711a7db7ab6ab1a9ac46a2..46d0c2769cb334f5cb75ae0ef5e48da45448c48f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/data_balance_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -215,7 +216,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else { // This op runs on all devices, and its output may have parameter's // gradients. - CreateComputationalOps(&result, *op, places_.size()); + if (op->Type() == "read" && strategy_.enable_data_balance_) { + op->SetAttr("throw_eof_exp", false); + CreateComputationalOps(&result, *op, places_.size()); + const auto &data_var_names = op->Output("Out"); + InsertDataBalanceOp(&result, data_var_names); + } else { + CreateComputationalOps(&result, *op, places_.size()); + } if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be @@ -360,6 +368,29 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, } } +void MultiDevSSAGraphBuilder::InsertDataBalanceOp( + SSAGraph *result, const std::vector &datas) const { +#ifdef PADDLE_WITH_CUDA + result->ops_.emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); +#else + result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); +#endif + auto *op_handle = result->ops_.back().get(); + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + SetCommunicationContext(op_handle, p); + for (const std::string &d_name : datas) { + auto &vars = result->vars_[i][d_name]; + PADDLE_ENFORCE(!vars.empty()); + op_handle->AddInput(vars.back().get()); + auto var = new VarHandle(vars.size(), i, d_name, p); + vars.emplace_back(var); + op_handle->AddOutput(var); + } + } +} + bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const { @@ -512,7 +543,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); // the variable name which contains .block means it was splited by // split_byref op - // so that we can balance the variable blocks to all the pserver instances. + // so that we can balance the variable blocks to all the pserver + // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && op.InputArgumentNames()[0].find(".block") == std::string::npos) { op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 0b6347bf51dc1c347073a0fdcf4ddd91865d846d..a964e024885e56693224a6199e00ff30beaa1df4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -101,6 +101,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertDataBalanceOp(SSAGraph *result, + const std::vector &datas) const; + void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 1f84c3b9e2d7ee9ae51959988fceeb3451b7b3b8..3560fabb424375a770432586fe7c8e51210b3d0c 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -58,8 +58,10 @@ void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_NOT_NULL(waited_ctx); if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { + PADDLE_ENFORCE_NOT_NULL(dev_ctx.second); dev_ctx.second->Wait(); } } else { diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 2d0611326be2fb2ddfef5cea0c11d2f935d20888..cba0064f38f89c1dd27cfac1ddb2339a5ee6c93f 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -90,6 +90,7 @@ std::string LoDToString(const LoD &lod) { LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, size_t elem_end) { PADDLE_ENFORCE_LT(level, in.size()); + PADDLE_ENFORCE_LT(elem_begin, elem_end); PADDLE_ENFORCE_LT(elem_end, in[level].size()); LoD res; @@ -393,6 +394,7 @@ void LoDTensor::MergeLoDTensor( new_dim[0] += t->dims()[0]; auto &lod = t->lod(); + PADDLE_ENFORCE_EQ(new_lod.size(), lod.size()); for (size_t j = 0; j < lod.size(); ++j) { auto &sub_lod = new_lod[j]; auto &offset = sub_lod.back(); diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 72a27d43584d55cd0859c63577ae85ff0f5fdfa8..60e4eb757668e1482090f02aea529aaad3a674d8 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -66,9 +66,19 @@ class ReadOp : public framework::OperatorBase { std::vector out_arg_names = Outputs("Out"); std::vector ins; reader->ReadNext(&ins); - PADDLE_ENFORCE(!ins.empty(), "There is no next data."); + if (ins.empty()) { + if (Attr("throw_eof_exp")) { + PADDLE_THROW("There is no next data."); + } else { + ins.resize(out_arg_names.size()); + for (auto& tensor : ins) { + // data type is not important for subsequent DataBalanceOpHandle + tensor.mutable_data(framework::make_ddim({0}), dev_place); + } + } + } PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); - for (size_t i = 0; i < ins.size(); ++i) { + for (size_t i = 0; i < out_arg_names.size(); ++i) { auto* out = scope.FindVar(out_arg_names[i])->GetMutable(); out->ShareDataWith(ins[i]); @@ -82,6 +92,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("Reader", "(ReaderHolder) The executed reader."); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); + AddAttr("throw_eof_exp", + "If set true, an exception will be thrown when the Reader " + "yields empty (which means there is no next data).") + .SetDefault(true); AddComment(R"DOC( Read Operator diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 36d080996831d4ad90d92baeafbe964693e2332a..9fc647a7d2a2bdfbaeeb91b00b4183f5c80b5aba 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle. [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; - }); + }) + .def_property( + "enable_data_balance", + [](const BuildStrategy &self) { return self.enable_data_balance_; }, + [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 3538a9c2009bb133609153427981fb66974377fa..b1e8fda03aa42f5f7528eafb46c16d55b868bae5 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -4,3 +4,5 @@ mnist_1.recordio mnist_2.recordio flowers.recordio wmt16.recordio +data_balance_test.recordio +data_balance_with_lod_test.recordio diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py new file mode 100644 index 0000000000000000000000000000000000000000..b558d7c2ea172d9c7526c865a4bc54c32f8998b6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -0,0 +1,187 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import paddle.v2 as paddle +import numpy as np + + +class TestDataBalance(unittest.TestCase): + def prepare_data(self): + def fake_data_generator(): + for n in xrange(self.total_ins_num): + yield np.ones((3, 4)) * n, n + + # Prepare data + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch( + fake_data_generator, batch_size=self.batch_size) + feeder = fluid.DataFeeder( + feed_list=[ + fluid.layers.data( + name='image', shape=[3, 4], dtype='float32'), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file( + self.data_file_name, reader, feeder) + + def prepare_lod_data(self): + def fake_data_generator(): + for n in xrange(1, self.total_ins_num + 1): + d1 = (np.ones((n, 3)) * n).astype('float32') + d2 = (np.array(n).reshape((1, 1))).astype('int32') + yield d1, d2 + + # Prepare lod data + with fluid.program_guard(fluid.Program(), fluid.Program()): + with fluid.recordio_writer.create_recordio_writer( + filename=self.lod_data_file_name) as writer: + eof = False + generator = fake_data_generator() + while (not eof): + data_batch = [ + np.array([]).reshape((0, 3)), np.array([]).reshape( + (0, 1)) + ] + lod = [0] + for _ in xrange(self.batch_size): + try: + ins = generator.next() + except StopIteration: + eof = True + break + for i, d in enumerate(ins): + data_batch[i] = np.concatenate( + (data_batch[i], d), axis=0) + lod.append(lod[-1] + ins[0].shape[0]) + if data_batch[0].shape[0] > 0: + for i, d in enumerate(data_batch): + t = fluid.LoDTensor() + t.set(data_batch[i], fluid.CPUPlace()) + if i == 0: + t.set_lod([lod]) + writer.append_tensor(t) + writer.complete_append_tensor() + + def setUp(self): + self.use_cuda = fluid.core.is_compiled_with_cuda() + self.data_file_name = './data_balance_test.recordio' + self.lod_data_file_name = './data_balance_with_lod_test.recordio' + self.total_ins_num = 50 + self.batch_size = 10 + self.prepare_data() + self.prepare_lod_data() + + def main(self): + main_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + data_reader = fluid.layers.io.open_files( + filenames=[self.data_file_name], + shapes=[[-1, 3, 4], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + if self.use_cuda: + data_reader = fluid.layers.double_buffer(data_reader) + image, label = fluid.layers.read_file(data_reader) + + place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + parallel_exe = fluid.ParallelExecutor( + use_cuda=self.use_cuda, main_program=main_prog) + + if (parallel_exe.device_count > self.batch_size): + print("WARNING: Unittest TestDataBalance skipped. \ + For the result is not correct when device count \ + is larger than batch size.") + exit(0) + fetch_list = [image.name, label.name] + + data_appeared = [False] * self.total_ins_num + while (True): + try: + image_val, label_val = parallel_exe.run(fetch_list, + return_numpy=True) + except fluid.core.EnforceNotMet as ex: + self.assertIn("There is no next data.", ex.message) + break + ins_num = image_val.shape[0] + broadcasted_label = np.ones( + (ins_num, 3, 4)) * label_val.reshape((ins_num, 1, 1)) + self.assertEqual(image_val.all(), broadcasted_label.all()) + for l in label_val: + self.assertFalse(data_appeared[l[0]]) + data_appeared[l[0]] = True + for i in data_appeared: + self.assertTrue(i) + + def main_lod(self): + main_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + data_reader = fluid.layers.io.open_files( + filenames=[self.lod_data_file_name], + shapes=[[-1, 3], [-1, 1]], + lod_levels=[1, 0], + dtypes=['float32', 'int32'], + thread_num=1) + ins, label = fluid.layers.read_file(data_reader) + + place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + parallel_exe = fluid.ParallelExecutor( + use_cuda=self.use_cuda, main_program=main_prog) + + if (parallel_exe.device_count > self.batch_size): + print("WARNING: Unittest TestDataBalance skipped. \ + For the result is not correct when device count \ + is larger than batch size.") + exit(0) + fetch_list = [ins.name, label.name] + + data_appeared = [False] * self.total_ins_num + while (True): + try: + ins_tensor, label_tensor = parallel_exe.run( + fetch_list, return_numpy=False) + except fluid.core.EnforceNotMet as ex: + self.assertIn("There is no next data.", ex.message) + break + + ins_val = np.array(ins_tensor) + label_val = np.array(label_tensor) + ins_lod = ins_tensor.lod()[0] + self.assertEqual(ins_val.shape[1], 3) + self.assertEqual(label_val.shape[1], 1) + self.assertEqual(len(ins_lod) - 1, label_val.shape[0]) + for i in range(0, len(ins_lod) - 1): + ins_elem = ins_val[ins_lod[i]:ins_lod[i + 1]][:] + label_elem = label_val[i][0] + self.assertEqual(ins_elem.all(), label_elem.all()) + self.assertFalse(data_appeared[int(label_elem - 1)]) + data_appeared[int(label_elem - 1)] = True + + for i in data_appeared: + self.assertTrue(i) + + def test_all(self): + self.main() + self.main_lod()