提交 daba57f7 编写于 作者: Q Qiao Longfei

complete ctr_reader

上级 9f53aad1
......@@ -51,6 +51,7 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto file_list = Attr<std::vector<std::string>>("file_list");
DataDesc data_desc(batch_size, file_list, file_type, file_format,
dense_slot_index, sparse_slot_index, sparse_slots);
VLOG(1) << data_desc;
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num,
data_desc));
}
......@@ -69,10 +70,10 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"The list of files that need to read");
AddAttr<std::vector<int>>(
"dense_slot_index",
"the sparse slots id that should be extract from file")
"the dense slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<int>>(
"dense_slot_index",
"sparse_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>("sparse_slots",
......
......@@ -157,8 +157,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "all reader thread is stopped, push empty data into queue";
queue->Push({});
VLOG(3) << "all reader thread is stopped, close the queue";
queue->Close();
VLOG(3) << "monitor thread exited";
}
......@@ -247,7 +247,7 @@ static inline void parse_csv_line(
int slot_idx = data_desc.dense_slot_index_[i];
auto& slot_data = ret[slot_idx];
std::vector<std::string> data_in_slot_str;
string_split(ret[slot_idx], ',', &data_in_slot_str);
string_split(slot_data, ',', &data_in_slot_str);
std::vector<float> data_in_slot;
for (auto& data_str : data_in_slot_str) {
(*dense_datas)[i].push_back(std::stof(data_str));
......
......@@ -60,6 +60,35 @@ struct DataDesc {
const std::vector<std::string> sparse_slot_ids_;
};
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
os << "data_desc:\n";
os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
os << "\tfile_type -> " << data_desc.file_type_ << "\n";
os << "\tfile_format -> " << data_desc.file_format_ << "\n";
os << "\tfile_names -> {";
for (auto& file_name : data_desc.file_names_) {
os << file_name << ",";
}
os << "}\n";
os << "\tdense_slot_index -> {";
for (auto& slot : data_desc.dense_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_index_ -> {";
for (auto& slot : data_desc.sparse_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_ids_ -> {";
for (auto& slot : data_desc.sparse_slot_ids_) {
os << slot << ",";
}
os << "}\n";
return os;
}
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
......@@ -89,7 +118,7 @@ class CTRReader : public framework::FileReader {
}
}
~CTRReader() {}
~CTRReader() { Shutdown(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
......@@ -106,7 +135,10 @@ class CTRReader : public framework::FileReader {
for (auto& read_thread : read_threads_) {
read_thread->join();
}
if (monitor_thread_) {
monitor_thread_->join();
}
read_threads_.clear();
monitor_thread_.reset(nullptr);
......
......@@ -27,15 +27,16 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output.");
if (!ctx->IsRuntime() && ctx->Attrs().Get<bool>("infer_out")) {
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<std::string> out_names = ctx->Outputs("Out");
PADDLE_ENFORCE_EQ(
reader_dims.size(), out_names.size(),
"The reader's dim number doesn't match the output number.");
ctx->SetOutputsDim("Out", reader_dims);
if (!ctx->IsRuntime()) {
auto in_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]);
std::cout << in_desc->Proto()->SerializeAsString() << std::endl;
auto in_lod_levels = in_desc->GetLoDLevels();
auto out_var_ptrs = ctx->GetOutputVarPtrs("Out");
PADDLE_ENFORCE_EQ(in_lod_levels.size(), out_var_ptrs.size(),
......@@ -53,6 +54,8 @@ class ReadInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
if (infer_out) {
std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
......@@ -64,6 +67,7 @@ class ReadInferVarType : public framework::VarTypeInference {
out.SetDataType(dtypes[i]);
}
}
}
};
class ReadOp : public framework::OperatorBase {
......@@ -73,6 +77,7 @@ class ReadOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
VLOG(3) << "read op in";
framework::ReaderHolder* reader =
detail::Ref(scope.FindVar(Input("Reader")),
"Cannot find reader variable %s", Input("Reader"))
......@@ -87,7 +92,9 @@ class ReadOp : public framework::OperatorBase {
reader->ReadNext(&ins);
if (ins.empty()) {
VLOG(3) << "read empty data in";
if (Attr<bool>("throw_eof_exp")) {
VLOG(3) << "throw_eof_exp";
PADDLE_THROW_EOF();
} else {
ins.resize(out_arg_names.size());
......@@ -96,6 +103,7 @@ class ReadOp : public framework::OperatorBase {
tensor.mutable_data<float>(framework::make_ddim({0}), dev_place);
}
}
VLOG(3) << "read empty data out";
}
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) {
......@@ -120,6 +128,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.")
.SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true);
AddComment(R"DOC(
Read Operator
......
......@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<bool>(
"use_data_config",
"Use the config of all datas like shape_concat/ranks/lod_levels")
.SetDefault(true);
Apply();
}
......@@ -75,7 +79,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
bool use_data_config = ctx->Attrs().Get<bool>("use_data_config");
if (use_data_config) {
const auto shape_concat =
ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
......@@ -88,6 +95,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
}
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
......
......@@ -364,6 +364,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue =
......
......@@ -22,9 +22,12 @@ from . import op_frequence
from .op_frequence import *
from . import quantize
from .quantize import *
from . import reader
from .reader import *
__all__ = []
__all__ += decoder.__all__
__all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__
__all__ += quantize.__all__
__all__ += reader.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import ctr_reader
__all__ = ctr_reader.__all__
......@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \
default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name
__all__ = ['ctr_reader']
def monkey_patch_reader_methods(reader):
def __get_reader__():
......@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader):
def reset():
return __get_reader__().reset()
def start():
return __get_reader__().start()
reader.reset = reset
reader.start = start
reader.stop_gradient = True
reader.persistable = True
return reader
......@@ -44,7 +50,12 @@ def _copy_reader_var_(block, var):
return new_var
def ctr_reader(feed_data,
def ctr_reader(
feed_dict,
file_type, # gzip or plain
file_format, # csv or svm
dense_slot_indexs,
sparse_slot_indexs,
capacity,
thread_num,
batch_size,
......@@ -99,12 +110,22 @@ def ctr_reader(feed_data,
inputs={'blocking_queue': [queue_name]},
outputs={'Out': [reader_var]},
attrs={
'use_data_config': False,
'thread_num': thread_num,
'batch_size': batch_size,
'file_list': file_list,
'slots': slots,
'file_type': file_type,
'file_format': file_format,
'dense_slot_index': dense_slot_indexs,
'sparse_slot_index': sparse_slot_indexs,
'sparse_slots': slots,
'ranks': [],
'lod_levels': [],
'shape_concat': []
})
dtypes = [data.dtype for data in feed_dict]
reader_var.desc.set_dtypes(dtypes)
reader_var.persistable = True
main_prog_reader_var = _copy_reader_var_(
......@@ -118,6 +139,9 @@ def ctr_reader(feed_data,
main_blk = default_main_program().current_block()
main_blk.append_op(
type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data})
type='read',
inputs={'Reader': [reader]},
attrs={'infer_out': False},
outputs={'Out': feed_dict})
return reader
......@@ -107,6 +107,7 @@ packages=['paddle',
'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.reader',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册