From 88af4ab650674dfe9323b16e31e844bb2deb3546 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 26 Sep 2019 07:42:53 +0800 Subject: [PATCH] Add new data layer (#19916) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The new "fluid.data" changes old "fluid.layers.data": 1. Add shape and dtype check. 2. Remove "append_batch_size" parameter. We won't offer this in the new data layer because other deep learning platforms don't have this kind of data layer pre-processing. It may confuse users. 3. Remove "stop gradient" parameter because the data layer doesn't do back-propagation TODOļ¼š Now data layer feeded by executor is checked, will we want to check the feed data of readers in the future? --- paddle/fluid/API.spec | 1 + paddle/fluid/framework/framework.proto | 3 + paddle/fluid/framework/var_desc.h | 6 + paddle/fluid/pybind/protobuf.cc | 4 +- python/paddle/fluid/__init__.py | 3 + python/paddle/fluid/data.py | 73 +++++++++ python/paddle/fluid/executor.py | 104 ++++++++++++- python/paddle/fluid/framework.py | 14 +- .../test_feed_data_check_shape_type.py | 145 ++++++++++++++++++ 9 files changed, 348 insertions(+), 5 deletions(-) create mode 100644 python/paddle/fluid/data.py create mode 100644 python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 77ad4f5415..47b3319cad 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -556,6 +556,7 @@ paddle.fluid.contrib.BasicLSTMUnit.sublayers (ArgSpec(args=['self', 'include_sub paddle.fluid.contrib.BasicLSTMUnit.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.basic_lstm (ArgSpec(args=['input', 'init_hidden', 'init_cell', 'hidden_size', 'num_layers', 'sequence_length', 'dropout_prob', 'bidirectional', 'batch_first', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, 0.0, False, True, None, None, None, None, 1.0, 'float32', 'basic_lstm')), ('document', 'fe4d0c3c55a162b8cfe10b05fabb7ce4')) paddle.fluid.contrib.ctr_metric_bundle (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'b68d12366896c41065fc3738393da2aa')) +paddle.fluid.data (ArgSpec(args=['name', 'shape', 'dtype', 'type'], varargs=None, keywords=None, defaults=('float32', VarType.LOD_TENSOR)), ('document', '4e96c3d52ab30b07157f7588ba61d3d1')) paddle.fluid.dygraph.Layer ('paddle.fluid.dygraph.layers.Layer', ('document', 'a889d5affd734ede273e94d4257163ab')) paddle.fluid.dygraph.Layer.__init__ (ArgSpec(args=['self', 'name_scope', 'dtype'], varargs=None, keywords=None, defaults=(VarType.FP32,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.dygraph.Layer.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1')) diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index efdabffb9b..2c1296d5ca 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -166,6 +166,9 @@ message VarDesc { required string name = 1; required VarType type = 2; optional bool persistable = 3 [ default = false ]; + // True if the variable is an input data and + // have to check the feed data shape and dtype + optional bool need_check_feed = 4 [ default = false ]; } message BlockDesc { diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index 7c82e1d68f..6e8be0fdd4 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -110,6 +110,12 @@ class VarDesc { void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } + bool NeedCheckFeed() const { return desc_.need_check_feed(); } + + void SetNeedCheckFeed(bool need_check_feed) { + desc_.set_need_check_feed(need_check_feed); + } + private: const proto::VarType::TensorDesc &tensor_desc() const; std::vector tensor_descs() const; diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 31b5dd5d7c..7c6e2f2414 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -162,7 +162,9 @@ void BindVarDsec(pybind11::module *m) { .def("set_type", &pd::VarDesc::SetType) .def("serialize_to_string", SerializeMessage) .def("persistable", &pd::VarDesc::Persistable) - .def("set_persistable", &pd::VarDesc::SetPersistable); + .def("set_persistable", &pd::VarDesc::SetPersistable) + .def("need_check_feed", &pd::VarDesc::NeedCheckFeed) + .def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed); pybind11::enum_(var_desc, "VarType", "") .value("BOOL", pd::proto::VarType::BOOL) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 180fae6631..262c7a5ea8 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -44,6 +44,8 @@ from .data_feed_desc import * from . import dataset from .dataset import * +from .data import * + from . import trainer_desc from . import inferencer @@ -97,6 +99,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'one_hot', 'layers', 'contrib', + 'data', 'dygraph', 'transpiler', 'nets', diff --git a/python/paddle/fluid/data.py b/python/paddle/fluid/data.py new file mode 100644 index 0000000000..4a67bab544 --- /dev/null +++ b/python/paddle/fluid/data.py @@ -0,0 +1,73 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from . import core +from .layer_helper import LayerHelper + +__all__ = ['data'] + + +def data(name, shape, dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR): + """ + **Data Layer** + + This function creates a variable on the global scope. The global variables + can be accessed by all the following operators in the graph. + + Note: + `paddle.fluid.layers.data` is deprecated. It will be removed in a future + version. Please use this `paddle.fluid.data`. + + The `paddle.fluid.layers.data` set shape at compile time but does NOT + check the shape of feeded data, this `paddle.fluid.data` checks the + shape of data feeded by Executor/ParallelExecutor during run time. + + Args: + name (str): The name/alias of the variable + shape (list|tuple): List|Tuple of integers declaring the shape. + dtype (np.dtype|VarType|str): The type of the data. Supported dtype: + float16, float32, float64, int8, int16, int32, int64, uint8, bool. + type (VarType): The output type. Supported type: VarType.LOD_TENSOR, + VarType.SELECTED_ROWS, VarType.NCCL_ID. Default: VarType.LOD_TENSOR. + + Returns: + Variable: The global variable that gives access to the data. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + # Creates a variable with fixed size [1, 2, 3] + # User can only feed data of the same shape to x + x = fluid.data(name='x', shape=[1, 2, 3], dtype='int64') + + # Creates a variable with changable batch size -1. + # Users can feed data of any batch size into y, + # but size of each data sample has to be [3, 224, 224] + y = fluid.data(name='y', shape=[-1, 3, 224, 224], dtype='float32') + + """ + helper = LayerHelper('data', **locals()) + return helper.create_global_variable( + name=name, + shape=shape, + dtype=dtype, + type=type, + stop_gradient=True, + lod_level=0, + is_data=True, + need_check_feed=True) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index ed0479be84..27117585be 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -22,7 +22,7 @@ import warnings import numpy as np from .wrapped_decorator import signature_safe_contextmanager import six -from .framework import Program, default_main_program, Variable +from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_ from . import core from . import compiler from .. import compat as cpt @@ -128,6 +128,91 @@ def as_numpy(tensor): return None +def dtype_is_compatible_with(first, second): + """ + Returns True if the first dtype can be compatible the second one. + Currently, we require the two dtype's have to be same. + + Args: + dtype (np.dtype|VarType|str): The type of data: float32, int64, etc. + + Returns: + True if the two types are same. + """ + if not isinstance(first, core.VarDesc.VarType): + first = convert_np_dtype_to_dtype_(first) + if not isinstance(second, core.VarDesc.VarType): + second = convert_np_dtype_to_dtype_(second) + return first == second + + +def dimension_is_compatible_with(first, second): + """ + Returns True if the two dimensions are compatible. + + A dimension is compatible with the other if: + 1. The length of the dimensions are same. + 2. Each non-negative number of the two dimentions are same. + 3. For negative number or 'None' in a dimention, it means unknown so it + is compatible with any number. + + Args: + first (list/tuple): integers representing shape. "None" or negative + number means unknown. + second (list/tuple): integers representing shape. "None" or negative + number means unknown. + + Returns: + True if the two dimensions are compatible. + """ + + dim_len = len(first) + if dim_len != len(second): + return False + + for i in range(dim_len): + if first[i] is None or first[i] < 0: + continue + if second[i] is None or second[i] < 0: + continue + if first[i] != second[i]: + return False + + return True + + +def check_feed_shape_type(var, feed): + """ + Returns True if the variable doesn't require feed check or it is compatible + with the shape and have same dtype as the feeded value. + + A dimension is compatible with the other if: + 1. The length of the dimensions are same. + 2. Each non-negative number of the two dimentions are same. + 3. For negative number or 'None' in a dimention, it means unknown so it + is compatible with any number. + + Args: + var (Variable): the Variable object + feed (LoDTensor): the feeded value, which must be a LoDTensor + Returns: + True if the shape and dtype of variable is compatible with the feed value + Raises: + ValueError: if the shape or dtype of the variable is not compatible with + the feed value + """ + if var.desc.need_check_feed(): + if not dimension_is_compatible_with(feed.shape(), var.shape): + raise ValueError('Cannot feed value of shape %r for Variable %r, ' + 'which has shape %r' % + (feed.shape, var.name, var.shape)) + if not dtype_is_compatible_with(feed._dtype(), var.dtype): + raise ValueError('Cannot feed value of type %r for Variable %r, ' + 'which has type %r' % + (feed._dtype(), var.name, var.dtype)) + return True + + def has_feed_operators(block, feed_targets, feed_holder_name): """ Check whether the block already has feed operators. @@ -443,12 +528,15 @@ class Executor(object): def _feed_data(self, program, feed, feed_var_name, scope): # feed var to framework - for op in program.global_block().ops: + global_block = program.global_block() + for op in global_block.ops: if op.desc.type() == 'feed': feed_target_name = op.desc.output('Out')[0] cur_feed = feed[feed_target_name] if not isinstance(cur_feed, core.LoDTensor): cur_feed = _as_lodtensor(cur_feed, self.place) + var = global_block.var(feed_target_name) + check_feed_shape_type(var, cur_feed) idx = op.desc.attr('col') core.set_feed_variable(scope, cur_feed, feed_var_name, idx) else: @@ -492,6 +580,11 @@ class Executor(object): def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, return_numpy): exe = program._executor + # TODO(zhenghuihuang): quantization uses Graph in CompiledProgram + # instead of program. We will add support for checking Vars in Graph + need_check_feed = program._program is not None + if need_check_feed: + global_block = program._program.global_block() if isinstance(feed, dict): feed_tensor_dict = dict() for feed_name in feed: @@ -504,6 +597,9 @@ class Executor(object): "The input({}) should be numpy.array, but not {}.".format( feed_name, type(feed[feed_name])) feed_tensor.set(feed[feed_name], core.CPUPlace()) + if need_check_feed: + var = global_block.var(feed_name) + check_feed_shape_type(var, feed_tensor) feed_tensor_dict[feed_name] = feed_tensor exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) @@ -528,6 +624,9 @@ class Executor(object): feed_name, type(each[feed_name])) tmp.set(tensor, program._places[i]) tensor = tmp + if need_check_feed: + var = global_block.var(feed_name) + check_feed_shape_type(var, tensor) res_dict[feed_name] = tensor res.append(res_dict) exe.feed_tensors_into_local_scopes(res) @@ -645,6 +744,7 @@ class Executor(object): fetch_list = [] compiled = isinstance(program, compiler.CompiledProgram) + # For backward compatibility, run directly. if not compiled: return self._run_program( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c9fb957656..a13e1f2a31 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -416,6 +416,8 @@ class Variable(object): stop_gradient (bool): True if the variable will stop to calculate its gradients when backward. Default: False. is_data (bool): True if the variable is an input data. Default: False + need_check_feed (bool): True if the variable is an input data and have + to check the feed data shape and dtype. Default: False Notes: The constructor of Variable should not be invoked directly. Please @@ -444,6 +446,7 @@ class Variable(object): error_clip=None, stop_gradient=False, is_data=False, + need_check_feed=False, **kwargs): self.block = block if name is None: @@ -532,6 +535,9 @@ class Variable(object): "persistable is {2}. They are not matched".format( self.name, self.persistable, persistable)) + if need_check_feed and is_new_var: + self.desc.set_need_check_feed(need_check_feed) + if capacity is not None: if is_new_var: self.desc.set_capacity(capacity) @@ -2109,7 +2115,8 @@ class Block(object): dtype=var.dtype, type=var.type, persistable=True if force_persistable else var.persistable, - is_data=var.is_data) + is_data=var.is_data, + need_check_feed=var.desc.need_check_feed()) else: ret_var = self.create_var( name=var.name, @@ -2118,7 +2125,8 @@ class Block(object): type=var.type, lod_level=var.lod_level, persistable=True if force_persistable else var.persistable, - is_data=var.is_data) + is_data=var.is_data, + need_check_feed=var.desc.need_check_feed()) return ret_var @@ -3730,6 +3738,8 @@ class Program(object): for var in list(other.global_block().vars.values()): if var.is_data: self.global_block().var(var.name).is_data = True + if var.desc.need_check_feed(): + self.global_block().var(var.name).desc.set_need_check_feed(True) def list_vars(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py b/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py new file mode 100644 index 0000000000..2489ae8e26 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py @@ -0,0 +1,145 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import multiprocessing +import numpy as np +import os +import paddle +import paddle.fluid as fluid +import paddle.fluid.compiler as compiler +import paddle.fluid.core as core +import unittest + +os.environ['CPU_NUM'] = str(4) +np.random.seed(123) + + +class TestFeedData(unittest.TestCase): + ''' + Test paddle.fluid.data feeds with different shape and types. + Note: paddle.fluid.data is not paddle.fluid.layers.data. + ''' + + def setUp(self): + self.hidden_sizes = [25, 20, 15] + self.base_batch_size = 10 + self.class_num = 10 + self.iterations = 5 + + def _get_batch_size(self, use_cuda, use_parallel_executor): + batch_size_times = 1 + if use_parallel_executor: + batch_size_times = core.get_cuda_device_count( + ) if use_cuda else int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + return self.base_batch_size * batch_size_times + + def _simple_fc_net(self, in_size, label_size, class_num, hidden_sizes): + in_data = fluid.data(name="data", dtype='float32', shape=in_size) + label = fluid.data(name='label', dtype='int64', shape=label_size) + + hidden = in_data + for hidden_size in hidden_sizes: + hidden = fluid.layers.fc(hidden, size=hidden_size) + + predict_label = fluid.layers.fc(hidden, size=class_num, act='softmax') + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + return in_data, label, loss + + def test(self): + for use_cuda in [True, False] if core.is_compiled_with_cuda( + ) else [False]: + for use_parallel_executor in [False, True]: + print('Test Parameters:'), + print({ + 'use_cuda': use_cuda, + 'use_parallel_executor': use_parallel_executor, + }) + self._test_feed_data_match_shape_type(use_cuda, + use_parallel_executor) + self._test_feed_data_contains_neg_one(use_cuda, + use_parallel_executor) + with self.assertRaises(ValueError): + self._test_feed_data_shape_mismatch(use_cuda, + use_parallel_executor) + + def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor): + batch_size = self._get_batch_size(use_cuda, use_parallel_executor) + in_size = [-1, 3, 4, 8] + feed_in_data = np.random.uniform( + size=[batch_size, 3, 4, 5]).astype(np.float32) + label_size = [-1, 1] + feed_label = np.random.randint( + low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64) + self._feed_data_in_executor(in_size, label_size, feed_in_data, + feed_label, use_cuda, use_parallel_executor) + + def _test_feed_data_contains_neg_one(self, use_cuda, use_parallel_executor): + batch_size = self._get_batch_size(use_cuda, use_parallel_executor) + in_size = [-1, 3, 4, 5] + feed_in_data = np.random.uniform( + size=[batch_size, 3, 4, 5]).astype(np.float32) + label_size = (-1, 1) + feed_label = np.random.randint( + low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64) + self._feed_data_in_executor(in_size, label_size, feed_in_data, + feed_label, use_cuda, use_parallel_executor) + + def _test_feed_data_match_shape_type(self, use_cuda, use_parallel_executor): + batch_size = self._get_batch_size(use_cuda, use_parallel_executor) + in_size = [batch_size, 3, 4, 5] + feed_in_data = np.random.uniform(size=in_size).astype(np.float32) + label_size = [batch_size, 1] + feed_label = np.random.randint( + low=0, high=self.class_num, size=label_size).astype(np.int64) + self._feed_data_in_executor(in_size, label_size, feed_in_data, + feed_label, use_cuda, use_parallel_executor) + + def _feed_data_in_executor(self, in_size, label_size, feed_in_data, + feed_label, use_cuda, use_parallel_executor): + + startup_program = fluid.Program() + main_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + in_data, label, loss = self._simple_fc_net( + in_size, label_size, self.class_num, self.hidden_sizes) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(startup_program) + + train_program = main_program + if use_parallel_executor: + train_program = compiler.CompiledProgram( + main_program).with_data_parallel(loss_name=loss.name) + + for i in range(self.iterations): + fetches = exe.run( + train_program, + feed={in_data.name: feed_in_data, + label.name: feed_label}, + fetch_list=[loss.name]) + + +if __name__ == '__main__': + unittest.main() -- GitLab