From ebc5f997894bcb632ea770d121557267f400215e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 1 Sep 2020 11:37:35 +0800 Subject: [PATCH] add embedding 2.0 (#26649) * add embedding 2.0 * add embedding support input int32 --- paddle/fluid/operators/lookup_table_v2_op.cc | 13 +- paddle/fluid/operators/lookup_table_v2_op.cu | 73 ++++++-- paddle/fluid/operators/lookup_table_v2_op.h | 173 +++++++++--------- python/paddle/fluid/input.py | 1 + python/paddle/fluid/layers/nn.py | 1 + .../fluid/tests/unittests/test_adam_op.py | 2 +- .../unittests/test_lookup_table_v2_op.py | 4 +- .../test_nn_functional_embedding_dygraph.py | 36 ++++ .../test_nn_functional_embedding_static.py | 82 +++++++++ python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/input.py | 117 +++++++++++- python/paddle/nn/layer/common.py | 170 ++++++++++++++--- 12 files changed, 548 insertions(+), 125 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py create mode 100644 python/paddle/fluid/tests/unittests/test_nn_functional_embedding_static.py diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 122e01f146c..4a6680d76c4 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/lookup_table_v2_op.h" #include - #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/var_type_inference.h" namespace paddle { @@ -196,3 +196,14 @@ REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel, REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad, ops::LookupTableV2GradKernel, ops::LookupTableV2GradKernel); + +/* ========================== register checkpoint ===========================*/ +REGISTER_OP_VERSION(lookup_table_v2) + .AddCheckpoint( + R"ROC(fix lookup_table_v2, add input type `int32`)ROC", + paddle::framework::compatible::OpVersionDesc() + .BugfixWithBehaviorChanged("lookup_table_v2 support input type " + "`int64`; after support input type " + "`int32/int64`")); + +/* ========================================================================== */ diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu index b3b0f8f1960..551f0d3c641 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -85,6 +85,14 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, } } +template +__global__ void InputTypeCovert(const T *in_ids, const int64_t K, + int64_t *out_ids) { + for (int i = 0; i < K; i++) { + out_ids[i] = (int64_t)(in_ids[i]); + } +} + template class LookupTableV2CUDAKernel : public framework::OpKernel { public: @@ -101,23 +109,37 @@ class LookupTableV2CUDAKernel : public framework::OpKernel { size_t D = table_t->dims()[1]; size_t K = ids_t->numel(); - auto *ids = ids_t->data(); - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - dim3 threads(256, 4); dim3 grids(80, 1); + // copy GPU memory to CPU pinned memory + framework::Vector ids; + ids.resize(K); + + const int64_t *ids_p = nullptr; + + if (ids_t->type() == framework::proto::VarType::INT32) { + InputTypeCovert< + int><<>>( + ids_t->data(), K, ids.MutableData(context.GetPlace())); + ids_p = ids.MutableData(context.GetPlace()); + } else { + ids_p = ids_t->data(); + } + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + if (padding_idx == -1) LookupTableV2< T, 256, 4, 80, false><<>>( - output, table, ids, N, K, D, padding_idx); + output, table, ids_p, N, K, D, padding_idx); else LookupTableV2< T, 256, 4, 80, true><<>>( - output, table, ids, N, K, D, padding_idx); + output, table, ids_p, N, K, D, padding_idx); } }; @@ -139,16 +161,24 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); - + dim3 threads(128, 8); + dim3 grids(8, 1); auto stream = dev_ctx.stream(); // copy GPU memory to CPU pinned memory framework::Vector new_rows; new_rows.resize(ids_num); auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); - // TODO(yuyang18): Strange code here. - memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()), - gpu_place, ids_data, ids_num * sizeof(int64_t), stream); + if (ids->type() == framework::proto::VarType::INT32) { + InputTypeCovert< + int><<>>( + ids->data(), ids_num, + new_rows.MutableData(context.GetPlace())); + } else { + memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()), + gpu_place, ids_data, ids_num * sizeof(int64_t), stream); + } + d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); @@ -177,17 +207,32 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel { int N = d_table_t->dims()[0]; int D = d_table_t->dims()[1]; int K = ids_t->numel(); - const int64_t *ids = ids_t->data(); + + dim3 threads(128, 8); + dim3 grids(8, 1); + // copy GPU memory to CPU pinned memory + framework::Vector ids; + ids.resize(K); + + const int64_t *ids_p = nullptr; + + if (ids_t->type() == framework::proto::VarType::INT32) { + InputTypeCovert< + int><<>>( + ids_t->data(), K, ids.MutableData(context.GetPlace())); + ids_p = ids.MutableData(context.GetPlace()); + } else { + ids_p = ids_t->data(); + } + const T *d_output = d_output_t->data(); T *d_table = d_table_t->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*d_table_t); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); - dim3 threads(128, 8); - dim3 grids(8, 1); LookupTableV2Grad<<>>( - d_table, d_output, ids, N, K, D); + d_table, d_output, ids_p, N, K, D); } } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h index 9aab90d8479..092c5f3b033 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include @@ -45,84 +46,70 @@ class LookupTableV2Kernel : public framework::OpKernel { auto *output_t = context.Output("Out"); // float tensor auto *table_var = context.InputVar("W"); - auto id_name = context.InputNames("Ids").front(); - auto embedding_name = context.InputNames("W").front(); - auto out_name = context.OutputNames("Out").front(); - - // for remote prefetch - auto epmap = context.Attr>("epmap"); - auto remote_prefetch = context.Attr("remote_prefetch"); - auto table_names = context.Attr>("table_names"); + int64_t padding_idx = context.Attr("padding_idx"); + int64_t ids_numel = ids_t->numel(); - if (remote_prefetch && !epmap.empty()) { -// if epmap is not empty, then the parameter will be fetched from remote -// parameter server + std::vector ids; + ids.reserve(ids_numel); -#ifdef PADDLE_WITH_DISTRIBUTE - operators::distributed::prefetch(id_name, out_name, embedding_name, false, - table_names, epmap, context, - context.scope()); -#else - PADDLE_THROW( - "paddle is not compiled with distribute support, can not do " - "parameter prefetch!"); -#endif + if (ids_t->type() == framework::proto::VarType::INT32) { + std::transform(ids_t->data(), ids_t->data() + ids_numel, + std::back_inserter(ids), + [&](int id) { return static_cast(id); }); } else { - int64_t padding_idx = context.Attr("padding_idx"); - int64_t *ids = const_cast(ids_t->data()); - int64_t ids_numel = ids_t->numel(); - - if (table_var->IsType()) { - auto *table_t = context.Input("W"); - int64_t row_number = table_t->dims()[0]; - int64_t row_width = table_t->dims()[1]; - - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_LT( - ids[i], row_number, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - row_number, ids[i]); - PADDLE_ENFORCE_GE( - ids[i], 0, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - row_number, ids[i]); - memcpy(output + i * row_width, table + ids[i] * row_width, - row_width * sizeof(T)); - } + framework::TensorToVector(*ids_t, &ids); + } + + if (table_var->IsType()) { + auto *table_t = context.Input("W"); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT( + ids[i], row_number, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i]); + PADDLE_ENFORCE_GE( + ids[i], 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i]); + memcpy(output + i * row_width, table + ids[i] * row_width, + row_width * sizeof(T)); } - } else if (table_var->IsType()) { - const auto &table_t = table_var->Get(); - int64_t row_width = table_t.value().dims()[1]; - const auto *table = table_t.value().data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - auto blas = math::GetBlas(context); - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_GE( - ids[i], 0, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0. But received %ld", - ids[i]); - auto id_index = table_t.Index(ids[i]); - PADDLE_ENFORCE_GE( - id_index, 0, "the input key should be exists. But received %d.", - id_index); - blas.VCOPY(row_width, table + id_index * row_width, - output + i * row_width); - } + } + } else if (table_var->IsType()) { + const auto &table_t = table_var->Get(); + int64_t row_width = table_t.value().dims()[1]; + const auto *table = table_t.value().data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE( + ids[i], 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0. But received %ld", + ids[i]); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE(id_index, 0, + "the input key should be exists. But received %d.", + id_index); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); } } } @@ -151,17 +138,23 @@ class LookupTableV2GradKernel : public framework::OpKernel { // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { - auto *ids = context.Input("Ids"); + auto *ids_t = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); + int64_t ids_num = ids_t->numel(); + + std::vector ids; + ids.reserve(ids_num); - auto *ids_data = ids->data(); - int64_t ids_num = ids->numel(); + if (ids_t->type() == framework::proto::VarType::INT32) { + std::transform(ids_t->data(), ids_t->data() + ids_num, + std::back_inserter(ids), + [&](int id) { return static_cast(id); }); + } else { + framework::TensorToVector(*ids_t, &ids); + } - std::vector new_rows; - new_rows.resize(ids_num); - std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); - d_table->set_rows(new_rows); + d_table->set_rows(ids); auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); @@ -185,11 +178,23 @@ class LookupTableV2GradKernel : public framework::OpKernel { memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); } else { - auto *ids = context.Input("Ids"); + auto *ids_t = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); + int64_t ids_num = ids_t->numel(); + + std::vector ids; + ids.reserve(ids_num); + + if (ids_t->type() == framework::proto::VarType::INT32) { + std::transform(ids_t->data(), ids_t->data() + ids_num, + std::back_inserter(ids), + [&](int id) { return static_cast(id); }); + } else { + framework::TensorToVector(*ids_t, &ids); + } - auto *ids_data = ids->data(); + auto *ids_data = ids.data(); int64_t N = table_dim[0]; int64_t D = table_dim[1]; @@ -199,7 +204,7 @@ class LookupTableV2GradKernel : public framework::OpKernel { memset(d_table_data, 0, d_table->numel() * sizeof(T)); - for (int64_t i = 0; i < ids->numel(); ++i) { + for (int64_t i = 0; i < ids_num; ++i) { if (padding_idx != kNoPadding && ids_data[i] == padding_idx) { // the gradient of padding_idx should be 0, already done by memset, so // do nothing. diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py index 15a3022f932..529588c0846 100644 --- a/python/paddle/fluid/input.py +++ b/python/paddle/fluid/input.py @@ -129,6 +129,7 @@ def one_hot(input, depth, allow_out_of_range=False): return one_hot_out +@deprecated(since='2.0.0', update_to='paddle.nn.functional.embedding') def embedding(input, size, is_sparse=False, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 19c46fd21b1..e77f58d31f7 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -367,6 +367,7 @@ def fc(input, return helper.append_activation(pre_activation) +@deprecated(since="2.0.0", update_to="paddle.nn.functional.embedding") def embedding(input, size, is_sparse=False, diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index b145c8a6fb3..14e83fccd65 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -450,7 +450,7 @@ class TestAdamOpV2(unittest.TestCase): import paddle paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) state_dict = adam.state_dict() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 98d8b7f9f88..44a653521a9 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -59,7 +59,7 @@ class TestLookupTableOpWithTensorIds(OpTest): def setUp(self): self.op_type = "lookup_table_v2" table = np.random.random((17, 31)).astype("float64") - ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype("int64") + ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype("int32") self.inputs = {'W': table, 'Ids': ids} self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} @@ -100,7 +100,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): class TestLookupTableWIsSelectedRows(unittest.TestCase): def prepare_ids(self, scope, place): ids_tensor = scope.var('Ids').get_tensor() - ids_array = np.array([0, 4, 3, 5]).astype("int64") + ids_array = np.array([0, 4, 3, 5]).astype("int32") ids_tensor.set(ids_array, place) return ids_array diff --git a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py new file mode 100644 index 00000000000..e0edf901935 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py @@ -0,0 +1,36 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + + +class EmbeddingDygraph(unittest.TestCase): + def test_1(self): + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + # example 1 + inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64') + inp_word.shape # [2, 3] + dict_size = 20 + + emb = nn.Embedding(dict_size, 32, weight_attr='emb.w', sparse=False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_static.py b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_static.py new file mode 100644 index 00000000000..c9c91ceb39d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_static.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.nn.functional as functional + + +class EmbeddingStatic(unittest.TestCase): + def test_1(self): + prog = fluid.Program() + with fluid.program_guard(prog): + + def test_bad_x(): + initializer = fluid.initializer.NumpyArrayInitializer( + np.random.random(size=(128, 100))) + + param_attr = fluid.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32") + + label = fluid.layers.data( + name="label", + shape=[4], + append_batch_size=False, + dtype="int64") + + emb = functional.embedding( + x=label, weight=weight, sparse=True, name="embedding") + + test_bad_x() + + def test_2(self): + prog = fluid.Program() + with fluid.program_guard(prog): + + def test_bad_x(): + initializer = fluid.initializer.NumpyArrayInitializer( + np.random.random(size=(128, 100))) + + param_attr = fluid.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32") + + label = fluid.layers.data( + name="label", + shape=[4], + append_batch_size=False, + dtype="int32") + + emb = functional.embedding( + x=label, weight=weight, sparse=True, name="embedding") + + test_bad_x() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 3c0aa9c5c99..325eaa64d5c 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -233,3 +233,4 @@ from .vision import space_to_depth #DEFINE_ALIAS from .vision import yolo_box #DEFINE_ALIAS from .vision import yolov3_loss #DEFINE_ALIAS from .input import one_hot #DEFINE_ALIAS +from .input import embedding #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index e77bf0e3967..bc48cc21c29 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -19,7 +19,7 @@ from ...fluid.layer_helper import LayerHelper from ...fluid.layers import core from ...fluid.data_feeder import check_variable_and_dtype, check_dtype -__all__ = ['one_hot'] +__all__ = ['one_hot', 'embedding'] def one_hot(x, num_classes, name=None): @@ -83,6 +83,7 @@ def one_hot(x, num_classes, name=None): # [0., 1., 0., 0.], # [0., 0., 0., 1.], # [1., 0., 0., 0.]] + """ if in_dygraph_mode(): @@ -94,7 +95,7 @@ def one_hot(x, num_classes, name=None): one_hot_out = helper.create_variable_for_type_inference(dtype='float32') if not isinstance(num_classes, Variable): - # user attribute + # user attribute inputs = {'X': x} attrs = {'depth': num_classes, 'allow_out_of_range': False} else: @@ -108,3 +109,115 @@ def one_hot(x, num_classes, name=None): outputs={'Out': one_hot_out}, stop_gradient=True) return one_hot_out + + +def embedding(x, weight, padding_idx=None, sparse=False, name=None): + """ + The operator is used to lookup embeddings vector of ids provided by :attr:`input` . + + The shape of output Tensor is generated by appending the last dimension of the input Tensor shape + with embedding size. + **Note:** The id in :attr:`input` must satisfy :math:`0 =< id < weight.shape[0]` , + otherwise the program will throw an exception and exit. + + .. code-block:: text + + Case 1: + input is a Tensor. + padding_idx = -1 + x.data = [[1, 3], [2, 4], [4, 127]] + x.shape = [3, 2] + weight.shape = [128, 16] + output is a Tensor: + out.shape = [3, 2, 16] + out.data = [[[0.129435295, 0.244512452, ..., 0.436322452], + [0.345421456, 0.524563927, ..., 0.144534654]], + [[0.345249859, 0.124939536, ..., 0.194353745], + [0.945345345, 0.435394634, ..., 0.435345365]], + [[0.945345345, 0.435394634, ..., 0.435345365], + [0.0, 0.0, ..., 0.0 ]]] # padding data + + The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127 + It will pad all-zero data when ids is 127. + + Args: + x(Tensor): A Tensor with type int32/int64, which contains the id information. The value of the input id should + satisfy :math:`0<= id < weight.shape[0]` . + weight (Tensor): The weight. A Tensor with shape of lookup table parameter. It should have two elements which + indicates the size of the dictionary of embeddings and the size of each embedding vector respectively. + sparse(bool): The flag indicating whether to use sparse update. This parameter only + affects the performance of the backwards gradient update. It is recommended to set + True because sparse update is faster. But some optimizers does not support sparse update, + such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` , + :ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` , + :ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` . + In these cases, is_sparse must be False. Default: False. + padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size). + If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted + to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup + encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. + If set None, it makes no effect to output. Default: None. + name(str|None): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: Embedding Tensor mapped by input. The data type is the same as :attr:`weight`. + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn as nn + + weight = prog.global_block().create_parameter( + attr=self._param_attr, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0)) + + prog = paddle.static.Program() + + weight = prog.global_block().create_parameter( + (128, 100), dtype="float32", default_initializer=Constant(1.0)) + + label = paddle.data( + name="label", + shape=[4], + append_batch_size=False, + dtype="int64") + + emb = nn.embedding( + x=label, weight=weight, sparse=True, name="embedding") + + """ + if in_dygraph_mode(): + return core.ops.lookup_table_v2( + weight, x, 'is_sparse', sparse, 'is_distributed', False, + 'remote_prefetch', False, 'padding_idx', padding_idx) + else: + helper = LayerHelper('embedding', **locals()) + dtype = helper.input_dtype() + + check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'embedding') + + is_distributed = False + remote_prefetch = sparse and (not is_distributed) + + tmp = helper.create_variable_for_type_inference(dtype) + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + weight.shape[0] + padding_idx) + + helper.append_op( + type='lookup_table_v2', + inputs={'Ids': x, + 'W': weight}, + outputs={'Out': tmp}, + attrs={ + 'is_sparse': sparse, + 'is_distributed': is_distributed, + 'remote_prefetch': remote_prefetch, + 'padding_idx': padding_idx + }) + return tmp diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 8641e28e37b..d8e1d03b028 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -15,7 +15,7 @@ # TODO: define the common classes to build a neural network from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS -from ...fluid.dygraph import Embedding #DEFINE_ALIAS +from ...fluid.dygraph import Linear #DEFINE_ALIAS from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers from .. import functional as F @@ -146,9 +146,9 @@ class UpSample(layers.Layer): 'nearest' : Nearest neighbor interpolation 'bicubic' : Bicubic interpolation - Linear interpolation is the method of using a line connecting two known quantities - to determine the value of an unknown quantity between the two known quantities. - + Linear interpolation is the method of using a line connecting two known quantities + to determine the value of an unknown quantity between the two known quantities. + Nearest neighbor interpolation is to perform nearest neighbor interpolation in both the 3rd dimension(in height direction) and the 4th dimension(in width direction) on input tensor. @@ -158,7 +158,7 @@ class UpSample(layers.Layer): W-direction in this op) on a rectilinear 2D grid. The key idea is to perform linear interpolation first in one direction, and then again in the other direction. - + Bicubic interpolation is an extension of cubic interpolation for interpolating data points on a two-dimensional regular grid. The interpolated surface is smoother than corresponding surfaces obtained by bilinear interpolation or @@ -205,7 +205,7 @@ class UpSample(layers.Layer): output: (N,C,H_out,W_out) where: H_out = round(H_{in} * scale_{factor}) W_out = round(W_{in} * scale_{factor}) - + Bilinear interpolation: if: align_corners = False , align_mode = 0 @@ -252,19 +252,19 @@ class UpSample(layers.Layer): https://en.wikipedia.org/wiki/Linear_interpolation. For details of linear interpolation, please refer to Wikipedia: - + For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. - + For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation. - + For details of bicubic interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bicubic_interpolation - + For details of trilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Trilinear_interpolation. - + Parameters: x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8, its data format is specified by :attr:`data_format`. @@ -537,8 +537,8 @@ class Pad2D(layers.Layer): If mode is 'reflect', paddings[0] and paddings[1] must be no greater than height-1. And the width dimension has the same condition. Parameters: - paddings (int | List[int32]): The padding size. If padding is a int, uses the same - padding in all boundaries, if padding is a List, it must contain four integers, + paddings (int | List[int32]): The padding size. If padding is a int, uses the same + padding in all boundaries, if padding is a List, it must contain four integers, (padding_top, padding_bottom, padding_left, padding_right). Default is [0, 0, 0, 0]. mode (str): Three modes: 'constant' (default), 'reflect', 'edge' . @@ -550,7 +550,7 @@ class Pad2D(layers.Layer): data_format (str): An string from: "NHWC", "NCHW". Specify the data format of the input data. Default is "NCHW" - Returns: + Returns: None Examples: .. code-block:: text @@ -631,11 +631,11 @@ class Bilinear(layers.Layer): in1_features (int): The dimension of each first input(`x1`). in2_features (int): The dimension of each second input(`x2`). out_features (int): The dimension of output of this layer. - weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of + weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of this layer. The default value is None. bias_attr (ParamAttr, optional): The parameter attribute for the bias of this layer. If it is set to False, no bias will be added to the output units. - If it is set to None, the bias is initialized zero. The default value is None. + If it is set to None, the bias is initialized zero. The default value is None. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. @@ -702,7 +702,7 @@ class Dropout(layers.Layer): """ Dropout is a regularization technique for reducing overfitting by preventing neuron co-adaption during training as described in the paper: - `Improving neural networks by preventing co-adaptation of feature detectors `_ + `Improving neural networks by preventing co-adaptation of feature detectors `_ The dropout operator randomly sets the outputs of some units to zero, while upscale others according to the given dropout probability. @@ -771,8 +771,8 @@ class Dropout2d(layers.Layer): Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` , a channel is a 2D feature map with the shape `HW`). Each channel will be zeroed out independently on every forward call with probability `p` using samples from a Bernoulli distribution. - Dropout2d will help promote independence between feature maps as described in the paper: - `Efficient Object Localization Using Convolutional Networks `_ + Dropout2d will help promote independence between feature maps as described in the paper: + `Efficient Object Localization Using Convolutional Networks `_ See ``paddle.nn.functional.dropout2d`` for more details. @@ -829,8 +829,8 @@ class Dropout3d(layers.Layer): Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` , a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently on every forward call with probability `p` using samples from a Bernoulli distribution. - Dropout3d will help promote independence between feature maps as described in the paper: - `Efficient Object Localization Using Convolutional Networks `_ + Dropout3d will help promote independence between feature maps as described in the paper: + `Efficient Object Localization Using Convolutional Networks `_ See ``paddle.nn.functional.dropout3d`` for more details. @@ -1547,3 +1547,131 @@ class CosineSimilarity(layers.Layer): def forward(self, x1, x2): return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps) + + +class Embedding(layers.Layer): + """ + :alias_main: paddle.nn.Embedding + :alias: paddle.nn.Embedding,paddle.nn.layer.Embedding,paddle.nn.layer.common.Embedding + :old_api: paddle.fluid.dygraph.Embedding + + **Embedding Layer** + + This interface is used to construct a callable object of the ``Embedding`` class. + For specific usage, refer to code examples. It implements the function of the Embedding Layer. + This layer is used to lookup embeddings vector of ids provided by :attr:`input` . + It automatically constructs a 2D embedding matrix based on the + input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` . + + The shape of output Tensor is generated by appending an emb_size dimension to the + last dimension of the input Tensor shape. + + **Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` , + otherwise the program will throw an exception and exit. + + .. code-block:: text + + Case 1: + + input is a Tensor. padding_idx = -1 + input.data = [[1, 3], [2, 4], [4, 127] + input.shape = [3, 2] + Given size = [128, 16] + output is a Tensor: + out.shape = [3, 2, 16] + out.data = [[[0.129435295, 0.244512452, ..., 0.436322452], + [0.345421456, 0.524563927, ..., 0.144534654]], + + [[0.345249859, 0.124939536, ..., 0.194353745], + [0.945345345, 0.435394634, ..., 0.435345365]], + + [[0.945345345, 0.435394634, ..., 0.435345365], + [0.0, 0.0, ..., 0.0 ]]] # padding data + The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127 + It will pad all-zero data when ids is 127. + + Parameters: + num_embeddings (int): Just one element which indicate the size + of the dictionary of embeddings. + embedding_dim: Just one element which indicate the size of each embedding vector respectively. + padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size). + If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted + to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup + encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. + If set None, it makes no effect to output. Default: None. + sparse(bool): The flag indicating whether to use sparse update. This parameter only + affects the performance of the backwards gradient update. It is recommended to set + True because sparse update is faster. But some optimizer does not support sparse update, + such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` , + :ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` , + :ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` . + In these case, is_sparse must be False. Default: False. + weight_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the + default weight parameter property is used. See usage for details in :ref:`api_fluid_ParamAttr` . In addition, + user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. + The local word vector needs to be transformed into numpy format, and the shape of local word + vector should be consistent with :attr:`size` . Then :ref:`api_fluid_initializer_NumpyArrayInitializer` + is used to load custom or pre-trained word vectors. See code example 2 for details. + name(str|None): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Attribute: + **weight** (Parameter): the learnable weights of this layer. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + # example 1 + inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64') + inp_word.shape # [2, 3] + dict_size = 20 + + emb = nn.Embedding( + dict_size, + 32, + sparse=False) + """ + + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + sparse=False, + weight_attr=None, + name=None): + super(Embedding, self).__init__() + self._num_embeddings = num_embeddings + self._embedding_dim = embedding_dim + self._sparse = sparse + self._is_distributed = False + self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + num_embeddings + padding_idx) + self._dtype = self._helper.get_default_dtype() + self._size = [self._num_embeddings, self._embedding_dim] + + self._weight_attr = weight_attr + self._remote_prefetch = False + self._name = name + self._weight = self.create_parameter( + attr=self._weight_attr, + shape=self._size, + dtype=self._dtype, + is_bias=False) + + def forward(self, x): + return F.embedding( + x, + weight=self._weight, + padding_idx=self._padding_idx, + sparse=self._sparse, + name=self._name) -- GitLab