提交 7f07dfa1 编写于 作者: Q Qiao Longfei

clean code

上级 618f7620
......@@ -126,7 +126,7 @@ TEST(CTR_READER, read_data) {
LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64;
queue_holder.InitOnce(capacity, {}, false);
queue_holder.InitOnce(capacity, false);
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
......
......@@ -32,10 +32,8 @@ class LoDTensorBlockingQueue {
friend class LoDTensorBlockingQueueHolder;
private:
LoDTensorBlockingQueue(size_t capacity,
const std::vector<framework::DDim>& dims,
bool speed_test_mode = false)
: queue_(capacity, speed_test_mode), dims_(dims) {}
explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
: queue_(capacity, speed_test_mode) {}
public:
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
......@@ -65,17 +63,15 @@ class LoDTensorBlockingQueue {
private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_;
std::vector<framework::DDim> dims_;
};
class LoDTensorBlockingQueueHolder {
public:
void InitOnce(size_t capacity, const std::vector<framework::DDim>& dims,
bool speed_test_mode = false) {
void InitOnce(size_t capacity, bool speed_test_mode = false) {
PADDLE_ENFORCE(
queue_ == nullptr,
"LoDTensorBlockingQueueHolder::InitOnce() can only be called once");
queue_.reset(new LoDTensorBlockingQueue(capacity, dims, speed_test_mode));
queue_.reset(new LoDTensorBlockingQueue(capacity, speed_test_mode));
}
inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const {
......
......@@ -384,19 +384,12 @@ All parameter, weight, gradient are variables in Paddle.
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity,
const std::vector<std::vector<int64_t>> &shapes)
-> std::shared_ptr<LoDTensorBlockingQueue> {
std::vector<DDim> dims(shapes.size());
std::transform(shapes.begin(), shapes.end(), dims.begin(),
[](const std::vector<int64_t> &shape) {
return make_ddim(shape);
});
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, dims,
FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
[](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
py::return_value_policy::copy);
py::class_<Scope>(m, "Scope", R"DOC(
......
......@@ -523,7 +523,7 @@ def _py_reader(capacity,
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册