未验证 提交 88af4ab6 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add new data layer (#19916)

The new "fluid.data" changes old "fluid.layers.data":

1. Add shape and dtype check.
2. Remove "append_batch_size" parameter. We won't offer this in the new data layer because other deep learning platforms don't have this kind of data layer pre-processing. It may confuse users.
3. Remove "stop gradient" parameter because the data layer doesn't do back-propagation

TODO:
Now data layer feeded by executor is checked, will we want to check the feed data of readers in the future?
上级 1b7de894
......@@ -556,6 +556,7 @@ paddle.fluid.contrib.BasicLSTMUnit.sublayers (ArgSpec(args=['self', 'include_sub
paddle.fluid.contrib.BasicLSTMUnit.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.basic_lstm (ArgSpec(args=['input', 'init_hidden', 'init_cell', 'hidden_size', 'num_layers', 'sequence_length', 'dropout_prob', 'bidirectional', 'batch_first', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, 0.0, False, True, None, None, None, None, 1.0, 'float32', 'basic_lstm')), ('document', 'fe4d0c3c55a162b8cfe10b05fabb7ce4'))
paddle.fluid.contrib.ctr_metric_bundle (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'b68d12366896c41065fc3738393da2aa'))
paddle.fluid.data (ArgSpec(args=['name', 'shape', 'dtype', 'type'], varargs=None, keywords=None, defaults=('float32', VarType.LOD_TENSOR)), ('document', '4e96c3d52ab30b07157f7588ba61d3d1'))
paddle.fluid.dygraph.Layer ('paddle.fluid.dygraph.layers.Layer', ('document', 'a889d5affd734ede273e94d4257163ab'))
paddle.fluid.dygraph.Layer.__init__ (ArgSpec(args=['self', 'name_scope', 'dtype'], varargs=None, keywords=None, defaults=(VarType.FP32,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.dygraph.Layer.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1'))
......
......@@ -166,6 +166,9 @@ message VarDesc {
required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
// True if the variable is an input data and
// have to check the feed data shape and dtype
optional bool need_check_feed = 4 [ default = false ];
}
message BlockDesc {
......
......@@ -110,6 +110,12 @@ class VarDesc {
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
bool NeedCheckFeed() const { return desc_.need_check_feed(); }
void SetNeedCheckFeed(bool need_check_feed) {
desc_.set_need_check_feed(need_check_feed);
}
private:
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
......
......@@ -162,7 +162,9 @@ void BindVarDsec(pybind11::module *m) {
.def("set_type", &pd::VarDesc::SetType)
.def("serialize_to_string", SerializeMessage<pd::VarDesc>)
.def("persistable", &pd::VarDesc::Persistable)
.def("set_persistable", &pd::VarDesc::SetPersistable);
.def("set_persistable", &pd::VarDesc::SetPersistable)
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", pd::proto::VarType::BOOL)
......
......@@ -44,6 +44,8 @@ from .data_feed_desc import *
from . import dataset
from .dataset import *
from .data import *
from . import trainer_desc
from . import inferencer
......@@ -97,6 +99,7 @@ __all__ = framework.__all__ + executor.__all__ + \
'one_hot',
'layers',
'contrib',
'data',
'dygraph',
'transpiler',
'nets',
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from . import core
from .layer_helper import LayerHelper
__all__ = ['data']
def data(name, shape, dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR):
"""
**Data Layer**
This function creates a variable on the global scope. The global variables
can be accessed by all the following operators in the graph.
Note:
`paddle.fluid.layers.data` is deprecated. It will be removed in a future
version. Please use this `paddle.fluid.data`.
The `paddle.fluid.layers.data` set shape at compile time but does NOT
check the shape of feeded data, this `paddle.fluid.data` checks the
shape of data feeded by Executor/ParallelExecutor during run time.
Args:
name (str): The name/alias of the variable
shape (list|tuple): List|Tuple of integers declaring the shape.
dtype (np.dtype|VarType|str): The type of the data. Supported dtype:
float16, float32, float64, int8, int16, int32, int64, uint8, bool.
type (VarType): The output type. Supported type: VarType.LOD_TENSOR,
VarType.SELECTED_ROWS, VarType.NCCL_ID. Default: VarType.LOD_TENSOR.
Returns:
Variable: The global variable that gives access to the data.
Examples:
.. code-block:: python
import paddle.fluid as fluid
# Creates a variable with fixed size [1, 2, 3]
# User can only feed data of the same shape to x
x = fluid.data(name='x', shape=[1, 2, 3], dtype='int64')
# Creates a variable with changable batch size -1.
# Users can feed data of any batch size into y,
# but size of each data sample has to be [3, 224, 224]
y = fluid.data(name='y', shape=[-1, 3, 224, 224], dtype='float32')
"""
helper = LayerHelper('data', **locals())
return helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=type,
stop_gradient=True,
lod_level=0,
is_data=True,
need_check_feed=True)
......@@ -22,7 +22,7 @@ import warnings
import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
import six
from .framework import Program, default_main_program, Variable
from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_
from . import core
from . import compiler
from .. import compat as cpt
......@@ -128,6 +128,91 @@ def as_numpy(tensor):
return None
def dtype_is_compatible_with(first, second):
"""
Returns True if the first dtype can be compatible the second one.
Currently, we require the two dtype's have to be same.
Args:
dtype (np.dtype|VarType|str): The type of data: float32, int64, etc.
Returns:
True if the two types are same.
"""
if not isinstance(first, core.VarDesc.VarType):
first = convert_np_dtype_to_dtype_(first)
if not isinstance(second, core.VarDesc.VarType):
second = convert_np_dtype_to_dtype_(second)
return first == second
def dimension_is_compatible_with(first, second):
"""
Returns True if the two dimensions are compatible.
A dimension is compatible with the other if:
1. The length of the dimensions are same.
2. Each non-negative number of the two dimentions are same.
3. For negative number or 'None' in a dimention, it means unknown so it
is compatible with any number.
Args:
first (list/tuple): integers representing shape. "None" or negative
number means unknown.
second (list/tuple): integers representing shape. "None" or negative
number means unknown.
Returns:
True if the two dimensions are compatible.
"""
dim_len = len(first)
if dim_len != len(second):
return False
for i in range(dim_len):
if first[i] is None or first[i] < 0:
continue
if second[i] is None or second[i] < 0:
continue
if first[i] != second[i]:
return False
return True
def check_feed_shape_type(var, feed):
"""
Returns True if the variable doesn't require feed check or it is compatible
with the shape and have same dtype as the feeded value.
A dimension is compatible with the other if:
1. The length of the dimensions are same.
2. Each non-negative number of the two dimentions are same.
3. For negative number or 'None' in a dimention, it means unknown so it
is compatible with any number.
Args:
var (Variable): the Variable object
feed (LoDTensor): the feeded value, which must be a LoDTensor
Returns:
True if the shape and dtype of variable is compatible with the feed value
Raises:
ValueError: if the shape or dtype of the variable is not compatible with
the feed value
"""
if var.desc.need_check_feed():
if not dimension_is_compatible_with(feed.shape(), var.shape):
raise ValueError('Cannot feed value of shape %r for Variable %r, '
'which has shape %r' %
(feed.shape, var.name, var.shape))
if not dtype_is_compatible_with(feed._dtype(), var.dtype):
raise ValueError('Cannot feed value of type %r for Variable %r, '
'which has type %r' %
(feed._dtype(), var.name, var.dtype))
return True
def has_feed_operators(block, feed_targets, feed_holder_name):
""" Check whether the block already has feed operators.
......@@ -443,12 +528,15 @@ class Executor(object):
def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework
for op in program.global_block().ops:
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = _as_lodtensor(cur_feed, self.place)
var = global_block.var(feed_target_name)
check_feed_shape_type(var, cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
......@@ -492,6 +580,11 @@ class Executor(object):
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy):
exe = program._executor
# TODO(zhenghuihuang): quantization uses Graph in CompiledProgram
# instead of program. We will add support for checking Vars in Graph
need_check_feed = program._program is not None
if need_check_feed:
global_block = program._program.global_block()
if isinstance(feed, dict):
feed_tensor_dict = dict()
for feed_name in feed:
......@@ -504,6 +597,9 @@ class Executor(object):
"The input({}) should be numpy.array, but not {}.".format(
feed_name, type(feed[feed_name]))
feed_tensor.set(feed[feed_name], core.CPUPlace())
if need_check_feed:
var = global_block.var(feed_name)
check_feed_shape_type(var, feed_tensor)
feed_tensor_dict[feed_name] = feed_tensor
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
......@@ -528,6 +624,9 @@ class Executor(object):
feed_name, type(each[feed_name]))
tmp.set(tensor, program._places[i])
tensor = tmp
if need_check_feed:
var = global_block.var(feed_name)
check_feed_shape_type(var, tensor)
res_dict[feed_name] = tensor
res.append(res_dict)
exe.feed_tensors_into_local_scopes(res)
......@@ -645,6 +744,7 @@ class Executor(object):
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly.
if not compiled:
return self._run_program(
......
......@@ -416,6 +416,8 @@ class Variable(object):
stop_gradient (bool): True if the variable will stop to calculate its
gradients when backward. Default: False.
is_data (bool): True if the variable is an input data. Default: False
need_check_feed (bool): True if the variable is an input data and have
to check the feed data shape and dtype. Default: False
Notes:
The constructor of Variable should not be invoked directly. Please
......@@ -444,6 +446,7 @@ class Variable(object):
error_clip=None,
stop_gradient=False,
is_data=False,
need_check_feed=False,
**kwargs):
self.block = block
if name is None:
......@@ -532,6 +535,9 @@ class Variable(object):
"persistable is {2}. They are not matched".format(
self.name, self.persistable, persistable))
if need_check_feed and is_new_var:
self.desc.set_need_check_feed(need_check_feed)
if capacity is not None:
if is_new_var:
self.desc.set_capacity(capacity)
......@@ -2109,7 +2115,8 @@ class Block(object):
dtype=var.dtype,
type=var.type,
persistable=True if force_persistable else var.persistable,
is_data=var.is_data)
is_data=var.is_data,
need_check_feed=var.desc.need_check_feed())
else:
ret_var = self.create_var(
name=var.name,
......@@ -2118,7 +2125,8 @@ class Block(object):
type=var.type,
lod_level=var.lod_level,
persistable=True if force_persistable else var.persistable,
is_data=var.is_data)
is_data=var.is_data,
need_check_feed=var.desc.need_check_feed())
return ret_var
......@@ -3730,6 +3738,8 @@ class Program(object):
for var in list(other.global_block().vars.values()):
if var.is_data:
self.global_block().var(var.name).is_data = True
if var.desc.need_check_feed():
self.global_block().var(var.name).desc.set_need_check_feed(True)
def list_vars(self):
"""
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import multiprocessing
import numpy as np
import os
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
import paddle.fluid.core as core
import unittest
os.environ['CPU_NUM'] = str(4)
np.random.seed(123)
class TestFeedData(unittest.TestCase):
'''
Test paddle.fluid.data feeds with different shape and types.
Note: paddle.fluid.data is not paddle.fluid.layers.data.
'''
def setUp(self):
self.hidden_sizes = [25, 20, 15]
self.base_batch_size = 10
self.class_num = 10
self.iterations = 5
def _get_batch_size(self, use_cuda, use_parallel_executor):
batch_size_times = 1
if use_parallel_executor:
batch_size_times = core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
return self.base_batch_size * batch_size_times
def _simple_fc_net(self, in_size, label_size, class_num, hidden_sizes):
in_data = fluid.data(name="data", dtype='float32', shape=in_size)
label = fluid.data(name='label', dtype='int64', shape=label_size)
hidden = in_data
for hidden_size in hidden_sizes:
hidden = fluid.layers.fc(hidden, size=hidden_size)
predict_label = fluid.layers.fc(hidden, size=class_num, act='softmax')
loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=predict_label, label=label))
optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss)
return in_data, label, loss
def test(self):
for use_cuda in [True, False] if core.is_compiled_with_cuda(
) else [False]:
for use_parallel_executor in [False, True]:
print('Test Parameters:'),
print({
'use_cuda': use_cuda,
'use_parallel_executor': use_parallel_executor,
})
self._test_feed_data_match_shape_type(use_cuda,
use_parallel_executor)
self._test_feed_data_contains_neg_one(use_cuda,
use_parallel_executor)
with self.assertRaises(ValueError):
self._test_feed_data_shape_mismatch(use_cuda,
use_parallel_executor)
def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
in_size = [-1, 3, 4, 8]
feed_in_data = np.random.uniform(
size=[batch_size, 3, 4, 5]).astype(np.float32)
label_size = [-1, 1]
feed_label = np.random.randint(
low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64)
self._feed_data_in_executor(in_size, label_size, feed_in_data,
feed_label, use_cuda, use_parallel_executor)
def _test_feed_data_contains_neg_one(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
in_size = [-1, 3, 4, 5]
feed_in_data = np.random.uniform(
size=[batch_size, 3, 4, 5]).astype(np.float32)
label_size = (-1, 1)
feed_label = np.random.randint(
low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64)
self._feed_data_in_executor(in_size, label_size, feed_in_data,
feed_label, use_cuda, use_parallel_executor)
def _test_feed_data_match_shape_type(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
in_size = [batch_size, 3, 4, 5]
feed_in_data = np.random.uniform(size=in_size).astype(np.float32)
label_size = [batch_size, 1]
feed_label = np.random.randint(
low=0, high=self.class_num, size=label_size).astype(np.int64)
self._feed_data_in_executor(in_size, label_size, feed_in_data,
feed_label, use_cuda, use_parallel_executor)
def _feed_data_in_executor(self, in_size, label_size, feed_in_data,
feed_label, use_cuda, use_parallel_executor):
startup_program = fluid.Program()
main_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
in_data, label, loss = self._simple_fc_net(
in_size, label_size, self.class_num, self.hidden_sizes)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
train_program = main_program
if use_parallel_executor:
train_program = compiler.CompiledProgram(
main_program).with_data_parallel(loss_name=loss.name)
for i in range(self.iterations):
fetches = exe.run(
train_program,
feed={in_data.name: feed_in_data,
label.name: feed_label},
fetch_list=[loss.name])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册