提交 daba57f7 编写于 作者: Q Qiao Longfei

complete ctr_reader

上级 9f53aad1
...@@ -51,6 +51,7 @@ class CreateCTRReaderOp : public framework::OperatorBase { ...@@ -51,6 +51,7 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto file_list = Attr<std::vector<std::string>>("file_list"); auto file_list = Attr<std::vector<std::string>>("file_list");
DataDesc data_desc(batch_size, file_list, file_type, file_format, DataDesc data_desc(batch_size, file_list, file_type, file_format,
dense_slot_index, sparse_slot_index, sparse_slots); dense_slot_index, sparse_slot_index, sparse_slots);
VLOG(1) << data_desc;
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num, out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num,
data_desc)); data_desc));
} }
...@@ -69,10 +70,10 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { ...@@ -69,10 +70,10 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"The list of files that need to read"); "The list of files that need to read");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"dense_slot_index", "dense_slot_index",
"the sparse slots id that should be extract from file") "the dense slots id that should be extract from file")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"dense_slot_index", "sparse_slot_index",
"the sparse slots id that should be extract from file") "the sparse slots id that should be extract from file")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>("sparse_slots", AddAttr<std::vector<std::string>>("sparse_slots",
......
...@@ -157,8 +157,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -157,8 +157,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
} }
std::this_thread::sleep_for(std::chrono::milliseconds(1000)); std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(3) << "all reader thread is stopped, push empty data into queue"; VLOG(3) << "all reader thread is stopped, close the queue";
queue->Push({}); queue->Close();
VLOG(3) << "monitor thread exited"; VLOG(3) << "monitor thread exited";
} }
...@@ -247,7 +247,7 @@ static inline void parse_csv_line( ...@@ -247,7 +247,7 @@ static inline void parse_csv_line(
int slot_idx = data_desc.dense_slot_index_[i]; int slot_idx = data_desc.dense_slot_index_[i];
auto& slot_data = ret[slot_idx]; auto& slot_data = ret[slot_idx];
std::vector<std::string> data_in_slot_str; std::vector<std::string> data_in_slot_str;
string_split(ret[slot_idx], ',', &data_in_slot_str); string_split(slot_data, ',', &data_in_slot_str);
std::vector<float> data_in_slot; std::vector<float> data_in_slot;
for (auto& data_str : data_in_slot_str) { for (auto& data_str : data_in_slot_str) {
(*dense_datas)[i].push_back(std::stof(data_str)); (*dense_datas)[i].push_back(std::stof(data_str));
......
...@@ -60,6 +60,35 @@ struct DataDesc { ...@@ -60,6 +60,35 @@ struct DataDesc {
const std::vector<std::string> sparse_slot_ids_; const std::vector<std::string> sparse_slot_ids_;
}; };
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
os << "data_desc:\n";
os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
os << "\tfile_type -> " << data_desc.file_type_ << "\n";
os << "\tfile_format -> " << data_desc.file_format_ << "\n";
os << "\tfile_names -> {";
for (auto& file_name : data_desc.file_names_) {
os << file_name << ",";
}
os << "}\n";
os << "\tdense_slot_index -> {";
for (auto& slot : data_desc.dense_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_index_ -> {";
for (auto& slot : data_desc.sparse_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_ids_ -> {";
for (auto& slot : data_desc.sparse_slot_ids_) {
os << slot << ",";
}
os << "}\n";
return os;
}
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id, const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status, std::vector<ReaderThreadStatus>* thread_status,
...@@ -89,7 +118,7 @@ class CTRReader : public framework::FileReader { ...@@ -89,7 +118,7 @@ class CTRReader : public framework::FileReader {
} }
} }
~CTRReader() {} ~CTRReader() { Shutdown(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
...@@ -106,7 +135,10 @@ class CTRReader : public framework::FileReader { ...@@ -106,7 +135,10 @@ class CTRReader : public framework::FileReader {
for (auto& read_thread : read_threads_) { for (auto& read_thread : read_threads_) {
read_thread->join(); read_thread->join();
} }
monitor_thread_->join();
if (monitor_thread_) {
monitor_thread_->join();
}
read_threads_.clear(); read_threads_.clear();
monitor_thread_.reset(nullptr); monitor_thread_.reset(nullptr);
......
...@@ -27,15 +27,16 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -27,15 +27,16 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input."); "The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output."); "The ReadOp should be assigned with output.");
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader"); if (!ctx->IsRuntime() && ctx->Attrs().Get<bool>("infer_out")) {
std::vector<std::string> out_names = ctx->Outputs("Out"); std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
PADDLE_ENFORCE_EQ( std::vector<std::string> out_names = ctx->Outputs("Out");
reader_dims.size(), out_names.size(), PADDLE_ENFORCE_EQ(
"The reader's dim number doesn't match the output number."); reader_dims.size(), out_names.size(),
ctx->SetOutputsDim("Out", reader_dims); "The reader's dim number doesn't match the output number.");
if (!ctx->IsRuntime()) { ctx->SetOutputsDim("Out", reader_dims);
auto in_desc = auto in_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]); boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]);
std::cout << in_desc->Proto()->SerializeAsString() << std::endl;
auto in_lod_levels = in_desc->GetLoDLevels(); auto in_lod_levels = in_desc->GetLoDLevels();
auto out_var_ptrs = ctx->GetOutputVarPtrs("Out"); auto out_var_ptrs = ctx->GetOutputVarPtrs("Out");
PADDLE_ENFORCE_EQ(in_lod_levels.size(), out_var_ptrs.size(), PADDLE_ENFORCE_EQ(in_lod_levels.size(), out_var_ptrs.size(),
...@@ -53,15 +54,18 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -53,15 +54,18 @@ class ReadInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Input("Reader")[0]; bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
std::vector<std::string> out_names = op_desc.Output("Out"); if (infer_out) {
framework::VarDesc* reader = block->FindVarRecursive(reader_name); std::string reader_name = op_desc.Input("Reader")[0];
auto dtypes = reader->GetDataTypes(); std::vector<std::string> out_names = op_desc.Output("Out");
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
for (size_t i = 0; i < dtypes.size(); ++i) { auto dtypes = reader->GetDataTypes();
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
out.SetType(framework::proto::VarType::LOD_TENSOR); for (size_t i = 0; i < dtypes.size(); ++i) {
out.SetDataType(dtypes[i]); framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]);
}
} }
} }
}; };
...@@ -73,6 +77,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -73,6 +77,7 @@ class ReadOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
VLOG(3) << "read op in";
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
detail::Ref(scope.FindVar(Input("Reader")), detail::Ref(scope.FindVar(Input("Reader")),
"Cannot find reader variable %s", Input("Reader")) "Cannot find reader variable %s", Input("Reader"))
...@@ -87,7 +92,9 @@ class ReadOp : public framework::OperatorBase { ...@@ -87,7 +92,9 @@ class ReadOp : public framework::OperatorBase {
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { if (ins.empty()) {
VLOG(3) << "read empty data in";
if (Attr<bool>("throw_eof_exp")) { if (Attr<bool>("throw_eof_exp")) {
VLOG(3) << "throw_eof_exp";
PADDLE_THROW_EOF(); PADDLE_THROW_EOF();
} else { } else {
ins.resize(out_arg_names.size()); ins.resize(out_arg_names.size());
...@@ -96,6 +103,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -96,6 +103,7 @@ class ReadOp : public framework::OperatorBase {
tensor.mutable_data<float>(framework::make_ddim({0}), dev_place); tensor.mutable_data<float>(framework::make_ddim({0}), dev_place);
} }
} }
VLOG(3) << "read empty data out";
} }
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) { for (size_t i = 0; i < out_arg_names.size(); ++i) {
...@@ -120,6 +128,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,6 +128,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" only when the data-balance is enabled in ParallelExecutor" " only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.") " and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() { ...@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time," "It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."); "whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data."); AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<bool>(
"use_data_config",
"Use the config of all datas like shape_concat/ranks/lod_levels")
.SetDefault(true);
Apply(); Apply();
} }
...@@ -75,19 +79,23 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -75,19 +79,23 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null."); "The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat"); bool use_data_config = ctx->Attrs().Get<bool>("use_data_config");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks"); if (use_data_config) {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); const auto shape_concat =
ctx->SetReaderDims("Out", shapes); ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels"); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), ctx->SetReaderDims("Out", shapes);
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).", const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
lod_levels.size(), shapes.size()); PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
framework::VarDesc* reader = "The number of 'lod_levels'(%d) doesn't match the number "
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); "of 'shapes'(%d).",
reader->SetLoDLevels(lod_levels); lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
} }
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
......
...@@ -364,6 +364,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -364,6 +364,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll); .def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
......
...@@ -22,9 +22,12 @@ from . import op_frequence ...@@ -22,9 +22,12 @@ from . import op_frequence
from .op_frequence import * from .op_frequence import *
from . import quantize from . import quantize
from .quantize import * from .quantize import *
from . import reader
from .reader import *
__all__ = [] __all__ = []
__all__ += decoder.__all__ __all__ += decoder.__all__
__all__ += memory_usage_calc.__all__ __all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__ __all__ += op_frequence.__all__
__all__ += quantize.__all__ __all__ += quantize.__all__
__all__ += reader.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import ctr_reader
__all__ = ctr_reader.__all__
...@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \ ...@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \
default_startup_program, Variable default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name from paddle.fluid.unique_name import generate as unique_name
__all__ = ['ctr_reader']
def monkey_patch_reader_methods(reader): def monkey_patch_reader_methods(reader):
def __get_reader__(): def __get_reader__():
...@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader): ...@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader):
def reset(): def reset():
return __get_reader__().reset() return __get_reader__().reset()
def start():
return __get_reader__().start()
reader.reset = reset reader.reset = reset
reader.start = start
reader.stop_gradient = True reader.stop_gradient = True
reader.persistable = True reader.persistable = True
return reader return reader
...@@ -44,13 +50,18 @@ def _copy_reader_var_(block, var): ...@@ -44,13 +50,18 @@ def _copy_reader_var_(block, var):
return new_var return new_var
def ctr_reader(feed_data, def ctr_reader(
capacity, feed_dict,
thread_num, file_type, # gzip or plain
batch_size, file_format, # csv or svm
file_list, dense_slot_indexs,
slots, sparse_slot_indexs,
name=None): capacity,
thread_num,
batch_size,
file_list,
slots,
name=None):
""" """
Create a CTR reader for data feeding in Python Create a CTR reader for data feeding in Python
...@@ -99,12 +110,22 @@ def ctr_reader(feed_data, ...@@ -99,12 +110,22 @@ def ctr_reader(feed_data,
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [reader_var]}, outputs={'Out': [reader_var]},
attrs={ attrs={
'use_data_config': False,
'thread_num': thread_num, 'thread_num': thread_num,
'batch_size': batch_size, 'batch_size': batch_size,
'file_list': file_list, 'file_list': file_list,
'slots': slots, 'file_type': file_type,
'file_format': file_format,
'dense_slot_index': dense_slot_indexs,
'sparse_slot_index': sparse_slot_indexs,
'sparse_slots': slots,
'ranks': [],
'lod_levels': [],
'shape_concat': []
}) })
dtypes = [data.dtype for data in feed_dict]
reader_var.desc.set_dtypes(dtypes)
reader_var.persistable = True reader_var.persistable = True
main_prog_reader_var = _copy_reader_var_( main_prog_reader_var = _copy_reader_var_(
...@@ -118,6 +139,9 @@ def ctr_reader(feed_data, ...@@ -118,6 +139,9 @@ def ctr_reader(feed_data,
main_blk = default_main_program().current_block() main_blk = default_main_program().current_block()
main_blk.append_op( main_blk.append_op(
type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data}) type='read',
inputs={'Reader': [reader]},
attrs={'infer_out': False},
outputs={'Out': feed_dict})
return reader return reader
...@@ -107,6 +107,7 @@ packages=['paddle', ...@@ -107,6 +107,7 @@ packages=['paddle',
'paddle.fluid.contrib', 'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize', 'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.reader',
'paddle.fluid.transpiler', 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details'] 'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册