提交 c2436f2c 编写于 作者: F fengjiayi

expose random_data_generator

上级 8655904b
......@@ -21,14 +21,15 @@ namespace reader {
template <typename T>
class RandomDataGenerator : public framework::ReaderBase {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max)
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
float high)
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) {
PADDLE_ENFORCE_LE(low, high,
"'low' shouldn't be greater than 'high'.(%f vs %f)", low,
high);
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(min_, max_);
dist_ = std::uniform_real_distribution<float>(low_, high_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
......@@ -53,8 +54,8 @@ class RandomDataGenerator : public framework::ReaderBase {
void ReInit() override { return; }
private:
float min_;
float max_;
float low_;
float high_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
......@@ -78,22 +79,22 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"),
Attr<float>("high")));
}
};
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
protected:
void Apply() override {
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddAttr<float>("low", "The lower bound of reader's uniform distribution.");
AddAttr<float>("high", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
Generated data follow an uniform distribution between 'low' and 'high'.
)DOC");
}
};
......
......@@ -321,7 +321,7 @@ def open_recordio_file(filename,
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
image, label = fluid.layers.io.read_file(reader)
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
......@@ -359,6 +359,73 @@ def open_recordio_file(filename,
return monkey_patch_reader_methods(main_prog_var)
def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
"""
Create a uniform random data generator
This layer returns a Reader Variable.
Instead of opening a file and reading data from it, this
Reader Variable generates float uniform random data by itself.
It can be used as a dummy reader to test a network without
opening a real file.
Args:
low(float): The lower bound of data's uniform distribution.
high(float): The upper bound of data's uniform distribution.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable from which we can get random data.
Examples:
.. code-block:: python
reader = fluid.layers.io.random_data_generator(
low=0.0,
high=1.0,
shapes=[(3,224,224), (1)],
lod_levels=[0, 0])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
dtypes = [core.VarDesc.VarType.FP32] * len(shapes)
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('random_data_generator')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='create_random_data_generator',
outputs={'Out': [startup_var]},
attrs={
'low': low,
'high': high,
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames,
shapes,
lod_levels,
......
......@@ -44,8 +44,8 @@ create_random_data_generator_op = startup_block.append_op(
attrs={
"shape_concat": [1, 2, 1, 1],
"ranks": [2, 2],
"min": 0.0,
"max": 1.0,
"low": 0.0,
"high": 1.0,
'lod_levels': [0, 0]
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册