提交 a59b7ac7 编写于 作者: J JiabinYang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/imperative

...@@ -213,6 +213,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act ...@@ -213,6 +213,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.shuffle_channel ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)) paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0))
...@@ -359,6 +360,7 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b ...@@ -359,6 +360,7 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b
paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.contrib.reader.ctr_reader.ctr_reader ArgSpec(args=['feed_dict', 'file_type', 'file_format', 'dense_slot_index', 'sparse_slot_index', 'capacity', 'thread_num', 'batch_size', 'file_list', 'slots', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.build_compressor ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) paddle.fluid.contrib.build_compressor ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None))
paddle.fluid.contrib.CompressPass.__init__ ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) paddle.fluid.contrib.CompressPass.__init__ ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None))
paddle.fluid.contrib.CompressPass.add_strategy ArgSpec(args=['self', 'strategy'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.CompressPass.add_strategy ArgSpec(args=['self', 'strategy'], varargs=None, keywords=None, defaults=None)
......
...@@ -25,7 +25,10 @@ struct ExecutionStrategy { ...@@ -25,7 +25,10 @@ struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_cuda_{true}; bool use_cuda_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{1}; // If we set this to 1, we will delete all variables when finish a batch. and
// this will loss 15%+ performance.
// Please be aware about this parameters.
size_t num_iteration_per_drop_scope_{100};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
}; };
......
...@@ -117,8 +117,9 @@ bool VariableResponse::CopyLodTensorData( ...@@ -117,8 +117,9 @@ bool VariableResponse::CopyLodTensorData(
tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length << ", dims:" << dims
PADDLE_ENFORCE_EQ(tensor->memory_size(), static_cast<unsigned int>(length)); << ", numel:" << tensor->numel();
PADDLE_ENFORCE_GE(tensor->memory_size(), static_cast<unsigned int>(length));
return ReadRaw(input, ctx, tensor->place(), tensor_data, length); return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
} }
......
...@@ -21,5 +21,5 @@ endif() ...@@ -21,5 +21,5 @@ endif()
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS}) cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper) cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
if(NOT WIN32) if(NOT WIN32)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper device_tracer) cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper device_tracer tensor)
endif() endif()
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h" #include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -155,14 +156,22 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -155,14 +156,22 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
LOG(INFO) << loginfos.str(); LOG(INFO) << loginfos.str();
} }
using Tensor = paddle::framework::Tensor;
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() { void BenchXYZNKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
std::vector<T> x(d), y(d), z(d); Tensor x, y, z;
RandomVec<T>(d, x.data()); x.Resize({d});
RandomVec<T>(d, y.data()); y.Resize({d});
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(), z.Resize({d});
z.data(), d); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
T* z_data = z.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
y.data<T>(), z_data, d);
} }
} }
...@@ -170,9 +179,13 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> ...@@ -170,9 +179,13 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() { void BenchAXYNKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const T a = static_cast<T>(3); const T a = static_cast<T>(3);
std::vector<T> x(d), y(d); Tensor x, y;
RandomVec<T>(d, x.data()); x.Resize({d});
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(), y.Resize({d});
T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
d); d);
} }
} }
...@@ -180,9 +193,13 @@ void BenchAXYNKernel() { ...@@ -180,9 +193,13 @@ void BenchAXYNKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() { void BenchXYNKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
std::vector<T> x(d), y(d); Tensor x, y;
RandomVec<T>(d, x.data()); x.Resize({d});
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d); y.Resize({d});
T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d);
} }
} }
...@@ -192,16 +209,23 @@ void BenchLSTMKernel() { ...@@ -192,16 +209,23 @@ void BenchLSTMKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
use_peephole); use_peephole);
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d); Tensor x, ct_1, ct, ht, wp, checked;
RandomVec<T>(4 * d, x.data(), -2.f, 2.f); x.Resize({4 * d});
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f); ct_1.Resize({d});
RandomVec<T>(d, ct_1.data(), -2.f, 2.f); ct.Resize({d});
const T* ct_1_data = ct_1.data(); ht.Resize({d});
const T* wp_data = wp.data(); wp.Resize({3 * d});
T* x_data = x.data(); checked.Resize({2 * d});
T* checked_data = checked.data(); auto place = PlaceType();
T* ct_data = ct.data(); RandomVec<T>(x.numel(), x.mutable_data<T>(place), -2.f, 2.f);
T* ht_data = ht.data(); RandomVec<T>(wp.numel(), wp.mutable_data<T>(place), -2.f, 2.f);
RandomVec<T>(ct_1.numel(), ct_1.mutable_data<T>(place), -2.f, 2.f);
const T* ct_1_data = ct_1.data<T>();
const T* wp_data = wp.data<T>();
T* x_data = x.mutable_data<T>(place);
T* checked_data = checked.mutable_data<T>(place);
T* ct_data = ct.mutable_data<T>(place);
T* ht_data = ht.mutable_data<T>(place);
jit::lstm_t step; jit::lstm_t step;
step.gates = x_data; step.gates = x_data;
step.ct_1 = ct_1_data; step.ct_1 = ct_1_data;
...@@ -220,12 +244,16 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> ...@@ -220,12 +244,16 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() { void BenchGRUKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
std::vector<T> x(3 * d), ht_1(d), ht(d); auto place = PlaceType();
RandomVec<T>(3 * d, x.data(), -2.f, 2.f); Tensor x, ht_1, ht;
RandomVec<T>(d, ht_1.data(), -2.f, 2.f); x.Resize({3 * d});
const T* ht_1_data = ht_1.data(); ht_1.Resize({d});
T* x_data = x.data(); ht.Resize({d});
T* ht_data = ht.data(); RandomVec<T>(3 * d, x.mutable_data<T>(place), -2.f, 2.f);
RandomVec<T>(d, ht_1.mutable_data<T>(place), -2.f, 2.f);
const T* ht_1_data = ht_1.data<T>();
T* x_data = x.mutable_data<T>(place);
T* ht_data = ht.mutable_data<T>(place);
jit::gru_t step; jit::gru_t step;
step.gates = x_data; step.gates = x_data;
step.ht_1 = ht_1_data; step.ht_1 = ht_1_data;
...@@ -243,10 +271,12 @@ void BenchSeqPoolKernel() { ...@@ -243,10 +271,12 @@ void BenchSeqPoolKernel() {
jit::seq_pool_attr_t attr(w, type); jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) { for (int h : TestSizes()) {
attr.h = h; attr.h = h;
std::vector<T> x(h * w), y(w); Tensor x, y;
RandomVec<T>(h * w, x.data(), -2.f, 2.f); x.Resize({h * w});
const T* x_data = x.data(); y.Resize({w});
T* y_data = y.data(); RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data, BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr); y_data, &attr);
} }
...@@ -259,12 +289,15 @@ void BenchMatMulKernel() { ...@@ -259,12 +289,15 @@ void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
for (int k : TestSizes()) { for (int k : TestSizes()) {
std::vector<T> a(m * k), b(k * n), c(m * n); Tensor a, b, c;
RandomVec<T>(m * k, a.data(), -2.f, 2.f); a.Resize({m * k});
RandomVec<T>(k * n, b.data(), -2.f, 2.f); b.Resize({k * n});
const T* a_data = a.data(); c.Resize({m * n});
const T* b_data = b.data(); RandomVec<T>(m * k, a.mutable_data<T>(PlaceType()), -2.f, 2.f);
T* c_data = c.data(); RandomVec<T>(k * n, b.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* a_data = a.data<T>();
const T* b_data = b.data<T>();
T* c_data = c.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(k, a_data, b_data, BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(k, a_data, b_data,
c_data, m, n, k); c_data, m, n, k);
} }
......
...@@ -41,13 +41,19 @@ class CreateCTRReaderOp : public framework::OperatorBase { ...@@ -41,13 +41,19 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
int thread_num = Attr<int>("thread_num"); auto thread_num = Attr<int>("thread_num");
std::vector<std::string> slots = Attr<std::vector<std::string>>("slots"); auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots");
int batch_size = Attr<int>("batch_size"); auto dense_slot_index = Attr<std::vector<int>>("dense_slot_index");
std::vector<std::string> file_list = auto sparse_slot_index = Attr<std::vector<int>>("sparse_slot_index");
Attr<std::vector<std::string>>("file_list"); auto batch_size = Attr<int>("batch_size");
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), batch_size, auto file_type = Attr<std::string>("file_type");
thread_num, slots, file_list)); auto file_format = Attr<std::string>("file_format");
auto file_list = Attr<std::vector<std::string>>("file_list");
DataDesc data_desc(batch_size, file_list, file_type, file_format,
dense_slot_index, sparse_slot_index, sparse_slots);
VLOG(1) << data_desc;
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num,
data_desc));
} }
}; };
...@@ -58,10 +64,22 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { ...@@ -58,10 +64,22 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"Name of the `LoDTensorBlockingQueueHolder` variable"); "Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("thread_num", "the thread num to read data"); AddAttr<int>("thread_num", "the thread num to read data");
AddAttr<int>("batch_size", "the batch size of read data"); AddAttr<int>("batch_size", "the batch size of read data");
AddAttr<std::string>("file_type", "plain or gzip").SetDefault("plain");
AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv");
AddAttr<std::vector<std::string>>("file_list", AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read"); "The list of files that need to read");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<int>>(
"slots", "the slots that should be extract from file"); "dense_slot_index",
"the dense slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<int>>(
"sparse_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>("sparse_slots",
"the sparse slots id that should be "
"extract from file, used when file "
"format is svm");
AddComment(R"DOC( AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp. Create CTRReader to support read ctr data with cpp.
......
...@@ -73,6 +73,9 @@ static inline void parse_line( ...@@ -73,6 +73,9 @@ static inline void parse_line(
} }
} }
// label slot1:fea_sign slot2:fea_sign slot1:fea_sign
static inline void parse_svm_line(const std::string& line) {}
class Reader { class Reader {
public: public:
virtual ~Reader() {} virtual ~Reader() {}
...@@ -95,11 +98,27 @@ class GzipReader : public Reader { ...@@ -95,11 +98,27 @@ class GzipReader : public Reader {
igzstream gzstream_; igzstream gzstream_;
}; };
class MultiGzipReader : public Reader { class PlainFileReader : public Reader {
public: public:
explicit MultiGzipReader(const std::vector<std::string>& file_list) { explicit PlainFileReader(const std::string& file_name)
: stream_(file_name.c_str()) {}
~PlainFileReader() {}
bool HasNext() override { return stream_.peek() != EOF; }
void NextLine(std::string* line) override { std::getline(stream_, *line); }
private:
std::ifstream stream_;
};
template <typename SingleFileReader>
class MultiFileReader : public Reader {
public:
explicit MultiFileReader(const std::vector<std::string>& file_list) {
for (auto& file : file_list) { for (auto& file : file_list) {
readers_.emplace_back(std::make_shared<GzipReader>(file)); readers_.emplace_back(std::make_shared<SingleFileReader>(file));
} }
} }
...@@ -119,46 +138,35 @@ class MultiGzipReader : public Reader { ...@@ -119,46 +138,35 @@ class MultiGzipReader : public Reader {
} }
private: private:
std::vector<std::shared_ptr<GzipReader>> readers_; std::vector<std::shared_ptr<SingleFileReader>> readers_;
size_t current_reader_index_ = 0; size_t current_reader_index_ = 0;
}; };
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "monitor thread in"; VLOG(3) << "monitor thread in";
bool reader_thread_is_running = true; bool reader_thread_is_running = true;
while (reader_thread_is_running) { while (reader_thread_is_running) {
VLOG(30) << "reader_thread_is_running"; VLOG(3) << "reader_thread_is_running";
reader_thread_is_running = false; reader_thread_is_running = false;
for (size_t i = 0; i < (*thread_status).size(); ++i) { for (size_t i = 0; i < (*thread_status).size(); ++i) {
if ((*thread_status)[i] == Running) { if ((*thread_status)[i] == Running) {
VLOG(30) << "reader is running!"; VLOG(3) << "reader is running!";
reader_thread_is_running = true; reader_thread_is_running = true;
} }
} }
std::this_thread::sleep_for(std::chrono::milliseconds(1000)); std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(30) << "all reader thread is stopped, push empty data into queue"; VLOG(3) << "all reader thread is stopped, close the queue";
queue->Push({}); queue->Close();
VLOG(30) << "monitor thread exited"; VLOG(3) << "monitor thread exited";
} }
void ReadThread(const std::vector<std::string>& file_list, void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
const std::vector<std::string>& slots, int batch_size, std::shared_ptr<LoDTensorBlockingQueue> queue) {
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(30) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(30) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index; std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < slots.size(); ++i) { for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) {
slot_to_index[slots[i]] = i; slot_to_index[data_desc.sparse_slot_ids_[i]] = i;
} }
std::string line; std::string line;
...@@ -166,21 +174,17 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -166,21 +174,17 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data; std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label; std::vector<int64_t> batch_label;
MultiGzipReader reader(file_list); while (reader->HasNext()) {
VLOG(30) << "reader inited";
while (reader.HasNext()) {
batch_data.clear(); batch_data.clear();
batch_data.reserve(batch_size); batch_data.reserve(data_desc.batch_size_);
batch_label.clear(); batch_label.clear();
batch_label.reserve(batch_size); batch_label.reserve(data_desc.batch_size_);
// read batch_size data // read batch_size data
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < data_desc.batch_size_; ++i) {
if (reader.HasNext()) { if (reader->HasNext()) {
reader.NextLine(&line); reader->NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data; std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
int64_t label; int64_t label;
parse_line(line, slot_to_index, &label, &slot_to_data); parse_line(line, slot_to_index, &label, &slot_to_data);
...@@ -193,8 +197,8 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -193,8 +197,8 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<framework::LoDTensor> lod_datas; std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each slots // first insert tensor for each sparse_slots
for (auto& slot : slots) { for (auto& slot : data_desc.sparse_slot_ids_) {
std::vector<size_t> lod_data{0}; std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign; std::vector<int64_t> batch_feasign;
...@@ -226,11 +230,167 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -226,11 +230,167 @@ void ReadThread(const std::vector<std::string>& file_list,
lod_datas.push_back(label_tensor); lod_datas.push_back(label_tensor);
queue->Push(lod_datas); queue->Push(lod_datas);
VLOG(40) << "push one data, queue_size=" << queue->Size(); VLOG(4) << "push one data, queue_size=" << queue->Size();
}
}
// label dense_fea,dense_fea sparse_fea,sparse_fea
static inline void parse_csv_line(
const std::string& line, const DataDesc& data_desc, int64_t* label,
std::vector<std::vector<float>>* dense_datas,
std::vector<std::vector<int64_t>>* sparse_datas) {
std::vector<std::string> ret;
string_split(line, ' ', &ret);
*label = std::stol(ret[0]);
dense_datas->resize(data_desc.dense_slot_index_.size());
for (size_t i = 0; i < data_desc.dense_slot_index_.size(); ++i) {
int slot_idx = data_desc.dense_slot_index_[i];
auto& slot_data = ret[slot_idx];
std::vector<std::string> data_in_slot_str;
string_split(slot_data, ',', &data_in_slot_str);
std::vector<float> data_in_slot;
for (auto& data_str : data_in_slot_str) {
(*dense_datas)[i].push_back(std::stof(data_str));
}
}
sparse_datas->resize(data_desc.sparse_slot_index_.size());
for (size_t i = 0; i < data_desc.sparse_slot_index_.size(); ++i) {
int slot_idx = data_desc.sparse_slot_index_[i];
auto& slot_data = ret[slot_idx];
std::vector<std::string> data_in_slot_str;
string_split(slot_data, ',', &data_in_slot_str);
std::vector<int64_t> data_in_slot;
for (auto& data_str : data_in_slot_str) {
auto id = std::stol(data_str);
(*sparse_datas)[i].push_back(id);
}
}
}
void ReadCsvData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
std::string line;
while (reader->HasNext()) {
std::vector<int64_t> batch_label;
batch_label.reserve(data_desc.batch_size_);
std::vector<std::vector<std::vector<float>>> batch_dense_data;
batch_dense_data.reserve(data_desc.batch_size_);
std::vector<std::vector<std::vector<int64_t>>> batch_sparse_data;
batch_sparse_data.reserve(data_desc.batch_size_);
// read batch_size data
for (int i = 0; i < data_desc.batch_size_; ++i) {
if (reader->HasNext()) {
reader->NextLine(&line);
int64_t label;
std::vector<std::vector<float>> dense_datas;
std::vector<std::vector<int64_t>> sparse_datas;
parse_csv_line(line, data_desc, &label, &dense_datas, &sparse_datas);
batch_label.push_back(label);
if (!batch_dense_data.empty()) {
PADDLE_ENFORCE_EQ(batch_dense_data[0].size(), dense_datas.size(),
"dense data should have the same shape");
}
batch_dense_data.push_back(dense_datas);
batch_sparse_data.push_back(sparse_datas);
} else {
break;
}
}
// the order of output data is label, dense_datas, sparse_datas
std::vector<framework::LoDTensor> lod_datas;
// insert label tensor
framework::LoDTensor label_tensor;
auto* label_tensor_data = label_tensor.mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(batch_label.size()), 1}),
platform::CPUPlace());
memcpy(label_tensor_data, batch_label.data(),
batch_label.size() * sizeof(int64_t));
lod_datas.push_back(label_tensor);
// insert tensor for each dense_slots
for (size_t i = 0; i < data_desc.dense_slot_index_.size(); ++i) {
framework::LoDTensor lod_tensor;
size_t width = batch_dense_data[0][i].size();
auto* tensor_data = lod_tensor.mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(batch_dense_data.size()), // batch_size
static_cast<int64_t>(width)}),
platform::CPUPlace());
for (size_t j = 0; j < batch_dense_data.size(); ++j) {
auto& dense_data_row = batch_dense_data[j][i];
memcpy(tensor_data + j * width, dense_data_row.data(),
width * sizeof(float));
}
lod_datas.push_back(lod_tensor);
}
// insert tensor for each sparse_slots
for (size_t i = 0; i < data_desc.sparse_slot_index_.size(); ++i) {
std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign;
for (size_t row_idx = 0; row_idx < batch_sparse_data.size(); ++row_idx) {
auto& sparse_ids = batch_sparse_data[row_idx][i];
lod_data.push_back(lod_data.back() + sparse_ids.size());
batch_feasign.insert(batch_feasign.end(), sparse_ids.begin(),
sparse_ids.end());
}
framework::LoDTensor lod_tensor;
framework::LoD lod{lod_data};
lod_tensor.set_lod(lod);
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(batch_feasign.size()), 1}),
platform::CPUPlace());
memcpy(tensor_data, batch_feasign.data(),
batch_feasign.size() * sizeof(int64_t));
lod_datas.push_back(lod_tensor);
}
queue->Push(lod_datas);
VLOG(4) << "push one data, queue_size=" << queue->Size();
}
}
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
std::shared_ptr<Reader> reader;
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
if (data_desc.file_format_ == "svm") {
ReadSvmData(data_desc, reader, queue);
} else if (data_desc.file_format_ == "csv") {
ReadCsvData(data_desc, reader, queue);
} }
(*thread_status)[thread_id] = Stopped; (*thread_status)[thread_id] = Stopped;
VLOG(30) << "set status to stopped, thread " << thread_id << " exited"; VLOG(3) << "set status to stopped, thread " << thread_id << " exited";
} }
} // namespace reader } // namespace reader
......
...@@ -36,9 +36,63 @@ namespace reader { ...@@ -36,9 +36,63 @@ namespace reader {
enum ReaderThreadStatus { Running, Stopped }; enum ReaderThreadStatus { Running, Stopped };
struct DataDesc {
DataDesc(int batch_size, const std::vector<std::string>& file_names,
const std::string& file_type, const std::string& file_format,
const std::vector<int>& dense_slot_index,
const std::vector<int>& sparse_slot_index,
const std::vector<std::string>& sparse_slot_ids)
: batch_size_(batch_size),
file_names_(file_names),
file_type_(file_type),
file_format_(file_format),
dense_slot_index_(dense_slot_index),
sparse_slot_index_(sparse_slot_index),
sparse_slot_ids_(sparse_slot_ids) {}
const int batch_size_;
const std::vector<std::string> file_names_;
const std::string file_type_; // gzip or plain
const std::string file_format_; // csv or svm
// used for csv data format
const std::vector<int> dense_slot_index_;
const std::vector<int> sparse_slot_index_;
// used for svm data format
const std::vector<std::string> sparse_slot_ids_;
};
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
os << "data_desc:\n";
os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
os << "\tfile_type -> " << data_desc.file_type_ << "\n";
os << "\tfile_format -> " << data_desc.file_format_ << "\n";
os << "\tfile_names -> {";
for (auto& file_name : data_desc.file_names_) {
os << file_name << ",";
}
os << "}\n";
os << "\tdense_slot_index -> {";
for (auto& slot : data_desc.dense_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_index_ -> {";
for (auto& slot : data_desc.sparse_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_ids_ -> {";
for (auto& slot : data_desc.sparse_slot_ids_) {
os << slot << ",";
}
os << "}\n";
return os;
}
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size, const DataDesc& data_desc, int thread_id,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue); std::shared_ptr<LoDTensorBlockingQueue> queue);
// monitor all running thread, if they are all stopped, // monitor all running thread, if they are all stopped,
...@@ -48,15 +102,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -48,15 +102,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader { class CTRReader : public framework::FileReader {
public: public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue, CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, size_t thread_num, int thread_num, const DataDesc& data_desc)
const std::vector<std::string>& slots, : data_desc_(data_desc) {
const std::vector<std::string>& file_list)
: batch_size_(batch_size), slots_(slots), file_list_(file_list) {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0,
thread_num_ = std::min<size_t>(file_list_.size(), thread_num); "file list should not be empty");
thread_num_ = std::min<size_t>(data_desc_.file_names_.size(), thread_num);
queue_ = queue; queue_ = queue;
SplitFiles(); SplitFiles();
for (size_t i = 0; i < thread_num_; ++i) { for (size_t i = 0; i < thread_num_; ++i) {
...@@ -64,7 +118,7 @@ class CTRReader : public framework::FileReader { ...@@ -64,7 +118,7 @@ class CTRReader : public framework::FileReader {
} }
} }
~CTRReader() {} ~CTRReader() { Shutdown(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
...@@ -81,7 +135,10 @@ class CTRReader : public framework::FileReader { ...@@ -81,7 +135,10 @@ class CTRReader : public framework::FileReader {
for (auto& read_thread : read_threads_) { for (auto& read_thread : read_threads_) {
read_thread->join(); read_thread->join();
} }
monitor_thread_->join();
if (monitor_thread_) {
monitor_thread_->join();
}
read_threads_.clear(); read_threads_.clear();
monitor_thread_.reset(nullptr); monitor_thread_.reset(nullptr);
...@@ -95,9 +152,9 @@ class CTRReader : public framework::FileReader { ...@@ -95,9 +152,9 @@ class CTRReader : public framework::FileReader {
queue_->ReOpen(); queue_->ReOpen();
VLOG(3) << "reopen success"; VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_; VLOG(3) << "thread_num " << thread_num_;
for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) { for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(std::bind( read_threads_.emplace_back(new std::thread(std::bind(
&ReadThread, file_groups_[thread_id], slots_, batch_size_, &ReadThread, file_groups_[thread_id], data_desc_,
static_cast<int>(thread_id), &read_thread_status_, queue_))); static_cast<int>(thread_id), &read_thread_status_, queue_)));
} }
monitor_thread_.reset(new std::thread( monitor_thread_.reset(new std::thread(
...@@ -108,8 +165,8 @@ class CTRReader : public framework::FileReader { ...@@ -108,8 +165,8 @@ class CTRReader : public framework::FileReader {
private: private:
void SplitFiles() { void SplitFiles() {
file_groups_.resize(thread_num_); file_groups_.resize(thread_num_);
for (size_t i = 0; i < file_list_.size(); ++i) { for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
auto& file_name = file_list_[i]; auto& file_name = data_desc_.file_names_[i];
std::ifstream f(file_name.c_str()); std::ifstream f(file_name.c_str());
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
file_groups_[i % thread_num_].push_back(file_name); file_groups_[i % thread_num_].push_back(file_name);
...@@ -118,9 +175,7 @@ class CTRReader : public framework::FileReader { ...@@ -118,9 +175,7 @@ class CTRReader : public framework::FileReader {
private: private:
size_t thread_num_; size_t thread_num_;
const int batch_size_; const DataDesc data_desc_;
const std::vector<std::string> slots_;
const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_; std::vector<std::unique_ptr<std::thread>> read_threads_;
std::unique_ptr<std::thread> monitor_thread_; std::unique_ptr<std::thread> monitor_thread_;
......
...@@ -36,6 +36,7 @@ using paddle::framework::LoD; ...@@ -36,6 +36,7 @@ using paddle::framework::LoD;
using paddle::framework::DDim; using paddle::framework::DDim;
using paddle::platform::CPUPlace; using paddle::platform::CPUPlace;
using paddle::framework::make_ddim; using paddle::framework::make_ddim;
using paddle::operators::reader::DataDesc;
static void generatedata(const std::vector<std::string>& data, static void generatedata(const std::vector<std::string>& data,
const std::string& file_name) { const std::string& file_name) {
...@@ -126,30 +127,103 @@ TEST(CTR_READER, read_data) { ...@@ -126,30 +127,103 @@ TEST(CTR_READER, read_data) {
LoDTensorBlockingQueueHolder queue_holder; LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64; int capacity = 64;
queue_holder.InitOnce(capacity, {}, false); queue_holder.InitOnce(capacity, false);
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue(); std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
int batch_size = 3; int batch_size = 3;
int thread_num = 1; int thread_num = 1;
std::vector<std::string> slots = {"6002", "6003"}; std::vector<std::string> sparse_slots = {"6002", "6003"};
std::vector<std::string> file_list; std::vector<std::string> file_list;
for (int i = 0; i < thread_num; ++i) { for (int i = 0; i < thread_num; ++i) {
file_list.push_back(gz_file_name); file_list.push_back(gz_file_name);
} }
CTRReader reader(queue, batch_size, thread_num, slots, file_list); DataDesc data_desc(batch_size, file_list, "gzip", "svm", {}, {},
sparse_slots);
CTRReader reader(queue, thread_num, data_desc);
reader.Start(); reader.Start();
size_t batch_num = size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num; std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6003, batch_num, batch_size, queue, &reader); data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown(); reader.Shutdown();
reader.Start(); reader.Start();
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6003, batch_num, batch_size, queue, &reader); data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown(); reader.Shutdown();
} }
static void GenereteCsvData(const std::string& file_name,
const std::vector<std::string>& data) {
std::ofstream out(file_name.c_str());
PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name);
for (auto& c : data) {
out << c;
}
out.close();
PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name);
}
static void CheckReadCsvOut(const std::vector<LoDTensor>& out) {
ASSERT_EQ(out.size(), 3);
ASSERT_EQ(out[0].dims()[1], 1);
ASSERT_EQ(out[1].dims()[1], 2);
ASSERT_EQ(out[2].dims()[1], 1);
for (size_t i = 0; i < out[0].numel(); ++i) {
int64_t label = out[0].data<int64_t>()[i];
auto& dense_dim = out[1].dims();
for (size_t j = 0; j < dense_dim[1]; ++j) {
ASSERT_EQ(out[1].data<float>()[i * dense_dim[1] + j],
static_cast<float>(label + 0.1));
}
auto& sparse_lod = out[2].lod();
for (size_t j = sparse_lod[0][i]; j < sparse_lod[0][i + 1]; ++j) {
ASSERT_EQ(out[2].data<int64_t>()[j], label);
}
}
}
TEST(CTR_READER, read_csv_data) {
std::string file_name = "test_ctr_reader_data.csv";
const std::vector<std::string> csv_data = {
"0 0.1,0.1 0,0,0,0\n", "1 1.1,1.1 1,1,1,1\n", "2 2.1,2.1 2,2,2,2\n",
"3 3.1,3.1 3,3,3,3\n",
};
GenereteCsvData(file_name, csv_data);
LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64;
queue_holder.InitOnce(capacity, false);
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
int batch_size = 3;
int thread_num = 1;
std::vector<std::string> file_list;
for (int i = 0; i < thread_num; ++i) {
file_list.push_back(file_name);
}
DataDesc data_desc(batch_size, file_list, "plain", "csv", {1}, {2}, {});
CTRReader reader(queue, thread_num, data_desc);
for (size_t i = 0; i < 2; ++i) {
reader.Start();
std::vector<LoDTensor> out;
while (true) {
reader.ReadNext(&out);
if (out.empty()) {
break;
}
CheckReadCsvOut(out);
}
reader.Shutdown();
}
}
...@@ -32,10 +32,8 @@ class LoDTensorBlockingQueue { ...@@ -32,10 +32,8 @@ class LoDTensorBlockingQueue {
friend class LoDTensorBlockingQueueHolder; friend class LoDTensorBlockingQueueHolder;
private: private:
LoDTensorBlockingQueue(size_t capacity, explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
const std::vector<framework::DDim>& dims, : queue_(capacity, speed_test_mode) {}
bool speed_test_mode = false)
: queue_(capacity, speed_test_mode), dims_(dims) {}
public: public:
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) { bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
...@@ -65,17 +63,15 @@ class LoDTensorBlockingQueue { ...@@ -65,17 +63,15 @@ class LoDTensorBlockingQueue {
private: private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
std::vector<framework::DDim> dims_;
}; };
class LoDTensorBlockingQueueHolder { class LoDTensorBlockingQueueHolder {
public: public:
void InitOnce(size_t capacity, const std::vector<framework::DDim>& dims, void InitOnce(size_t capacity, bool speed_test_mode = false) {
bool speed_test_mode = false) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
queue_ == nullptr, queue_ == nullptr,
"LoDTensorBlockingQueueHolder::InitOnce() can only be called once"); "LoDTensorBlockingQueueHolder::InitOnce() can only be called once");
queue_.reset(new LoDTensorBlockingQueue(capacity, dims, speed_test_mode)); queue_.reset(new LoDTensorBlockingQueue(capacity, speed_test_mode));
} }
inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const { inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const {
......
...@@ -27,13 +27,13 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -27,13 +27,13 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input."); "The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output."); "The ReadOp should be assigned with output.");
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader"); if (!ctx->IsRuntime() && ctx->Attrs().Get<bool>("infer_out")) {
std::vector<std::string> out_names = ctx->Outputs("Out"); std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
PADDLE_ENFORCE_EQ( std::vector<std::string> out_names = ctx->Outputs("Out");
reader_dims.size(), out_names.size(), PADDLE_ENFORCE_EQ(
"The reader's dim number doesn't match the output number."); reader_dims.size(), out_names.size(),
ctx->SetOutputsDim("Out", reader_dims); "The reader's dim number doesn't match the output number.");
if (!ctx->IsRuntime()) { ctx->SetOutputsDim("Out", reader_dims);
auto in_desc = auto in_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]); boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]);
auto in_lod_levels = in_desc->GetLoDLevels(); auto in_lod_levels = in_desc->GetLoDLevels();
...@@ -53,15 +53,18 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -53,15 +53,18 @@ class ReadInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Input("Reader")[0]; bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
std::vector<std::string> out_names = op_desc.Output("Out"); if (infer_out) {
framework::VarDesc* reader = block->FindVarRecursive(reader_name); std::string reader_name = op_desc.Input("Reader")[0];
auto dtypes = reader->GetDataTypes(); std::vector<std::string> out_names = op_desc.Output("Out");
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
for (size_t i = 0; i < dtypes.size(); ++i) { auto dtypes = reader->GetDataTypes();
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
out.SetType(framework::proto::VarType::LOD_TENSOR); for (size_t i = 0; i < dtypes.size(); ++i) {
out.SetDataType(dtypes[i]); framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]);
}
} }
} }
}; };
...@@ -73,6 +76,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -73,6 +76,7 @@ class ReadOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
VLOG(3) << "read op in";
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
detail::Ref(scope.FindVar(Input("Reader")), detail::Ref(scope.FindVar(Input("Reader")),
"Cannot find reader variable %s", Input("Reader")) "Cannot find reader variable %s", Input("Reader"))
...@@ -87,7 +91,9 @@ class ReadOp : public framework::OperatorBase { ...@@ -87,7 +91,9 @@ class ReadOp : public framework::OperatorBase {
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { if (ins.empty()) {
VLOG(3) << "read empty data in";
if (Attr<bool>("throw_eof_exp")) { if (Attr<bool>("throw_eof_exp")) {
VLOG(3) << "throw_eof_exp";
PADDLE_THROW_EOF(); PADDLE_THROW_EOF();
} else { } else {
ins.resize(out_arg_names.size()); ins.resize(out_arg_names.size());
...@@ -96,6 +102,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -96,6 +102,7 @@ class ReadOp : public framework::OperatorBase {
tensor.mutable_data<float>(framework::make_ddim({0}), dev_place); tensor.mutable_data<float>(framework::make_ddim({0}), dev_place);
} }
} }
VLOG(3) << "read empty data out";
} }
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) { for (size_t i = 0; i < out_arg_names.size(); ++i) {
...@@ -120,6 +127,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,6 +127,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" only when the data-balance is enabled in ParallelExecutor" " only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.") " and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() { ...@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time," "It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."); "whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data."); AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<bool>(
"use_data_config",
"Use the config of all datas like shape_concat/ranks/lod_levels")
.SetDefault(true);
Apply(); Apply();
} }
...@@ -75,19 +79,23 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -75,19 +79,23 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null."); "The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat"); bool use_data_config = ctx->Attrs().Get<bool>("use_data_config");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks"); if (use_data_config) {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); const auto shape_concat =
ctx->SetReaderDims("Out", shapes); ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels"); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), ctx->SetReaderDims("Out", shapes);
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).", const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
lod_levels.size(), shapes.size()); PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
framework::VarDesc* reader = "The number of 'lod_levels'(%d) doesn't match the number "
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); "of 'shapes'(%d).",
reader->SetLoDLevels(lod_levels); lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
} }
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
......
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h"
namespace paddle {
namespace operators {
class ShuffleChannelOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ShuffleChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShuffleChannelOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
ctx->SetOutputDim("Out", input_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of ShuffleChannelOp, the layout is NCHW.");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"ShuffleChannelOp. The layout is NCHW.");
AddAttr<int>("group", "the number of groups.")
.SetDefault(1)
.AddCustomChecker([](const int& group) {
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0.");
});
AddComment(R"DOC(
Shuffle Channel operator
This opearator shuffles the channels of input x.
It divide the input channels in each group into several subgroups,
and obtain a new order by selecting element from every subgroup one by one.
Shuffle channel operation makes it possible to build more powerful structures
with multiple group convolutional layers.
please get more information from the following paper:
https://arxiv.org/pdf/1707.01083.pdf
)DOC");
}
};
class ShuffleChannelGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
ops::ShuffleChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
REGISTER_OP_CPU_KERNEL(
shuffle_channel,
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
shuffle_channel_grad,
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void ShuffleChannel(const int nthreads, const int feature_map_size,
T* output, const T* input, int group_row,
int group_column, int len) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t ii = index; ii < nthreads; ii += offset) {
const int n = index / group_row / group_column / len;
const int i = (index / group_column / len) % group_row;
const int j = index / len % group_column;
const int k = index - (n * feature_map_size + (i * group_column + j) * len);
T* p_o = output + n * feature_map_size + (j * group_row + i) * len;
p_o[k] = input[index];
}
}
template <typename DeviceContext, typename T>
class ShuffleChannelOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
auto weight = input_dims[3];
auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight;
int group_row = group;
int group_column = channel / group_row;
// count is the product of NCHW same as numel()
int count = num * group_column * group_row * sp_sz;
int blocks = NumBlocks(output->numel());
int threads = kNumCUDAThreads;
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
ShuffleChannel<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
count, feature_map_size, output_data, input_data, group_row,
group_column, sp_sz);
}
};
template <typename DeviceContext, typename T>
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
auto weight = input_dims[3];
auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight;
int group_row = group;
int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>();
int blocks = NumBlocks(output_grad->numel());
int threads = kNumCUDAThreads;
int count = num * group_column * group_row * sp_sz;
ShuffleChannel<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
count, feature_map_size, input_grad_data, output_grad_data, group_row,
group_column, sp_sz);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
shuffle_channel,
ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
shuffle_channel_grad,
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ShuffleChannelOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
auto weight = input_dims[3];
auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight;
int group_row = group;
int group_column = channel / group_row;
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
for (int n = 0; n < num; ++n) {
for (int i = 0; i < group_row; ++i) {
for (int j = 0; j < group_column; ++j) {
const T* p_i = input_data + n * feature_map_size +
(i * group_column + j) * sp_sz;
T* p_o =
output_data + n * feature_map_size + (j * group_row + i) * sp_sz;
memcpy(p_o, p_i, sizeof(int) * sp_sz);
}
}
}
}
};
template <typename DeviceContext, typename T>
class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
auto weight = input_dims[3];
auto feature_map_size = channel * height * weight;
auto sp_sz = height * weight;
int group_row = group;
int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>();
for (int n = 0; n < num; ++n) {
for (int i = 0; i < group_row; ++i) {
for (int j = 0; j < group_column; ++j) {
const T* p_i = output_grad_data + n * feature_map_size +
(i * group_column + j) * sp_sz;
T* p_o = input_grad_data + n * feature_map_size +
(j * group_row + i) * sp_sz;
memcpy(p_o, p_i, sizeof(int) * sp_sz);
}
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -485,6 +485,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -485,6 +485,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll); .def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
...@@ -505,19 +506,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -505,19 +506,12 @@ All parameter, weight, gradient are variables in Paddle.
.def("is_closed", &LoDTensorBlockingQueue::IsClosed); .def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity, [](Variable &var,
const std::vector<std::vector<int64_t>> &shapes) size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
-> std::shared_ptr<LoDTensorBlockingQueue> { auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
std::vector<DDim> dims(shapes.size()); holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
std::transform(shapes.begin(), shapes.end(), dims.begin(), return holder->GetQueue();
[](const std::vector<int64_t> &shape) { },
return make_ddim(shape);
});
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, dims,
FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
py::return_value_policy::copy); py::return_value_policy::copy);
py::class_<Scope>(m, "_Scope", R"DOC( py::class_<Scope>(m, "_Scope", R"DOC(
......
...@@ -22,6 +22,8 @@ from . import op_frequence ...@@ -22,6 +22,8 @@ from . import op_frequence
from .op_frequence import * from .op_frequence import *
from . import quantize from . import quantize
from .quantize import * from .quantize import *
from . import reader
from .reader import *
from . import slim from . import slim
from .slim import * from .slim import *
from . import utils from . import utils
...@@ -32,5 +34,6 @@ __all__ += decoder.__all__ ...@@ -32,5 +34,6 @@ __all__ += decoder.__all__
__all__ += memory_usage_calc.__all__ __all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__ __all__ += op_frequence.__all__
__all__ += quantize.__all__ __all__ += quantize.__all__
__all__ += reader.__all__
__all__ += slim.__all__ __all__ += slim.__all__
__all__ += utils.__all__ __all__ += utils.__all__
## CTR READER
An multi-thread cpp reader that has the same interface with py_reader. It
uses cpp multi-thread to read file and is much more faster then the Python read
thread in py_reader.
Currently, it support two types of file:
- gzip
- plain text file
and two types of data format:
- cvs data format is :
* label dense_fea,dense_fea sparse_fea,sparse_fea
- the svm data format is :
* label slot1:fea_sign slot2:fea_sign slot1:fea_sign
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import ctr_reader
__all__ = ctr_reader.__all__
...@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \ ...@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \
default_startup_program, Variable default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name from paddle.fluid.unique_name import generate as unique_name
__all__ = ['ctr_reader']
def monkey_patch_reader_methods(reader): def monkey_patch_reader_methods(reader):
def __get_reader__(): def __get_reader__():
...@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader): ...@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader):
def reset(): def reset():
return __get_reader__().reset() return __get_reader__().reset()
def start():
return __get_reader__().start()
reader.reset = reset reader.reset = reset
reader.start = start
reader.stop_gradient = True reader.stop_gradient = True
reader.persistable = True reader.persistable = True
return reader return reader
...@@ -44,13 +50,18 @@ def _copy_reader_var_(block, var): ...@@ -44,13 +50,18 @@ def _copy_reader_var_(block, var):
return new_var return new_var
def ctr_reader(feed_data, def ctr_reader(
capacity, feed_dict,
thread_num, file_type, # gzip or plain
batch_size, file_format, # csv or svm
file_list, dense_slot_index,
slots, sparse_slot_index,
name=None): capacity,
thread_num,
batch_size,
file_list,
slots,
name=None):
""" """
Create a CTR reader for data feeding in Python Create a CTR reader for data feeding in Python
...@@ -67,12 +78,21 @@ def ctr_reader(feed_data, ...@@ -67,12 +78,21 @@ def ctr_reader(feed_data,
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`. Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
Args: Args:
feed_dict(list(variable)): a list of data variable.
file_type('gzip'|'plain'): the type of the data file
file_format('csv'|'svm'): csv data or svm data format.
cvs data format is :
label dense_fea,dense_fea sparse_fea,sparse_fea
the svm data format is :
label slot1:fea_sign slot2:fea_sign slot1:fea_sign
dense_slot_index(list(int)): the index of dense slots
sparse_slot_index(list(int)): the index of sparse slots
capacity(int): The buffer capacity maintained by :code:`py_reader`. capacity(int): The buffer capacity maintained by :code:`py_reader`.
thread_num(list|tuple): List of tuples which declaring data shapes. thread_num(int): the thread num to read files by cpp reader.
batch_size(list|tuple): List of strs which declaring data type. batch_size(int): batch size of data.
file_list(list|tuple): List of ints which declaring data lod_level. file_list(list(str)): List of file names that need to read.
slots(bool): Whether use double buffer or not. slots(list(int64)): list of slot id.
name(basestring): The prefix Python queue name and Reader name. None will name(string): The prefix Python queue name and Reader name. None will
be generated automatically. be generated automatically.
Returns: Returns:
...@@ -80,7 +100,15 @@ def ctr_reader(feed_data, ...@@ -80,7 +100,15 @@ def ctr_reader(feed_data,
Examples: Examples:
1. The basic usage of :code:`py_reader` is as follows: 1. The basic usage of :code:`ctr_reader` is as follows:
.. code-block:: python
py_reader = fluid.contrib.ctr_reader.ctr_reader(
feed_dict=datas, file_type='plain', file_format='csv',
file_list=file_list, dense_slot_indexs=[1, 2, 3, 4], sparse_slot_indexs=[],
capacity=64, thread_num=20, batch_size=1000, slots=[], name='ctr_reader')
""" """
if name is None: if name is None:
queue_name = unique_name('lod_tensor_blocking_queue') queue_name = unique_name('lod_tensor_blocking_queue')
...@@ -90,7 +118,7 @@ def ctr_reader(feed_data, ...@@ -90,7 +118,7 @@ def ctr_reader(feed_data,
reader_name = "_".join([name, "reader"]) reader_name = "_".join([name, "reader"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
reader_var = startup_blk.create_var(name=reader_name) reader_var = startup_blk.create_var(name=reader_name)
...@@ -99,12 +127,22 @@ def ctr_reader(feed_data, ...@@ -99,12 +127,22 @@ def ctr_reader(feed_data,
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [reader_var]}, outputs={'Out': [reader_var]},
attrs={ attrs={
'use_data_config': False,
'thread_num': thread_num, 'thread_num': thread_num,
'batch_size': batch_size, 'batch_size': batch_size,
'file_list': file_list, 'file_list': file_list,
'slots': slots, 'file_type': file_type,
'file_format': file_format,
'dense_slot_index': dense_slot_index,
'sparse_slot_index': sparse_slot_index,
'sparse_slots': slots,
'ranks': [],
'lod_levels': [],
'shape_concat': []
}) })
dtypes = [data.dtype for data in feed_dict]
reader_var.desc.set_dtypes(dtypes)
reader_var.persistable = True reader_var.persistable = True
main_prog_reader_var = _copy_reader_var_( main_prog_reader_var = _copy_reader_var_(
...@@ -118,6 +156,9 @@ def ctr_reader(feed_data, ...@@ -118,6 +156,9 @@ def ctr_reader(feed_data,
main_blk = default_main_program().current_block() main_blk = default_main_program().current_block()
main_blk.append_op( main_blk.append_op(
type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data}) type='read',
inputs={'Reader': [reader]},
attrs={'infer_out': False},
outputs={'Out': feed_dict})
return reader return reader
...@@ -523,7 +523,7 @@ def _py_reader(capacity, ...@@ -523,7 +523,7 @@ def _py_reader(capacity,
double_buffer_name = "_".join([name, "double_buffer"]) double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
......
...@@ -179,6 +179,7 @@ __all__ = [ ...@@ -179,6 +179,7 @@ __all__ = [
'merge_selected_rows', 'merge_selected_rows',
'get_tensor_from_selected_rows', 'get_tensor_from_selected_rows',
'lstm', 'lstm',
'shuffle_channel',
'py_func', 'py_func',
'psroi_pool', 'psroi_pool',
'teacher_student_sigmoid_loss', 'teacher_student_sigmoid_loss',
...@@ -9646,6 +9647,79 @@ def get_tensor_from_selected_rows(x, name=None): ...@@ -9646,6 +9647,79 @@ def get_tensor_from_selected_rows(x, name=None):
return out return out
def shuffle_channel(x, group, name=None):
"""
**Shuffle Channel Operator**
This operator shuffles the channels of input x.
It divide the input channels in each group into :attr:`group` subgroups,
and obtain a new order by selecting element from every subgroup one by one.
Please refer to the paper
https://arxiv.org/pdf/1707.01083.pdf
.. code-block:: text
Given a 4-D tensor input with the shape (N, C, H, W):
input.shape = (1, 4, 2, 2)
input.data =[[[[0.1, 0.2],
[0.2, 0.3]],
[[0.3, 0.4],
[0.4, 0.5]],
[[0.5, 0.6],
[0.6, 0.7]],
[[0.7, 0.8],
[0.8, 0.9]]]]
Given group: 2
then we get a 4-D tensor out whth the same shape of input:
out.shape = (1, 4, 2, 2)
out.data = [[[[0.1, 0.2],
[0.2, 0.3]],
[[0.5, 0.6],
[0.6, 0.7]],
[[0.3, 0.4],
[0.4, 0.5]],
[[0.7, 0.8],
[0.8, 0.9]]]]
Args:
x(Variable): The input tensor variable. It should be a 4-D tensor with shape [N, C, H, W]
group(int): Indicating the conuts of subgroups, It should divide the number of channels.
Returns:
out(Variable): the channels shuffling result is a tensor variable with the
same shape and same type as the input.
Raises:
ValueError: If group is not an int type variable.
Examples:
.. code-block:: python
input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32')
out = fluid.layers.shuffle_channel(x=input, group=2)
"""
helper = LayerHelper("shuffle_channel", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if not isinstance(group, int):
raise TypeError("group must be int type")
helper.append_op(
type="shuffle_channel",
inputs={"X": x},
outputs={"Out": out},
attrs={"group": group})
return out
class PyFuncRegistry(object): class PyFuncRegistry(object):
_register_funcs = [] _register_funcs = []
......
...@@ -1023,6 +1023,14 @@ class TestBook(unittest.TestCase): ...@@ -1023,6 +1023,14 @@ class TestBook(unittest.TestCase):
print(str(program)) print(str(program))
def test_shuffle_channel(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
out = layers.shuffle_channel(x, group=4)
self.assertIsNotNone(out)
print(str(program))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
import paddle.fluid.core as core
class TestShuffleChannelOp(OpTest):
def setUp(self):
self.op_type = "shuffle_channel"
self.batch_size = 10
self.input_channels = 16
self.layer_h = 4
self.layer_w = 4
self.group = 4
self.x = np.random.random(
(self.batch_size, self.input_channels, self.layer_h,
self.layer_w)).astype('float32')
self.inputs = {'X': self.x}
self.attrs = {'group': self.group}
n, c, h, w = self.x.shape
input_reshaped = np.reshape(self.x,
(-1, self.group, c // self.group, h, w))
input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4))
self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
...@@ -109,6 +109,7 @@ packages=['paddle', ...@@ -109,6 +109,7 @@ packages=['paddle',
'paddle.fluid.contrib', 'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize', 'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.reader',
'paddle.fluid.contrib.slim', 'paddle.fluid.contrib.slim',
'paddle.fluid.contrib.slim.core', 'paddle.fluid.contrib.slim.core',
'paddle.fluid.contrib.slim.graph', 'paddle.fluid.contrib.slim.graph',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册