提交 aff54ef7 编写于 作者: Q Qiao Longfei

add ctr data

上级 40d65a13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import core
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import default_main_program, \
default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name
def monkey_patch_reader_methods(reader):
def __get_reader__():
scope = global_scope()
var = scope.find_var(reader.name)
return var.get_reader()
def reset():
return __get_reader__().reset()
reader.reset = reset
reader.stop_gradient = True
reader.persistable = True
return reader
def _copy_reader_var_(block, var):
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return new_var
def ctr_reader(feed_data,
capacity,
thread_num,
batch_size,
file_list,
slots,
name=None):
"""
Create a CTR reader for data feeding in Python
This layer returns a Reader Variable.
The Reader provides :code:`decorate_paddle_reader()` and
:code:`decorate_tensor_provider()` to set a Python generator as the data
source in Python side. When :code:`Executor::Run()` is invoked in C++
side, the data from the generator would be read automatically. Unlike
:code:`DataFeeder.feed()`, the data reading process and
:code:`Executor::Run()` process can run in parallel using
:code:`py_reader`. The :code:`start()` method of the Reader should be
called when each pass begins, while the :code:`reset()` method should be
called when the pass ends and :code:`fluid.core.EOFException` raises.
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
Args:
capacity(int): The buffer capacity maintained by :code:`py_reader`.
thread_num(list|tuple): List of tuples which declaring data shapes.
batch_size(list|tuple): List of strs which declaring data type.
file_list(list|tuple): List of ints which declaring data lod_level.
slots(bool): Whether use double buffer or not.
name(basestring): The prefix Python queue name and Reader name. None will
be generated automatically.
Returns:
Variable: A Reader from which we can get feeding data.
Examples:
1. The basic usage of :code:`py_reader` is as follows:
"""
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_ctr_reader')
else:
queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
startup_blk = default_startup_program().current_block()
reader_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op(
type='create_ctr_reader',
inputs={'blocking_queue': [queue_name]},
outputs={'Out': [reader_var]},
attrs={
'thread_num': thread_num,
'batch_size': batch_size,
'file_list': file_list,
'slots': slots,
})
reader_var.persistable = True
main_prog_reader_var = _copy_reader_var_(
default_main_program().current_block(), reader_var)
reader = monkey_patch_reader_methods(main_prog_reader_var)
# monkey patch py_reader special methods
reader.queue = feed_queue
reader.exited = False
main_blk = default_main_program().current_block()
main_blk.append_op(
type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data})
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册