提交 224bd8f0 编写于 作者: K Kexin Zhao 提交者: Yi Wang

Add lod_tensor.py for ease of creating lod tensor in book examples (#10817)

* add lod_tensor utility python module

* add lod_tensor test code

* add more lod tensor tests

* modify word2vec example code using new api

* add comment
上级 0d598cf9
......@@ -48,6 +48,7 @@ from transpiler import DistributeTranspiler, SimpleDistributeTranspiler, \
InferenceTranspiler, memory_optimize, release_memory
from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close, Select)
from lod_tensor import create_lod_tensor, create_random_int_lodtensor
import clip
import profiler
import unique_name
......@@ -59,7 +60,7 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + [
parallel_executor.__all__ + lod_tensor.__all__ + [
'io',
'initializer',
'layers',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import core
import numpy as np
__all__ = ['create_lod_tensor', 'create_random_int_lodtensor']
def _validate_lod(lod, tensor_height=-1):
"""Check whether the input length-based lod info is valid.
There are several things to check:
1. lod should be a list of lists. Empty list is fine.
2. The length of each sublist (a lod level) should be at least one.
3. Each element in each lod level should be an integer greater than 0.
4. The sum of one lod level should be equal to the length of the next lod level.
5. The sum of the last lod level should be equal to the tensor height.
Bypass this check if user does not provide tensor_height as input.
Args:
lod: the length-based lod info, e.g., [[2, 3], [2, 1, 2, 3, 4]].
tensor_height: the outermost dimension of the tensor with which the input
lod is associated with.
Returns:
A boolean indicating whether the input lod is valid or not.
"""
assert isinstance(lod, list), "lod should be a list"
# Empty lod is fine
if len(lod) == 0:
return True
lod_sum = []
for level in lod:
assert isinstance(level, list), "each item in lod should be a list"
# Each level of lod should have at least one length info
if len(level) < 1:
return False
level_sum = 0
for lod_len in level:
# Each length in a level should be > 0
if lod_len <= 0:
return False
level_sum += lod_len
lod_sum.append(level_sum)
for idx, val in enumerate(lod_sum[:-1]):
# Each level's sum should be equal to
# the number of items in the next level
if val != len(lod[idx + 1]):
return False
if tensor_height == -1:
return True
else:
# Last level's sum should be equal to the tensor height
return lod_sum[-1] == tensor_height
def _convert_lod(lod):
"""Convert a length-based lod to a offset-based lod.
If the length-based lod is [[2, 3], [2, 1, 2, 3, 4]],
then the offset-based lod is [[0, 2, 5], [0, 2, 3, 5, 8, 12]].
Args:
lod: a length-based lod info.
Returns:
A list of lists as the offset-based lod converted to from the input lod.
"""
new_lod = []
for level in lod:
cur_len = 0
new_level = [cur_len]
for lod_len in level:
cur_len += lod_len
new_level.append(cur_len)
new_lod.append(new_level)
return new_lod
def create_lod_tensor(data, lod, place):
"""Create a lod tensor from a numpy array or an existing lod tensor.
Create a lod tensor by doing the following:
1. Check that the length-based input lod is valid.
2. Convert the length-based lod to a offset-based LoD.
3. Copy the data from a numpy array or a existing lod tensor to
CPU or GPU device (based on input place).
4. Set the level of detail (LoD) using the offset-based LoD.
Use example:
Suppose we want LoDTensor to hold data for sequences of word, where each word is
represented by an integer. If we want to create a LoDTensor to represent two
sentences, one of 2 words, and one of 3 words.
Then 'data' can be a numpy array of integers with shape (5, 1).
'lod' will be [[2, 3]], indicating the length(# of words) in each sentence.
This length-based input lod [[2, 3]] will be converted to offset-based lod [[0, 2, 5]]
inside the function call.
Please refer to
github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/lod_tensor.md
for more details regarding LoD.
Args:
data: a numpy array or a LoDTensor holding the data to be copied.
lod: a list of lists indicating the length-based LoD info specified by the user.
place: CPU or GPU place indicating where the data in the new LoDTensor will be stored.
Returns:
A fluid LoDTensor object with tensor data and lod info.
"""
if isinstance(data, core.LoDTensor):
return create_lod_tensor(np.array(data), lod, place)
elif isinstance(data, np.ndarray):
assert _validate_lod(lod,
data.shape[0]), "the provided lod info is invalid"
tensor = core.LoDTensor()
tensor.set(data, place)
tensor.set_lod(_convert_lod(lod))
return tensor
else:
raise Exception(
"data should be either a LoDTensor or a Numpy array, but you pass type %s instead"
% (type(data)))
def create_random_int_lodtensor(lod, base_shape, place, low, high):
"""Create a LoDTensor containing random integers.
This function is frequently used in the book examples. So we revised it based on
the new create_lod_tensor API and put it here in the lod_tensor module to simplify
the code.
The function does the following:
1. Calculate the overall shape of the LoDTensor based on the length-based 'lod' input
and the shape of the basic element in 'base_shape'.
2. Create a numpy array of this shape.
3. Create the LoDTensor using create_lod_tensor API.
Suppose we want LoDTensor to hold data for sequences of word, where each word is
represented by an integer. If we want to create a LoDTensor to represent two
sentences, one of 2 words, and one of 3 words. Then 'base_shape' is [1], input
length-based 'lod' is [[2, 3]]. Then the overall shape of the LoDTensor would be
[5, 1], holding 5 words for two sentences.
Args:
data: a numpy array or a LoDTensor holding the data to be copied.
lod: a list of lists indicating the length-based LoD info specified by the user.
base_shape: the shape of the basic element to be held by the LoDTensor.
place: CPU or GPU place indicating where the data in the new LoDTensor will be stored.
low: the lower bound of the random integers.
high: the upper bound of the random integers.
Returns:
A fluid LoDTensor object with tensor data and lod info.
"""
assert isinstance(base_shape, list), "base_shape should be a list"
converted_lod = _convert_lod(lod)
# append the total number of basic elements to the front of its shape
overall_shape = [converted_lod[-1][-1]] + base_shape
# the range of integer data elements is [low, high]
data = np.random.random_integers(low, high, overall_shape).astype("int64")
return create_lod_tensor(data, lod, place)
......@@ -21,15 +21,6 @@ import math
import sys
def create_random_lodtensor(lod, place, low, high):
# The range of data elements is [low, high]
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
PASS_NUM = 100
EMBED_SIZE = 32
......@@ -175,16 +166,22 @@ def infer(use_cuda, save_dirname=None):
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
# Setup inputs, by creating 4 words, the lod of which should be [0, 1]
lod = [0, 1]
first_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
second_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
third_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
fourth_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
# Setup inputs by creating 4 LoDTensors representing 4 words. Here each word
# is simply an index to look up for the corresponding word vector and hence
# the shape of word (base_shape) should be [1]. The length-based level of
# detail (lod) info of each LoDtensor should be [[1]] meaning there is only
# one lod_level and there is only one sequence of one word on this level.
# Note that lod info should be a list of lists.
lod = [[1]]
base_shape = [1]
first_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
second_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
third_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
fourth_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.lod_tensor import create_lod_tensor, create_random_int_lodtensor, _validate_lod, _convert_lod
import numpy
import unittest
class TestLoDTensor(unittest.TestCase):
def test_validate_lod(self):
lod = (1, 2, 1)
self.assertRaises(AssertionError, _validate_lod, lod, -1)
lod = [[1, 2], (2, 3)]
self.assertRaises(AssertionError, _validate_lod, lod, -1)
lod = [1, 2, 3]
self.assertRaises(AssertionError, _validate_lod, lod, -1)
lod = []
self.assertTrue(_validate_lod(lod, -1))
lod = [[], [1], [3]]
self.assertFalse(_validate_lod(lod, -1))
lod = [[0], [-1], [3]]
self.assertFalse(_validate_lod(lod, -1))
# Each level's sum should be equal to the number of items in the next level
# Moreover, last level's sum should be equal to the tensor height
lod = [[2, 3], [1, 3, 1, 2, 1]]
self.assertTrue(_validate_lod(lod, tensor_height=8))
lod = [[1, 3], [2, 1, 3]]
self.assertFalse(_validate_lod(lod, tensor_height=6))
lod = [[1, 3], [2, 1, 3, 4]]
self.assertFalse(_validate_lod(lod, tensor_height=5))
def test_convert_lod(self):
lod = [[1, 2, 3]]
converted_lod = [[0, 1, 3, 6]]
self.assertEqual(_convert_lod(lod), converted_lod)
lod = [[2, 3], [1, 3, 1, 2, 1]]
converted_lod = [[0, 2, 5], [0, 1, 4, 5, 7, 8]]
self.assertEqual(_convert_lod(lod), converted_lod)
def test_create_lod_tensor(self):
# Only numpy array or a fluid LoDTensor is valid input to
# create_lod_tensor function, currently a list of lists is not.
data = [[1, 2], [3, 4]]
self.assertRaises(Exception, create_lod_tensor, data, [],
fluid.CPUPlace())
# Create LoDTensor from numpy array
data = numpy.random.random([10, 1])
lod = [[2, 1], [3, 3, 4]]
tensor = create_lod_tensor(data, lod, fluid.CPUPlace())
self.assertEqual(tensor.lod(), [[0, 2, 3], [0, 3, 6, 10]])
# Create LoDTensor from another LoDTensor, they are differnt instances
new_lod = [[2, 2, 1], [1, 2, 2, 3, 2]]
new_tensor = create_lod_tensor(tensor, new_lod, fluid.CPUPlace())
self.assertEqual(tensor.lod(), [[0, 2, 3], [0, 3, 6, 10]])
self.assertEqual(new_tensor.lod(), [[0, 2, 4, 5], [0, 1, 3, 5, 8, 10]])
def test_create_random_int_lodtensor(self):
# The shape of a word, commonly used in speech and NLP problem, is [1]
shape = [1]
lod = [[2, 3, 5]]
dict_size = 10000
low = 0
high = dict_size - 1
tensor = create_random_int_lodtensor(lod, shape,
fluid.CPUPlace(), low, high)
self.assertEqual(tensor.lod(), [[0, 2, 5, 10]])
self.assertEqual(tensor.shape(), [10, 1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册