未验证 提交 df7cc3a0 编写于 作者: G GGBond8488 提交者: GitHub

Fluid clean move dygraph profiler, fluid.input.on_hot and fluid.input.embedding (#50141)

* remove dygraph.profiler

* remove fluid.input.one-hot and move embedding to paddle.static.nn

* fix unitest error

* fix type error

* fix type error

* fix xpu test error

* fxi sample code error

* fxi sample code error

* fix sample code error

* remove test.py

* remove variable in docstr
上级 5f5a2082
...@@ -66,7 +66,6 @@ from . import average ...@@ -66,7 +66,6 @@ from . import average
from . import metrics from . import metrics
from . import transpiler from . import transpiler
from . import incubate from . import incubate
from .input import embedding, one_hot
from .param_attr import ParamAttr, WeightNormParamAttr from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder from .data_feeder import DataFeeder
...@@ -129,8 +128,6 @@ __all__ = ( ...@@ -129,8 +128,6 @@ __all__ = (
+ [ + [
'io', 'io',
'initializer', 'initializer',
'embedding',
'one_hot',
'layers', 'layers',
'contrib', 'contrib',
'data', 'data',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import core
__all__ = [
'start_gperf_profiler',
'stop_gperf_profiler',
]
def start_gperf_profiler():
core.start_imperative_gperf_profiler()
def stop_gperf_profiler():
core.stop_imperative_gperf_profiler()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from .framework import Variable, _non_static_mode, static_only
from .layer_helper import LayerHelper
from .data_feeder import check_variable_and_dtype, check_dtype
from ..utils import deprecated
__all__ = ['one_hot', 'embedding']
@deprecated(since='2.0.0', update_to='paddle.nn.functional.one_hot')
def one_hot(input, depth, allow_out_of_range=False):
"""
:alias_main: paddle.nn.functional.one_hot
:alias: paddle.nn.functional.one_hot,paddle.nn.functional.common.one_hot
:old_api: paddle.fluid.one_hot
The operator converts each id in the input to an one-hot vector with a
depth length. The value in the vector dimension corresponding to the id
is 1, and the value in the remaining dimension is 0.
The shape of output Tensor or LoDTensor is generated by appending depth dimension
behind the last dimension of the input shape.
.. code-block:: text
Example 1 (allow_out_of_range=False):
input:
X.shape = [4]
X.data = [1, 1, 3, 0]
depth = 4
output:
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]]
Example 2 (allow_out_of_range=True):
input:
X.shape = [4]
X.data = [1, 1, 5, 0]
depth = 4
allow_out_of_range = True
output:
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 0.], # This id is 5, which goes beyond depth, so set it all-zeros data.
[1., 0., 0., 0.]]
Example 3 (allow_out_of_range=False):
input:
X.shape = [4]
X.data = [1, 1, 5, 0]
depth = 4
allow_out_of_range = False
output: Throw an exception for Illegal value
The second dimension in X is 5, which is greater than depth.
Allow_out_of_range =False means that does not allow the word id to exceed depth,
so it throws an exception.
Args:
input(Variable): Tensor or LoDTensor with shape :math:`[N_1, N_2, ..., N_k]` ,
which contains at least one dimension. The data type is int32 or int64.
depth(int): An integer defining the depth of the one hot dimension. If input
is word id, depth is generally the dictionary size.
allow_out_of_range(bool): A bool value indicating whether the input
indices could be out of range :math:`[0, depth)` . When input indices are
out of range, exceptions :code:`Illegal value` is raised if :attr:`allow_out_of_range`
is False, or zero-filling representations is created if it is set True.
Default: False.
Returns:
Variable: The one-hot representations of input. A Tensor or LoDTensor with type float32.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
# Correspond to the first example above, where label.shape is 4 and one_hot_label.shape is [4, 4].
label = fluid.data(name="label", shape=[4], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=4)
"""
check_variable_and_dtype(input, 'input', ['int32', 'int64'], 'one_hot_v2')
helper = LayerHelper("one_hot_v2", **locals())
one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if _non_static_mode():
inputs = {'X': input}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op(
type="one_hot_v2",
inputs=inputs,
attrs=attrs,
outputs={'Out': one_hot_out},
stop_gradient=True,
)
return one_hot_out
@static_only
@deprecated(since='2.0.0', update_to='paddle.nn.functional.embedding')
def embedding(
input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32',
):
r"""
:api_attr: Static Graph
The operator is used to lookup embeddings vector of ids provided by :attr:`input` .
It automatically constructs a 2D embedding matrix based on the
input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` .
The shape of output Tensor is generated by appending an emb_size dimension to the
last dimension of the input Tensor shape.
**Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` ,
otherwise the program will throw an exception and exit.
.. code-block:: text
Case 1:
input is a Tensor. padding_idx = -1
input.data = [[1, 3], [2, 4], [4, 127]]
input.shape = [3, 2]
Given size = [128, 16]
output is a Tensor:
out.shape = [3, 2, 16]
out.data = [[[0.129435295, 0.244512452, ..., 0.436322452],
[0.345421456, 0.524563927, ..., 0.144534654]],
[[0.345249859, 0.124939536, ..., 0.194353745],
[0.945345345, 0.435394634, ..., 0.435345365]],
[[0.945345345, 0.435394634, ..., 0.435345365],
[0.0, 0.0, ..., 0.0 ]]] # padding data
The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
It will pad all-zero data when ids is 127.
Case 2:
input is a LoDTensor with 1-level LoD. padding_idx = 0
input.lod = [[2, 3]]
input.data = [[1], [3], [2], [4], [0]]
input.shape = [5, 1]
Given size = [128, 16]
output is a LoDTensor:
out.lod = [[2, 3]]
out.shape = [5, 1, 16]
out.data = [[[0.129435295, 0.244512452, ..., 0.436322452]],
[[0.345421456, 0.524563927, ..., 0.144534654]],
[[0.345249859, 0.124939536, ..., 0.194353745]],
[[0.945345345, 0.435394634, ..., 0.435345365]],
[[0.0, 0.0, ..., 0.0 ]]] # padding data
It will pad all-zero data when ids is 0.
Args:
input(Variable): A Tensor or LoDTensor with type int64, which contains the id information.
The value of the input id should satisfy :math:`0<= id < size[0]` .
size(tuple|list): The shape of lookup table parameter. It should have two elements which
indicates the size of the dictionary of embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update. This parameter only
affects the performance of the backwards gradient update. It is recommended to set
True because sparse update is faster. But some optimizer does not support sparse update
In these case, is_sparse must be False. Default: False.
is_distributed(bool): Whether to store the embedding matrix in a distributed manner. Only used
in multi-machine distributed CPU training. Default: False.
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
If set None, it makes no effect to output. Default: None.
param_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the
default weight parameter property is used. In addition,
user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
The local word vector needs to be transformed into numpy format, and the shape of local word
vector should be consistent with :attr:`size` .
dtype(str): It refers to the data type of output Tensor.
It must be float32 or float64. Default: float32.
Returns:
Variable: Embedding Tensor or LoDTensor mapped by input. The data type is the same as :attr:`dtype` .
Static Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.enable_static()
x = paddle.static.data(name="x", shape = [2, 4], dtype=np.int64)
embedding = paddle.nn.Embedding(10, 3,
weight_attr=paddle.nn.initializer.Constant(value=1.0))
adam = paddle.optimizer.SGD(parameters=[embedding.weight], learning_rate=0.01)
output = embedding(x)
m_output=paddle.mean(output)
adam.minimize(m_output)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
x = np.array([[7, 2, 4, 5],[4, 3, 2, 9]], dtype=np.int64)
# x is a Numpy.
# x.data = [[7, 2, 4, 5], [4, 3, 2, 9]]
# x.shape = [2, 4]
out, = exe.run(paddle.static.default_main_program(), feed={'x':x}, fetch_list=[output])
# out is a Numpy.
# out.data = [[1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.]],
#
# [[1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.],
# [0., 0., 0.]]]
# out.shape = [2, 4, 3]
Dygraph Examples:
.. code-block:: python
import paddle
import numpy as np
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
# x is a Tensor.
# x.data = [[3], [4], [5]]
# x.shape = [3, 1]
x = paddle.to_tensor(x_data, stop_gradient=False)
# embedding weight shape = [10, 3]
embedding = paddle.nn.Embedding(10, 3, sparse=True)
# embedding weight data = [10, 3]
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
# embedding.weight.shape = [10, 3]
# embedding.weight.data =
# [[2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.],
# [2., 2., 2.]]
embedding.weight.set_value(w0)
adam = paddle.optimizer.Adam(
parameters=[embedding.weight], learning_rate=0.01)
adam.clear_grad()
# out is Tensor
# out.shape: [3, 1, 3]
# out.layout: NCHW
# out.dtype: float
# out.data: [2 2 2 2 2 2 2 2 2]
out = embedding(x)
out.backward()
adam.step()
"""
helper = LayerHelper('embedding', **locals())
check_variable_and_dtype(input, 'input', ['int64'], 'fluid.embedding')
check_dtype(
dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'fluid.embedding',
)
remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch:
assert is_sparse is True and is_distributed is False
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False
)
tmp = helper.create_variable_for_type_inference(dtype)
padding_idx = (
-1
if padding_idx is None
else padding_idx
if padding_idx >= 0
else (size[0] + padding_idx)
)
helper.append_op(
type='lookup_table_v2',
inputs={'Ids': input, 'W': w},
outputs={'Out': tmp},
attrs={
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'remote_prefetch': remote_prefetch,
'padding_idx': padding_idx,
},
)
return tmp
...@@ -210,7 +210,7 @@ def embedding( ...@@ -210,7 +210,7 @@ def embedding(
data = fluid.data(name='x', shape=[None, 1], dtype='int64') data = fluid.data(name='x', shape=[None, 1], dtype='int64')
# example 1 # example 1
emb_1 = fluid.embedding(input=data, size=[128, 64]) emb_1 = paddle.static.nn.embedding(input=data, size=[128, 64])
# example 2: load custom or pre-trained word vectors # example 2: load custom or pre-trained word vectors
weight_data = np.random.random(size=(128, 100)) # word vectors with numpy format weight_data = np.random.random(size=(128, 100)) # word vectors with numpy format
......
...@@ -119,7 +119,7 @@ def train_network( ...@@ -119,7 +119,7 @@ def train_network(
) )
# embedding # embedding
q_emb = fluid.embedding( q_emb = paddle.static.nn.embedding(
input=q, input=q,
is_distributed=is_distributed, is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
...@@ -147,7 +147,7 @@ def train_network( ...@@ -147,7 +147,7 @@ def train_network(
) )
# embedding # embedding
pt_emb = fluid.embedding( pt_emb = paddle.static.nn.embedding(
input=pt, input=pt,
is_distributed=is_distributed, is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
...@@ -176,7 +176,7 @@ def train_network( ...@@ -176,7 +176,7 @@ def train_network(
) )
# embedding # embedding
nt_emb = fluid.embedding( nt_emb = paddle.static.nn.embedding(
input=nt, input=nt,
is_distributed=is_distributed, is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
......
...@@ -196,10 +196,6 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -196,10 +196,6 @@ class TestOneHotOpApi(unittest.TestCase):
[np.random.randint(0, depth - 1) for i in range(6)] [np.random.randint(0, depth - 1) for i in range(6)]
).reshape([6, 1]) ).reshape([6, 1])
with fluid.dygraph.guard(): with fluid.dygraph.guard():
one_hot_label = fluid.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth
)
one_hot_label = paddle.nn.functional.one_hot( one_hot_label = paddle.nn.functional.one_hot(
fluid.dygraph.to_variable(label), depth fluid.dygraph.to_variable(label), depth
) )
...@@ -208,7 +204,7 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -208,7 +204,7 @@ class TestOneHotOpApi(unittest.TestCase):
def _run(self, depth): def _run(self, depth):
label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64") label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth) one_hot_label = paddle.nn.functional.one_hot(x=label, num_classes=depth)
label_data = np.array( label_data = np.array(
[np.random.randint(0, 10 - 1) for i in range(6)] [np.random.randint(0, 10 - 1) for i in range(6)]
...@@ -239,7 +235,7 @@ class BadInputTestOnehotV2(unittest.TestCase): ...@@ -239,7 +235,7 @@ class BadInputTestOnehotV2(unittest.TestCase):
shape=[4], shape=[4],
dtype="float32", dtype="float32",
) )
one_hot_label = fluid.one_hot(input=label, depth=4) one_hot_label = paddle.nn.functional.one_hot(x=label, num_classes=4)
self.assertRaises(TypeError, test_bad_x) self.assertRaises(TypeError, test_bad_x)
......
...@@ -219,13 +219,13 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -219,13 +219,13 @@ class TestOneHotOpApi(unittest.TestCase):
[np.random.randint(0, depth - 1) for i in range(6)] [np.random.randint(0, depth - 1) for i in range(6)]
).reshape([6, 1]) ).reshape([6, 1])
with fluid.dygraph.guard(paddle.NPUPlace(0)): with fluid.dygraph.guard(paddle.NPUPlace(0)):
one_hot_label = fluid.one_hot( one_hot_label = paddle.nn.functional.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth x=fluid.dygraph.to_variable(label), num_classes=depth
) )
def _run(self, depth): def _run(self, depth):
label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64") label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth) one_hot_label = paddle.nn.functional.one_hot(x=label, num_classes=depth)
place = fluid.NPUPlace(0) place = fluid.NPUPlace(0)
label_data = np.array( label_data = np.array(
......
...@@ -234,13 +234,13 @@ class SimpleNet(BackwardNet): ...@@ -234,13 +234,13 @@ class SimpleNet(BackwardNet):
) )
# shared layer, the grad of 'w2v' will be summed and renamed. # shared layer, the grad of 'w2v' will be summed and renamed.
# To test _addup_repetitive_outputs_ # To test _addup_repetitive_outputs_
x_emb = fluid.embedding( x_emb = paddle.static.nn.embedding(
x, size=[100, 64], param_attr=fluid.ParamAttr(name='w2v') x, size=[100, 64], param_attr=fluid.ParamAttr(name='w2v')
) )
x2_emb = fluid.embedding( x2_emb = paddle.static.nn.embedding(
x2, size=[100, 64], param_attr=fluid.ParamAttr(name='w2v') x2, size=[100, 64], param_attr=fluid.ParamAttr(name='w2v')
) )
x3_emb = fluid.embedding( x3_emb = paddle.static.nn.embedding(
x3, size=[100, 64], param_attr=fluid.ParamAttr(name='w2v') x3, size=[100, 64], param_attr=fluid.ParamAttr(name='w2v')
) )
# merge layers # merge layers
...@@ -331,7 +331,7 @@ class TestAppendBackwardWithError(unittest.TestCase): ...@@ -331,7 +331,7 @@ class TestAppendBackwardWithError(unittest.TestCase):
def build_net(self): def build_net(self):
x = fluid.data(name='x', shape=[None, 13], dtype='int64') x = fluid.data(name='x', shape=[None, 13], dtype='int64')
y = fluid.data(name='y', shape=[None, 1], dtype='float32') y = fluid.data(name='y', shape=[None, 1], dtype='float32')
x_emb = fluid.embedding(x, size=[100, 256]) x_emb = paddle.static.nn.embedding(x, size=[100, 256])
y_predict = paddle.static.nn.fc(x=x_emb, size=1, name='my_fc') y_predict = paddle.static.nn.fc(x=x_emb, size=1, name='my_fc')
loss = paddle.nn.functional.square_error_cost(input=y_predict, label=y) loss = paddle.nn.functional.square_error_cost(input=y_predict, label=y)
avg_loss = paddle.mean(loss) avg_loss = paddle.mean(loss)
......
...@@ -58,7 +58,9 @@ class TestEmbeddingIdStopGradientBase(unittest.TestCase): ...@@ -58,7 +58,9 @@ class TestEmbeddingIdStopGradientBase(unittest.TestCase):
x.stop_gradient = stop_gradient x.stop_gradient = stop_gradient
emb = fluid.embedding(x, size=[10, 32], dtype='float32') emb = paddle.static.nn.embedding(
x, size=[10, 32], dtype='float32'
)
avg_cost = paddle.mean(emb, name='mean_loss') avg_cost = paddle.mean(emb, name='mean_loss')
optim = fluid.optimizer.SGD(learning_rate=0.001) optim = fluid.optimizer.SGD(learning_rate=0.001)
optim.minimize(avg_cost) optim.minimize(avg_cost)
......
...@@ -60,8 +60,8 @@ class TestDygraphLoadStatic(unittest.TestCase): ...@@ -60,8 +60,8 @@ class TestDygraphLoadStatic(unittest.TestCase):
batchnorm_out_2 = paddle.static.nn.batch_norm(batchnorm_in) batchnorm_out_2 = paddle.static.nn.batch_norm(batchnorm_in)
emb_in = fluid.data(name='emb_in', shape=[None, 10], dtype='int64') emb_in = fluid.data(name='emb_in', shape=[None, 10], dtype='int64')
emb_out_1 = fluid.embedding(emb_in, [1000, 100]) emb_out_1 = paddle.static.nn.embedding(emb_in, [1000, 100])
emb_out_2 = fluid.embedding(emb_in, [2000, 200]) emb_out_2 = paddle.static.nn.embedding(emb_in, [2000, 200])
layernorm = fluid.data(name="ln", shape=[None, 10], dtype='float32') layernorm = fluid.data(name="ln", shape=[None, 10], dtype='float32')
layernorm_1 = paddle.static.nn.layer_norm(layernorm) layernorm_1 = paddle.static.nn.layer_norm(layernorm)
......
...@@ -105,7 +105,7 @@ class TestEmbeddingLayerBF16ConstantInitializer(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestEmbeddingLayerBF16ConstantInitializer(unittest.TestCase):
x = paddle.static.data( x = paddle.static.data(
name='x', shape=[-1] + self.ids_shape, dtype='int64' name='x', shape=[-1] + self.ids_shape, dtype='int64'
) )
self.emb = fluid.input.embedding( self.emb = paddle.static.nn.embedding(
input=x, input=x,
size=self.w_shape, size=self.w_shape,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
......
...@@ -203,7 +203,7 @@ class TestLookupTableIsSparse(unittest.TestCase): ...@@ -203,7 +203,7 @@ class TestLookupTableIsSparse(unittest.TestCase):
with fluid.program_guard(main_program, fluid.Program()): with fluid.program_guard(main_program, fluid.Program()):
x = paddle.static.data(name='x', shape=[-1, 5], dtype='int64') x = paddle.static.data(name='x', shape=[-1, 5], dtype='int64')
y_ = paddle.static.data(name='y_', shape=[-1, 5], dtype='float32') y_ = paddle.static.data(name='y_', shape=[-1, 5], dtype='float32')
emb = fluid.input.embedding( emb = paddle.static.nn.embedding(
input=x, input=x,
size=[10, 16], size=[10, 16],
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
...@@ -246,7 +246,7 @@ class TestLookupTableIsSparse(unittest.TestCase): ...@@ -246,7 +246,7 @@ class TestLookupTableIsSparse(unittest.TestCase):
class TestLookupTableApi(unittest.TestCase): class TestLookupTableApi(unittest.TestCase):
def test_api(self): def test_api(self):
x = paddle.static.data(name='x', shape=[-1, 20], dtype='int64') x = paddle.static.data(name='x', shape=[-1, 20], dtype='int64')
emb = fluid.embedding(input=x, size=[128, 64]) emb = paddle.static.nn.embedding(input=x, size=[128, 64])
place = fluid.CPUPlace() place = fluid.CPUPlace()
x_data = np.random.randint(0, 127, [2, 20]).astype("int64") x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
...@@ -269,25 +269,29 @@ class TestEmbedOpError(unittest.TestCase): ...@@ -269,25 +269,29 @@ class TestEmbedOpError(unittest.TestCase):
def test_Variable(): def test_Variable():
# the input type must be Variable # the input type must be Variable
fluid.embedding(input=input_data, size=(10, 64)) paddle.static.nn.embedding(input=input_data, size=(10, 64))
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
def test_input_dtype(): def test_input_dtype():
# the input dtype must be int64 # the input dtype must be int64
input = fluid.data(name='x1', shape=[4, 6], dtype='float32') input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
fluid.embedding(input=input, size=(10, 64)) paddle.static.nn.embedding(input=input, size=(10, 64))
self.assertRaises(TypeError, test_input_dtype) self.assertRaises(TypeError, test_input_dtype)
def test_param_dtype(): def test_param_dtype():
# dtype must be float32 or float64 # dtype must be float32 or float64
input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64') input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
fluid.embedding(input=input2, size=(10, 64), dtype='int64') paddle.static.nn.embedding(
input=input2, size=(10, 64), dtype='int64'
)
self.assertRaises(TypeError, test_param_dtype) self.assertRaises(TypeError, test_param_dtype)
input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64') input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
fluid.embedding(input=input3, size=(10, 64), dtype='float16') paddle.static.nn.embedding(
input=input3, size=(10, 64), dtype='float16'
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -188,10 +188,6 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -188,10 +188,6 @@ class TestOneHotOpApi(unittest.TestCase):
[np.random.randint(0, depth - 1) for i in range(6)] [np.random.randint(0, depth - 1) for i in range(6)]
).reshape([6, 1]) ).reshape([6, 1])
with fluid.dygraph.guard(): with fluid.dygraph.guard():
one_hot_label = fluid.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth
)
one_hot_label = paddle.nn.functional.one_hot( one_hot_label = paddle.nn.functional.one_hot(
fluid.dygraph.to_variable(label), depth fluid.dygraph.to_variable(label), depth
) )
...@@ -202,7 +198,7 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -202,7 +198,7 @@ class TestOneHotOpApi(unittest.TestCase):
def _run(self, depth): def _run(self, depth):
label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64") label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
label.desc.set_need_check_feed(False) label.desc.set_need_check_feed(False)
one_hot_label = fluid.one_hot(input=label, depth=depth) one_hot_label = paddle.nn.functional.one_hot(x=label, num_classes=depth)
place = fluid.CPUPlace() place = fluid.CPUPlace()
label_data = np.array( label_data = np.array(
...@@ -231,7 +227,9 @@ class BadInputTestOnehotV2(unittest.TestCase): ...@@ -231,7 +227,9 @@ class BadInputTestOnehotV2(unittest.TestCase):
dtype="float32", dtype="float32",
) )
label.desc.set_need_check_feed(False) label.desc.set_need_check_feed(False)
one_hot_label = fluid.one_hot(input=label, depth=4) one_hot_label = paddle.nn.functional.one_hot(
x=label, num_classes=4
)
self.assertRaises(TypeError, test_bad_x) self.assertRaises(TypeError, test_bad_x)
......
...@@ -463,7 +463,7 @@ class TestRunProgramOpWithEmbedding(RunProgramOpTest): ...@@ -463,7 +463,7 @@ class TestRunProgramOpWithEmbedding(RunProgramOpTest):
x = paddle.static.data( x = paddle.static.data(
name=self.input_names['X'][0], shape=[-1, 5], dtype='int64' name=self.input_names['X'][0], shape=[-1, 5], dtype='int64'
) )
emb = fluid.input.embedding( emb = paddle.static.nn.embedding(
input=x, input=x,
size=[10, 16], size=[10, 16],
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
......
...@@ -200,7 +200,9 @@ class TestSGDOpWithLargeInput(unittest.TestCase): ...@@ -200,7 +200,9 @@ class TestSGDOpWithLargeInput(unittest.TestCase):
label = fluid.layers.fill_constant( label = fluid.layers.fill_constant(
shape=[1, 150], value=0.5, dtype='float32' shape=[1, 150], value=0.5, dtype='float32'
) )
emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32') emb = paddle.static.nn.embedding(
input=data, size=(10000000, 150), dtype='float32'
)
out = paddle.nn.functional.normalize(x=emb, axis=-1) out = paddle.nn.functional.normalize(x=emb, axis=-1)
cost = paddle.nn.functional.square_error_cost(input=out, label=label) cost = paddle.nn.functional.square_error_cost(input=out, label=label)
......
...@@ -168,7 +168,7 @@ class TestLookupTableWithTensorIdsWIsSelectedRows( ...@@ -168,7 +168,7 @@ class TestLookupTableWithTensorIdsWIsSelectedRows(
class TestLookupTableApi(unittest.TestCase): class TestLookupTableApi(unittest.TestCase):
def test_api(self): def test_api(self):
x = paddle.static.data(name='x', shape=[-1, 20], dtype='int64') x = paddle.static.data(name='x', shape=[-1, 20], dtype='int64')
emb = fluid.embedding(input=x, size=[128, 64]) emb = paddle.static.nn.embedding(input=x, size=[128, 64])
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
x_data = np.random.randint(0, 127, [2, 20]).astype("int64") x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
...@@ -191,25 +191,29 @@ class TestEmbedOpError(unittest.TestCase): ...@@ -191,25 +191,29 @@ class TestEmbedOpError(unittest.TestCase):
def test_Variable(): def test_Variable():
# the input type must be Variable # the input type must be Variable
fluid.embedding(input=input_data, size=(10, 64)) paddle.static.nn.embedding(input=input_data, size=(10, 64))
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
def test_input_dtype(): def test_input_dtype():
# the input dtype must be int64 # the input dtype must be int64
input = fluid.data(name='x1', shape=[4, 6], dtype='float32') input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
fluid.embedding(input=input, size=(10, 64)) paddle.static.nn.embedding(input=input, size=(10, 64))
self.assertRaises(TypeError, test_input_dtype) self.assertRaises(TypeError, test_input_dtype)
def test_param_dtype(): def test_param_dtype():
# dtype must be float32 or float64 # dtype must be float32 or float64
input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64') input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
fluid.embedding(input=input2, size=(10, 64), dtype='int64') paddle.static.nn.embedding(
input=input2, size=(10, 64), dtype='int64'
)
self.assertRaises(TypeError, test_param_dtype) self.assertRaises(TypeError, test_param_dtype)
input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64') input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
fluid.embedding(input=input3, size=(10, 64), dtype='float16') paddle.static.nn.embedding(
input=input3, size=(10, 64), dtype='float16'
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -162,13 +162,13 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -162,13 +162,13 @@ class TestOneHotOpApi(unittest.TestCase):
[np.random.randint(0, depth - 1) for i in range(6)] [np.random.randint(0, depth - 1) for i in range(6)]
).reshape([6, 1]) ).reshape([6, 1])
with fluid.dygraph.guard(): with fluid.dygraph.guard():
one_hot_label = fluid.one_hot( one_hot_label = paddle.nn.functional.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth x=fluid.dygraph.to_variable(label), num_classes=depth
) )
def _run(self, depth): def _run(self, depth):
label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64") label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth) one_hot_label = paddle.nn.functional.one_hot(x=label, num_classes=depth)
place = fluid.XPUPlace(0) place = fluid.XPUPlace(0)
label_data = np.array( label_data = np.array(
...@@ -196,7 +196,9 @@ class BadInputTestOnehotV2(unittest.TestCase): ...@@ -196,7 +196,9 @@ class BadInputTestOnehotV2(unittest.TestCase):
shape=[4], shape=[4],
dtype="float32", dtype="float32",
) )
one_hot_label = fluid.one_hot(input=label, depth=4) one_hot_label = paddle.nn.functional.one_hot(
x=label, num_classes=4
)
self.assertRaises(TypeError, test_bad_x) self.assertRaises(TypeError, test_bad_x)
......
...@@ -72,7 +72,9 @@ class TestSGDOpWithLargeInput(unittest.TestCase): ...@@ -72,7 +72,9 @@ class TestSGDOpWithLargeInput(unittest.TestCase):
label = fluid.layers.fill_constant( label = fluid.layers.fill_constant(
shape=[1, 150], value=0.5, dtype='float32' shape=[1, 150], value=0.5, dtype='float32'
) )
emb = fluid.embedding(input=data, size=(10000, 150), dtype='float32') emb = paddle.static.nn.embedding(
input=data, size=(10000, 150), dtype='float32'
)
out = paddle.nn.functional.normalize(x=emb, axis=-1) out = paddle.nn.functional.normalize(x=emb, axis=-1)
cost = paddle.nn.functional.square_error_cost(input=out, label=label) cost = paddle.nn.functional.square_error_cost(input=out, label=label)
......
...@@ -38,7 +38,7 @@ from .common import prelu # noqa: F401 ...@@ -38,7 +38,7 @@ from .common import prelu # noqa: F401
from .common import layer_norm # noqa: F401 from .common import layer_norm # noqa: F401
from ...fluid.input import embedding # noqa: F401 from .common import embedding # noqa: F401
from ...fluid.contrib.layers import sparse_embedding # noqa: F401 from ...fluid.contrib.layers import sparse_embedding # noqa: F401
from ...fluid.layers import StaticRNN # noqa: F401 from ...fluid.layers import StaticRNN # noqa: F401
......
...@@ -3631,3 +3631,164 @@ def layer_norm( ...@@ -3631,3 +3631,164 @@ def layer_norm(
) )
return helper.append_activation(layer_norm_out) return helper.append_activation(layer_norm_out)
@static_only
def embedding(
input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32',
):
r"""
:api_attr: Static Graph
The operator is used to lookup embeddings vector of ids provided by :attr:`input` .
It automatically constructs a 2D embedding matrix based on the
input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` .
The shape of output Tensor is generated by appending an emb_size dimension to the
last dimension of the input Tensor shape.
**Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` ,
otherwise the program will throw an exception and exit.
.. code-block:: text
Case 1:
input is a Tensor. padding_idx = -1
input.data = [[1, 3], [2, 4], [4, 127]]
input.shape = [3, 2]
Given size = [128, 16]
output is a Tensor:
out.shape = [3, 2, 16]
out.data = [[[0.129435295, 0.244512452, ..., 0.436322452],
[0.345421456, 0.524563927, ..., 0.144534654]],
[[0.345249859, 0.124939536, ..., 0.194353745],
[0.945345345, 0.435394634, ..., 0.435345365]],
[[0.945345345, 0.435394634, ..., 0.435345365],
[0.0, 0.0, ..., 0.0 ]]] # padding data
The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
It will pad all-zero data when ids is 127.
Case 2:
input is a LoDTensor with 1-level LoD. padding_idx = 0
input.lod = [[2, 3]]
input.data = [[1], [3], [2], [4], [0]]
input.shape = [5, 1]
Given size = [128, 16]
output is a LoDTensor:
out.lod = [[2, 3]]
out.shape = [5, 1, 16]
out.data = [[[0.129435295, 0.244512452, ..., 0.436322452]],
[[0.345421456, 0.524563927, ..., 0.144534654]],
[[0.345249859, 0.124939536, ..., 0.194353745]],
[[0.945345345, 0.435394634, ..., 0.435345365]],
[[0.0, 0.0, ..., 0.0 ]]] # padding data
It will pad all-zero data when ids is 0.
Args:
input(Tensor): A Tensor or LoDTensor with type int64, which contains the id information.
The value of the input id should satisfy :math:`0<= id < size[0]` .
size(tuple|list): The shape of lookup table parameter. It should have two elements which
indicates the size of the dictionary of embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update. This parameter only
affects the performance of the backwards gradient update. It is recommended to set
True because sparse update is faster. But some optimizer does not support sparse update
In these case, is_sparse must be False. Default: False.
is_distributed(bool): Whether to store the embedding matrix in a distributed manner. Only used
in multi-machine distributed CPU training. Default: False.
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
If set None, it makes no effect to output. Default: None.
param_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the
default weight parameter property is used. In addition,
user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
The local word vector needs to be transformed into numpy format, and the shape of local word
vector should be consistent with :attr:`size` .
dtype(str): It refers to the data type of output Tensor.
It must be float32 or float64. Default: float32.
Returns:
Tensor: Embedding Tensor or LoDTensor mapped by input. The data type is the same as :attr:`dtype` .
Static Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.enable_static()
x = paddle.static.data(name="x", shape = [2, 4], dtype=np.int64)
output = paddle.static.nn.embedding(x, (10, 3),
param_attr=paddle.nn.initializer.Constant(value=1.0))
m_output=paddle.mean(output)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
x = np.array([[7, 2, 4, 5],[4, 3, 2, 9]], dtype=np.int64)
# x is a Numpy.
# x.data = [[7, 2, 4, 5], [4, 3, 2, 9]]
# x.shape = [2, 4]
out, = exe.run(paddle.static.default_main_program(), feed={'x':x}, fetch_list=[output])
# out is a Numpy.
# out.data = [[1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.]],
#
# [[1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.],
# [0., 0., 0.]]]
# out.shape = [2, 4, 3]
"""
helper = LayerHelper('embedding', **locals())
check_variable_and_dtype(input, 'input', ['int64'], 'embedding')
check_dtype(
dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'embedding',
)
remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch:
assert is_sparse is True and is_distributed is False
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False
)
tmp = helper.create_variable_for_type_inference(dtype)
padding_idx = (
-1
if padding_idx is None
else padding_idx
if padding_idx >= 0
else (size[0] + padding_idx)
)
helper.append_op(
type='lookup_table_v2',
inputs={'Ids': input, 'W': w},
outputs={'Out': tmp},
attrs={
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'remote_prefetch': remote_prefetch,
'padding_idx': padding_idx,
},
)
return tmp
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册