diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index dbde9aa24ff02474a5f231e7f5d556d4af6e8836..20059e82a720ae4130d2963ecdff03fc37c5551d 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -27,6 +27,9 @@ limitations under the License. */ namespace paddle { namespace framework { +extern std::string DataTypeToString(const proto::VarType::Type type); +extern size_t SizeOfType(proto::VarType::Type type); + template struct IsComplex : public std::false_type {}; @@ -63,6 +66,13 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ COMPLEX128); +#define _ForEachIntDataType_(callback) \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ + _ForEachDataTypeHelper_(callback, int16_t, INT16); \ + _ForEachDataTypeHelper_(callback, int8_t, INT8); + #define _ForEachDataTypeSmall_(callback) \ _ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, double, FP64); \ @@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) { #undef VisitDataTypeCallbackSmall } +template +inline void VisitIntDataType(proto::VarType::Type type, Visitor visitor) { +#define VisitIntDataTypeCallback(cpp_type, proto_type) \ + do { \ + if (type == proto_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _ForEachIntDataType_(VisitIntDataTypeCallback); + + PADDLE_THROW(platform::errors::Unimplemented( + "Expected integral data type, but got %s", DataTypeToString(type))); + +#undef VisitIntDataTypeCallback +} + template inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) { #define VisitDataTypeCallbackTiny(cpp_type, proto_type) \ @@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) { #undef VisitDataTypeCallbackHIP } -extern std::string DataTypeToString(const proto::VarType::Type type); -extern size_t SizeOfType(proto::VarType::Type type); inline std::ostream& operator<<(std::ostream& out, const proto::VarType::Type& type) { out << DataTypeToString(type); diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu index 74ad0e4978b4ec6b3aa5553fc0a6202286ea6ffd..d182385cce9d2779f0d34a93ca8298ac0d26ca5b 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -21,16 +21,16 @@ limitations under the License. */ namespace paddle { namespace operators { -template -__global__ void LookupTableV2(T *output, const T *table, const int64_t *ids, +__global__ void LookupTableV2(T *output, const T *table, const IdT *ids, const int64_t N, const int64_t K, const int64_t D, const int64_t padding_idx) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * GridDimX; while (idy < K) { - int64_t id = ids[idy]; + auto id = static_cast(ids[idy]); T *out = output + idy * D; const T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { @@ -47,15 +47,15 @@ __global__ void LookupTableV2(T *output, const T *table, const int64_t *ids, } } -template -__global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, +template +__global__ void LookupTableV2Grad(T *table, const T *output, const IdT *ids, const int64_t N, const int64_t K, const int64_t D) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * GridDimX; while (idy < K) { - int64_t id = ids[idy]; + auto id = static_cast(ids[idy]); const T *out = output + idy * D; T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { @@ -66,123 +66,107 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, } template -__global__ void InputTypeCovert(const T *in_ids, const int64_t K, - int64_t *out_ids) { - for (int i = 0; i < K; i++) { - out_ids[i] = (int64_t)(in_ids[i]); - } -} - -template -class LookupTableV2CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *table_t = context.Input("W"); - auto *ids_t = context.Input("Ids"); - auto *output_t = context.Output("Out"); - int64_t padding_idx = context.Attr("padding_idx"); +struct LookupTableV2CUDAFunctor { + LookupTableV2CUDAFunctor(const framework::ExecutionContext &context, + const framework::Tensor *ids_t) + : context_(context), ids_t_(ids_t) {} - auto id_name = context.InputNames("Ids").front(); - auto out_name = context.OutputNames("Out").front(); + template + void apply() { + auto *table_t = context_.Input("W"); + auto *output_t = context_.Output("Out"); + int64_t padding_idx = context_.Attr("padding_idx"); size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - size_t K = ids_t->numel(); + size_t K = ids_t_->numel(); dim3 threads(256, 4); dim3 grids(80, 1); - // copy GPU memory to CPU pinned memory - framework::Vector ids; - ids.resize(K); + const auto *table = table_t->template data(); + const auto *ids = ids_t_->template data(); + auto *output = output_t->template mutable_data(context_.GetPlace()); + auto stream = context_.cuda_device_context().stream(); - const int64_t *ids_p = nullptr; - - if (ids_t->type() == framework::proto::VarType::INT32) { - InputTypeCovert< - int><<>>( - ids_t->data(), K, ids.MutableData(context.GetPlace())); - ids_p = ids.MutableData(context.GetPlace()); + if (padding_idx == -1) { + LookupTableV2<<>>( + output, table, ids, N, K, D, padding_idx); } else { - ids_p = ids_t->data(); + LookupTableV2<<>>( + output, table, ids, N, K, D, padding_idx); } - - for (int64_t i = 0; i < K; ++i) { - PADDLE_ENFORCE_GE( - ids[i], 0, - platform::errors::InvalidArgument( - "Variable value (input) of OP(paddle.nn.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input value.", - N, ids[i])); - PADDLE_ENFORCE_LT( - ids[i], N, - platform::errors::InvalidArgument( - "Variable value (input) of OP(paddle.nn.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input value.", - N, ids[i])); - } - - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - if (padding_idx == -1) - LookupTableV2< - T, 256, 4, 80, - false><<>>( - output, table, ids_p, N, K, D, padding_idx); - else - LookupTableV2< - T, 256, 4, 80, - true><<>>( - output, table, ids_p, N, K, D, padding_idx); } + + private: + const framework::ExecutionContext &context_; + const framework::Tensor *ids_t_; }; template -class LookupTableV2GradCUDAKernel : public framework::OpKernel { +class LookupTableV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { + const auto *ids_t = context.Input("Ids"); + LookupTableV2CUDAFunctor functor(context, ids_t); + framework::VisitIntDataType(ids_t->type(), functor); + } +}; + +template +__global__ void InputTypeConvert(const InT *in_ids, const int64_t K, + OutT *out_ids) { + for (int i = 0; i < K; i++) { + out_ids[i] = static_cast(in_ids[i]); + } +} + +template +struct LookupTableV2GradCUDAFunctor { + LookupTableV2GradCUDAFunctor(const framework::ExecutionContext &context, + const framework::Tensor *ids_t) + : context_(context), ids_t_(ids_t) {} + + template + void apply() { auto &dev_ctx = - context.template device_context(); - bool is_sparse = context.Attr("is_sparse"); + context_.template device_context(); + bool is_sparse = context_.Attr("is_sparse"); // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { - auto *ids = context.Input("Ids"); - auto *table = context.Input("W"); - auto *d_output = context.Input(framework::GradVarName("Out")); + auto *table = context_.Input("W"); + auto *d_output = + context_.Input(framework::GradVarName("Out")); auto *d_table = - context.Output(framework::GradVarName("W")); + context_.Output(framework::GradVarName("W")); - auto *ids_data = ids->data(); - int64_t ids_num = ids->numel(); + const auto *ids_data = ids_t_->template data(); + int64_t ids_num = ids_t_->numel(); dim3 threads(128, 8); dim3 grids(8, 1); auto stream = dev_ctx.stream(); - // copy GPU memory to CPU pinned memory framework::Vector new_rows; new_rows.resize(ids_num); - auto gpu_place = context.GetPlace(); + auto gpu_place = context_.GetPlace(); - if (ids->type() == framework::proto::VarType::INT32) { - InputTypeCovert< - int><<>>( - ids->data(), ids_num, - new_rows.MutableData(context.GetPlace())); + if (!std::is_same::value) { + InputTypeConvert<<>>( + ids_data, ids_num, new_rows.MutableData(gpu_place)); } else { - memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()), - gpu_place, ids_data, ids_num * sizeof(int64_t), stream); + memory::Copy(gpu_place, new_rows.CUDAMutableData(gpu_place), gpu_place, + ids_data, ids_num * sizeof(int64_t), stream); } d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table->dims()[1]}); - d_table_value->mutable_data(context.GetPlace()); + d_table_value->template mutable_data(gpu_place); - auto *d_table_data = d_table_value->data(); - auto *d_output_data = d_output->data(); + auto *d_table_data = d_table_value->template data(); + auto *d_output_data = d_output->template data(); auto d_output_dims = d_output->dims(); auto d_output_dims_2d = framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); @@ -197,41 +181,43 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel { d_output->numel() * sizeof(T), stream); } else { - auto ids_t = context.Input("Ids"); - auto d_output_t = context.Input(framework::GradVarName("Out")); - auto d_table_t = context.Output(framework::GradVarName("W")); + auto d_output_t = + context_.Input(framework::GradVarName("Out")); + auto d_table_t = + context_.Output(framework::GradVarName("W")); int N = d_table_t->dims()[0]; int D = d_table_t->dims()[1]; - int K = ids_t->numel(); + int K = ids_t_->numel(); dim3 threads(128, 8); dim3 grids(8, 1); - // copy GPU memory to CPU pinned memory - framework::Vector ids; - ids.resize(K); - - const int64_t *ids_p = nullptr; - - if (ids_t->type() == framework::proto::VarType::INT32) { - InputTypeCovert< - int><<>>( - ids_t->data(), K, ids.MutableData(context.GetPlace())); - ids_p = ids.MutableData(context.GetPlace()); - } else { - ids_p = ids_t->data(); - } - - const T *d_output = d_output_t->data(); - T *d_table = d_table_t->mutable_data(context.GetPlace()); + const T *d_output = d_output_t->template data(); + const auto *ids = ids_t_->template data(); + T *d_table = d_table_t->mutable_data(context_.GetPlace()); auto t = framework::EigenVector::Flatten(*d_table_t); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); - LookupTableV2Grad<<>>( - d_table, d_output, ids_p, N, K, D); + LookupTableV2Grad<<>>( + d_table, d_output, ids, N, K, D); } } + + private: + const framework::ExecutionContext &context_; + const framework::Tensor *ids_t_; +}; + +template +class LookupTableV2GradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const auto *ids_t = context.Input("Ids"); + LookupTableV2GradCUDAFunctor functor(context, ids_t); + framework::VisitIntDataType(ids_t->type(), functor); + } }; } // namespace operators diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h index 6ea9e58198fbffff5729ed7799a38f5dfece4b35..bc433b1e10f3e698c82871380a306bdf4593d308 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -34,35 +34,44 @@ using DDim = framework::DDim; constexpr int64_t kNoPadding = -1; +template +static std::vector CopyIdsToVector(const Tensor &ids) { + auto numel = ids.numel(); + const auto *src = ids.data(); + std::vector ret(numel); + if (std::is_same::value) { + std::memcpy(ret.data(), src, numel * sizeof(InT)); + } else { + for (decltype(numel) i = 0; i < numel; ++i) { + ret[i] = src[i]; + } + } + return ret; +} + template -class LookupTableV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *ids_t = context.Input("Ids"); // int tensor - auto *output_t = context.Output("Out"); // float tensor - auto *table_var = context.InputVar("W"); +struct LookupTableV2CPUFunctor { + LookupTableV2CPUFunctor(const framework::ExecutionContext &context, + const Tensor *ids_t) + : context_(context), ids_t_(ids_t) {} - int64_t padding_idx = context.Attr("padding_idx"); - int64_t ids_numel = ids_t->numel(); + template + void apply() { + auto *output_t = context_.Output("Out"); // float tensor + auto *table_var = context_.InputVar("W"); - std::vector ids; - ids.reserve(ids_numel); + int64_t padding_idx = context_.Attr("padding_idx"); - if (ids_t->type() == framework::proto::VarType::INT32) { - std::transform(ids_t->data(), ids_t->data() + ids_numel, - std::back_inserter(ids), - [&](int id) { return static_cast(id); }); - } else { - framework::TensorToVector(*ids_t, &ids); - } + auto ids = CopyIdsToVector(*ids_t_); + auto ids_numel = static_cast(ids.size()); - if (table_var->IsType()) { - auto *table_t = context.Input("W"); - int64_t row_number = table_t->dims()[0]; - int64_t row_width = table_t->dims()[1]; + if (table_var->template IsType()) { + const auto &table_t = table_var->template Get(); + int64_t row_number = table_t.dims()[0]; + int64_t row_width = table_t.dims()[1]; - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); + auto *table = table_t.template data(); + auto *output = output_t->template mutable_data(context_.GetPlace()); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) { @@ -86,11 +95,11 @@ class LookupTableV2Kernel : public framework::OpKernel { row_width * sizeof(T)); } } - } else if (table_var->IsType()) { - const auto &table_t = table_var->Get(); + } else if (table_var->template IsType()) { + const auto &table_t = table_var->template Get(); int64_t row_width = table_t.value().dims()[1]; - const auto *table = table_t.value().data(); - auto *output = output_t->mutable_data(context.GetPlace()); + const auto *table = table_t.value().template data(); + auto *output = output_t->template mutable_data(context_.GetPlace()); auto input_data_type = table_t.value().type(); for (int64_t i = 0; i < ids_numel; ++i) { @@ -114,7 +123,7 @@ class LookupTableV2Kernel : public framework::OpKernel { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = math::GetBlas(context); + auto blas = math::GetBlas(context_); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); } @@ -122,18 +131,36 @@ class LookupTableV2Kernel : public framework::OpKernel { } } } + + private: + const framework::ExecutionContext &context_; + const Tensor *ids_t_; }; template -class LookupTableV2GradKernel : public framework::OpKernel { +class LookupTableV2Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto *table_var = context.InputVar("W"); + const auto *ids = context.Input("Ids"); + LookupTableV2CPUFunctor functor(context, ids); + framework::VisitIntDataType(ids->type(), functor); + } +}; + +template +struct LookupTableV2GradCPUFunctor { + LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context, + const Tensor *ids_t) + : context_(context), ids_t_(ids_t) {} + + template + void apply() { + auto *table_var = context_.InputVar("W"); DDim table_dim; - if (table_var->IsType()) { - table_dim = context.Input("W")->dims(); - } else if (table_var->IsType()) { - auto *table_t = context.Input("W"); + if (table_var->template IsType()) { + table_dim = context_.Input("W")->dims(); + } else if (table_var->template IsType()) { + auto *table_t = context_.Input("W"); table_dim = table_t->value().dims(); } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -141,39 +168,30 @@ class LookupTableV2GradKernel : public framework::OpKernel { "must be either LoDTensor or SelectedRows")); } - int64_t padding_idx = context.Attr("padding_idx"); - bool is_sparse = context.Attr("is_sparse"); + int64_t padding_idx = context_.Attr("padding_idx"); + bool is_sparse = context_.Attr("is_sparse"); + + auto ids = CopyIdsToVector(*ids_t_); + auto ids_num = static_cast(ids.size()); + // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { - auto *ids_t = context.Input("Ids"); - auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_output = context_.Input(framework::GradVarName("Out")); auto *d_table = - context.Output(framework::GradVarName("W")); - int64_t ids_num = ids_t->numel(); - - std::vector ids; - ids.reserve(ids_num); - - if (ids_t->type() == framework::proto::VarType::INT32) { - std::transform(ids_t->data(), ids_t->data() + ids_num, - std::back_inserter(ids), - [&](int id) { return static_cast(id); }); - } else { - framework::TensorToVector(*ids_t, &ids); - } + context_.Output(framework::GradVarName("W")); d_table->set_rows(ids); auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->mutable_data(context.GetPlace()); + d_table_value->template mutable_data(context_.GetPlace()); d_table->set_height(table_dim[0]); - auto *d_output_data = d_output->data(); - auto *d_table_data = d_table_value->data(); + auto *d_output_data = d_output->template data(); + auto *d_table_data = d_table_value->template data(); auto d_output_dims = d_output->dims(); auto d_output_dims_2d = @@ -188,29 +206,16 @@ class LookupTableV2GradKernel : public framework::OpKernel { memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); } else { - auto *ids_t = context.Input("Ids"); - auto *d_output = context.Input(framework::GradVarName("Out")); - auto *d_table = context.Output(framework::GradVarName("W")); - int64_t ids_num = ids_t->numel(); - - std::vector ids; - ids.reserve(ids_num); - - if (ids_t->type() == framework::proto::VarType::INT32) { - std::transform(ids_t->data(), ids_t->data() + ids_num, - std::back_inserter(ids), - [&](int id) { return static_cast(id); }); - } else { - framework::TensorToVector(*ids_t, &ids); - } - + auto *d_output = context_.Input(framework::GradVarName("Out")); + auto *d_table = context_.Output(framework::GradVarName("W")); auto *ids_data = ids.data(); int64_t N = table_dim[0]; int64_t D = table_dim[1]; - auto *d_output_data = d_output->data(); - auto *d_table_data = d_table->mutable_data(context.GetPlace()); + auto *d_output_data = d_output->template data(); + auto *d_table_data = + d_table->template mutable_data(context_.GetPlace()); memset(d_table_data, 0, d_table->numel() * sizeof(T)); @@ -240,6 +245,20 @@ class LookupTableV2GradKernel : public framework::OpKernel { } } } + + private: + const framework::ExecutionContext &context_; + const Tensor *ids_t_; +}; + +template +class LookupTableV2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const auto *ids = context.Input("Ids"); + LookupTableV2GradCPUFunctor functor(context, ids); + framework::VisitIntDataType(ids->type(), functor); + } }; } // namespace operators diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 574a56f3bcecb9e34a0b0fd08701e25fc62ae87c..418b80c6ee81620ac0beb94839f869e3334626f5 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1652,7 +1652,9 @@ class Embedding(layers.Layer): 'is_distributed', self._is_distributed, 'remote_prefetch', self._remote_prefetch, 'padding_idx', self._padding_idx) - check_variable_and_dtype(input, 'input', ['int64'], 'Embedding') + check_variable_and_dtype(input, 'input', + ['uint8', 'int8', 'int16', 'int32', 'int64'], + 'Embedding') attrs = { 'is_sparse': self._is_sparse, 'is_distributed': self._is_distributed, diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 44a653521a9c4878f6135c7f78f4e779c929e7d3..cad6437d1d3e3e8aeedb29fab7691dfa810f065c 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest, skip_check_grad_ci +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid.op import Operator @@ -25,29 +26,36 @@ import paddle.fluid as fluid from paddle.fluid import Program, program_guard -class TestDygraphEmbeddingAPIError(unittest.TestCase): - def test_errors(self): - with fluid.program_guard(fluid.Program(), fluid.Program()): - dict_size = 20 - layer = fluid.dygraph.nn.Embedding( - size=[dict_size, 32], param_attr='emb.w', is_sparse=False) - # the input must be Variable. - x0 = fluid.create_lod_tensor( - np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) - self.assertRaises(TypeError, layer, x0) - # the input dtype must be int64 - data_t = fluid.data(name='word', shape=[1], dtype='int32') - self.assertRaises(TypeError, layer, data_t) +class TestStaticGraphSupportMultipleInt(unittest.TestCase): + def test_main(self): + dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64'] + if paddle.in_dynamic_mode(): + paddle.enable_static() + disable_static = True + else: + disable_static = False + for i, dtype in enumerate(dtypes): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype) + emb = paddle.nn.Embedding(10, 20) + y = emb(x) + + if disable_static: + paddle.disable_static() class TestLookupTableOp(OpTest): def setUp(self): self.op_type = "lookup_table_v2" table = np.random.random((17, 31)).astype("float64") - ids = np.random.randint(0, 17, 4).astype("int64") + ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) self.inputs = {'W': table, 'Ids': ids} self.outputs = {'Out': table[ids]} + def id_dtype(self): + return "int64" + def test_check_output(self): self.check_output() @@ -55,6 +63,21 @@ class TestLookupTableOp(OpTest): self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) +class TestLookupTableOpInt16(OpTest): + def id_dtype(self): + return "int16" + + +class TestLookupTableOpInt8(OpTest): + def id_dtype(self): + return "int8" + + +class TestLookupTableOpUInt8(OpTest): + def id_dtype(self): + return "uint8" + + class TestLookupTableOpWithTensorIds(OpTest): def setUp(self): self.op_type = "lookup_table_v2" @@ -256,4 +279,5 @@ class TestEmbedOpError(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py index 0a33c875bf30c1082d39bd91a8fb7901ce3d04ff..d29684b11b0706fb08e3b1a839849b8758ac4a27 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py @@ -30,21 +30,6 @@ from paddle.fluid import Program, program_guard paddle.enable_static() -class TestDygraphEmbeddingAPIError(unittest.TestCase): - def test_errors(self): - with fluid.program_guard(fluid.Program(), fluid.Program()): - dict_size = 20 - layer = fluid.dygraph.nn.Embedding( - size=[dict_size, 32], param_attr='emb.w', is_sparse=False) - # the input must be Variable - x0 = fluid.create_lod_tensor( - np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], paddle.XPUPlace(0)) - self.assertRaises(TypeError, layer, x0) - # the input dtype must be int64 - data_t = fluid.data(name='word', shape=[1], dtype='int32') - self.assertRaises(TypeError, layer, data_t) - - class TestLookupTableOp(OpTest): def setUp(self): self.op_type = "lookup_table_v2" diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index d88ee530715b0c14c8d9c850302b08341d89a794..f71d3001f6f3b426e0db1fb36733beaceff3b849 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -204,7 +204,9 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): helper = LayerHelper('embedding', **locals()) dtype = helper.input_dtype(input_param_name='weight') - check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'embedding') + check_variable_and_dtype(x, 'input', + ['uint8', 'int8', 'int16', 'int32', 'int64'], + 'embedding') is_distributed = False remote_prefetch = sparse and (not is_distributed)