From daba57f752b72be55fc5cdad86de2d5f52bb261c Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 4 Dec 2018 16:50:59 +0800 Subject: [PATCH] complete ctr_reader --- .../operators/reader/create_ctr_reader_op.cc | 5 ++- paddle/fluid/operators/reader/ctr_reader.cc | 6 +-- paddle/fluid/operators/reader/ctr_reader.h | 36 +++++++++++++++- paddle/fluid/operators/reader/read_op.cc | 41 +++++++++++------- .../operators/reader/reader_op_registry.cc | 34 +++++++++------ paddle/fluid/pybind/pybind.cc | 1 + python/paddle/fluid/contrib/__init__.py | 3 ++ .../paddle/fluid/contrib/reader/__init__.py | 19 +++++++++ .../paddle/fluid/contrib/reader/ctr_reader.py | 42 +++++++++++++++---- python/setup.py.in | 1 + 10 files changed, 143 insertions(+), 45 deletions(-) create mode 100644 python/paddle/fluid/contrib/reader/__init__.py diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc index 5b9e2ba69..2a3e80c91 100644 --- a/paddle/fluid/operators/reader/create_ctr_reader_op.cc +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -51,6 +51,7 @@ class CreateCTRReaderOp : public framework::OperatorBase { auto file_list = Attr>("file_list"); DataDesc data_desc(batch_size, file_list, file_type, file_format, dense_slot_index, sparse_slot_index, sparse_slots); + VLOG(1) << data_desc; out->Reset(std::make_shared(queue_holder->GetQueue(), thread_num, data_desc)); } @@ -69,10 +70,10 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { "The list of files that need to read"); AddAttr>( "dense_slot_index", - "the sparse slots id that should be extract from file") + "the dense slots id that should be extract from file") .SetDefault({}); AddAttr>( - "dense_slot_index", + "sparse_slot_index", "the sparse slots id that should be extract from file") .SetDefault({}); AddAttr>("sparse_slots", diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 3595d771b..946f17750 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -157,8 +157,8 @@ void MonitorThread(std::vector* thread_status, } std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } - VLOG(3) << "all reader thread is stopped, push empty data into queue"; - queue->Push({}); + VLOG(3) << "all reader thread is stopped, close the queue"; + queue->Close(); VLOG(3) << "monitor thread exited"; } @@ -247,7 +247,7 @@ static inline void parse_csv_line( int slot_idx = data_desc.dense_slot_index_[i]; auto& slot_data = ret[slot_idx]; std::vector data_in_slot_str; - string_split(ret[slot_idx], ',', &data_in_slot_str); + string_split(slot_data, ',', &data_in_slot_str); std::vector data_in_slot; for (auto& data_str : data_in_slot_str) { (*dense_datas)[i].push_back(std::stof(data_str)); diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 1f4663e3b..eef6c11aa 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -60,6 +60,35 @@ struct DataDesc { const std::vector sparse_slot_ids_; }; +inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) { + os << "data_desc:\n"; + os << "\tbatch_size -> " << data_desc.batch_size_ << "\n"; + os << "\tfile_type -> " << data_desc.file_type_ << "\n"; + os << "\tfile_format -> " << data_desc.file_format_ << "\n"; + os << "\tfile_names -> {"; + for (auto& file_name : data_desc.file_names_) { + os << file_name << ","; + } + os << "}\n"; + os << "\tdense_slot_index -> {"; + for (auto& slot : data_desc.dense_slot_index_) { + os << slot << ","; + } + os << "}\n"; + os << "\tsparse_slot_index_ -> {"; + for (auto& slot : data_desc.sparse_slot_index_) { + os << slot << ","; + } + os << "}\n"; + os << "\tsparse_slot_ids_ -> {"; + for (auto& slot : data_desc.sparse_slot_ids_) { + os << slot << ","; + } + os << "}\n"; + + return os; +} + void ReadThread(const std::vector& file_list, const DataDesc& data_desc, int thread_id, std::vector* thread_status, @@ -89,7 +118,7 @@ class CTRReader : public framework::FileReader { } } - ~CTRReader() {} + ~CTRReader() { Shutdown(); } void ReadNext(std::vector* out) override { bool success; @@ -106,7 +135,10 @@ class CTRReader : public framework::FileReader { for (auto& read_thread : read_threads_) { read_thread->join(); } - monitor_thread_->join(); + + if (monitor_thread_) { + monitor_thread_->join(); + } read_threads_.clear(); monitor_thread_.reset(nullptr); diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index a0b70938d..97faade04 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -27,15 +27,16 @@ class ReadInferShape : public framework::InferShapeBase { "The ReadOp must take a reader as input."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "The ReadOp should be assigned with output."); - std::vector reader_dims = ctx->GetReaderDims("Reader"); - std::vector out_names = ctx->Outputs("Out"); - PADDLE_ENFORCE_EQ( - reader_dims.size(), out_names.size(), - "The reader's dim number doesn't match the output number."); - ctx->SetOutputsDim("Out", reader_dims); - if (!ctx->IsRuntime()) { + if (!ctx->IsRuntime() && ctx->Attrs().Get("infer_out")) { + std::vector reader_dims = ctx->GetReaderDims("Reader"); + std::vector out_names = ctx->Outputs("Out"); + PADDLE_ENFORCE_EQ( + reader_dims.size(), out_names.size(), + "The reader's dim number doesn't match the output number."); + ctx->SetOutputsDim("Out", reader_dims); auto in_desc = boost::get(ctx->GetInputVarPtrs("Reader")[0]); + std::cout << in_desc->Proto()->SerializeAsString() << std::endl; auto in_lod_levels = in_desc->GetLoDLevels(); auto out_var_ptrs = ctx->GetOutputVarPtrs("Out"); PADDLE_ENFORCE_EQ(in_lod_levels.size(), out_var_ptrs.size(), @@ -53,15 +54,18 @@ class ReadInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { - std::string reader_name = op_desc.Input("Reader")[0]; - std::vector out_names = op_desc.Output("Out"); - framework::VarDesc* reader = block->FindVarRecursive(reader_name); - auto dtypes = reader->GetDataTypes(); - PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); - for (size_t i = 0; i < dtypes.size(); ++i) { - framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); - out.SetType(framework::proto::VarType::LOD_TENSOR); - out.SetDataType(dtypes[i]); + bool infer_out = boost::get(op_desc.GetAttr("infer_out")); + if (infer_out) { + std::string reader_name = op_desc.Input("Reader")[0]; + std::vector out_names = op_desc.Output("Out"); + framework::VarDesc* reader = block->FindVarRecursive(reader_name); + auto dtypes = reader->GetDataTypes(); + PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); + for (size_t i = 0; i < dtypes.size(); ++i) { + framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); + out.SetType(framework::proto::VarType::LOD_TENSOR); + out.SetDataType(dtypes[i]); + } } } }; @@ -73,6 +77,7 @@ class ReadOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { + VLOG(3) << "read op in"; framework::ReaderHolder* reader = detail::Ref(scope.FindVar(Input("Reader")), "Cannot find reader variable %s", Input("Reader")) @@ -87,7 +92,9 @@ class ReadOp : public framework::OperatorBase { reader->ReadNext(&ins); if (ins.empty()) { + VLOG(3) << "read empty data in"; if (Attr("throw_eof_exp")) { + VLOG(3) << "throw_eof_exp"; PADDLE_THROW_EOF(); } else { ins.resize(out_arg_names.size()); @@ -96,6 +103,7 @@ class ReadOp : public framework::OperatorBase { tensor.mutable_data(framework::make_ddim({0}), dev_place); } } + VLOG(3) << "read empty data out"; } PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); for (size_t i = 0; i < out_arg_names.size(); ++i) { @@ -120,6 +128,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { " only when the data-balance is enabled in ParallelExecutor" " and it is set by ParallelExecutor instance, not users.") .SetDefault(true); + AddAttr("infer_out", "").SetDefault(true); AddComment(R"DOC( Read Operator diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index b82aab121..3921eedf9 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() { "It means the reader will generate two data each time," "whose shapes are [2,3,4] and [5,6] respectively."); AddAttr>("lod_levels", "The LoD levels of each data."); + AddAttr( + "use_data_config", + "Use the config of all datas like shape_concat/ranks/lod_levels") + .SetDefault(true); Apply(); } @@ -75,19 +79,23 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasOutput("Out"), "The output file reader should not be null."); - const auto shape_concat = ctx->Attrs().Get>("shape_concat"); - const auto ranks = ctx->Attrs().Get>("ranks"); - std::vector shapes = RestoreShapes(shape_concat, ranks); - ctx->SetReaderDims("Out", shapes); - - const auto lod_levels = ctx->Attrs().Get>("lod_levels"); - PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), - "The number of 'lod_levels'(%d) doesn't match the number " - "of 'shapes'(%d).", - lod_levels.size(), shapes.size()); - framework::VarDesc* reader = - boost::get(ctx->GetOutputVarPtrs("Out")[0]); - reader->SetLoDLevels(lod_levels); + bool use_data_config = ctx->Attrs().Get("use_data_config"); + if (use_data_config) { + const auto shape_concat = + ctx->Attrs().Get>("shape_concat"); + const auto ranks = ctx->Attrs().Get>("ranks"); + std::vector shapes = RestoreShapes(shape_concat, ranks); + ctx->SetReaderDims("Out", shapes); + + const auto lod_levels = ctx->Attrs().Get>("lod_levels"); + PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), + "The number of 'lod_levels'(%d) doesn't match the number " + "of 'shapes'(%d).", + lod_levels.size(), shapes.size()); + framework::VarDesc* reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + reader->SetLoDLevels(lod_levels); + } } void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f0a5d1afc..681b213b4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -364,6 +364,7 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference); py::class_(m, "Reader", "") + .def("start", &framework::ReaderHolder::Start) .def("reset", &framework::ReaderHolder::ResetAll); using LoDTensorBlockingQueue = diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 3bf2fe5db..5d4b15772 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -22,9 +22,12 @@ from . import op_frequence from .op_frequence import * from . import quantize from .quantize import * +from . import reader +from .reader import * __all__ = [] __all__ += decoder.__all__ __all__ += memory_usage_calc.__all__ __all__ += op_frequence.__all__ __all__ += quantize.__all__ +__all__ += reader.__all__ diff --git a/python/paddle/fluid/contrib/reader/__init__.py b/python/paddle/fluid/contrib/reader/__init__.py new file mode 100644 index 000000000..4cf85ffc1 --- /dev/null +++ b/python/paddle/fluid/contrib/reader/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import ctr_reader + +__all__ = ctr_reader.__all__ diff --git a/python/paddle/fluid/contrib/reader/ctr_reader.py b/python/paddle/fluid/contrib/reader/ctr_reader.py index d7133562d..aad8ded87 100644 --- a/python/paddle/fluid/contrib/reader/ctr_reader.py +++ b/python/paddle/fluid/contrib/reader/ctr_reader.py @@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \ default_startup_program, Variable from paddle.fluid.unique_name import generate as unique_name +__all__ = ['ctr_reader'] + def monkey_patch_reader_methods(reader): def __get_reader__(): @@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader): def reset(): return __get_reader__().reset() + def start(): + return __get_reader__().start() + reader.reset = reset + reader.start = start reader.stop_gradient = True reader.persistable = True return reader @@ -44,13 +50,18 @@ def _copy_reader_var_(block, var): return new_var -def ctr_reader(feed_data, - capacity, - thread_num, - batch_size, - file_list, - slots, - name=None): +def ctr_reader( + feed_dict, + file_type, # gzip or plain + file_format, # csv or svm + dense_slot_indexs, + sparse_slot_indexs, + capacity, + thread_num, + batch_size, + file_list, + slots, + name=None): """ Create a CTR reader for data feeding in Python @@ -99,12 +110,22 @@ def ctr_reader(feed_data, inputs={'blocking_queue': [queue_name]}, outputs={'Out': [reader_var]}, attrs={ + 'use_data_config': False, 'thread_num': thread_num, 'batch_size': batch_size, 'file_list': file_list, - 'slots': slots, + 'file_type': file_type, + 'file_format': file_format, + 'dense_slot_index': dense_slot_indexs, + 'sparse_slot_index': sparse_slot_indexs, + 'sparse_slots': slots, + 'ranks': [], + 'lod_levels': [], + 'shape_concat': [] }) + dtypes = [data.dtype for data in feed_dict] + reader_var.desc.set_dtypes(dtypes) reader_var.persistable = True main_prog_reader_var = _copy_reader_var_( @@ -118,6 +139,9 @@ def ctr_reader(feed_data, main_blk = default_main_program().current_block() main_blk.append_op( - type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data}) + type='read', + inputs={'Reader': [reader]}, + attrs={'infer_out': False}, + outputs={'Out': feed_dict}) return reader diff --git a/python/setup.py.in b/python/setup.py.in index 200b96ec5..d5d82f643 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -107,6 +107,7 @@ packages=['paddle', 'paddle.fluid.contrib', 'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.quantize', + 'paddle.fluid.contrib.reader', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details'] -- GitLab