未验证 提交 60f1461a 编写于 作者: S sneaxiy 提交者: GitHub

Make Embedding layer support more int ids type (#39381)

* add more int id type support for embedding

* add ut

* add more ut

* fix ci error
上级 ccdcfa2d
......@@ -27,6 +27,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
template <typename T>
struct IsComplex : public std::false_type {};
......@@ -63,6 +66,13 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128);
#define _ForEachIntDataType_(callback) \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8);
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
......@@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackSmall
}
template <typename Visitor>
inline void VisitIntDataType(proto::VarType::Type type, Visitor visitor) {
#define VisitIntDataTypeCallback(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachIntDataType_(VisitIntDataTypeCallback);
PADDLE_THROW(platform::errors::Unimplemented(
"Expected integral data type, but got %s", DataTypeToString(type)));
#undef VisitIntDataTypeCallback
}
template <typename Visitor>
inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
#define VisitDataTypeCallbackTiny(cpp_type, proto_type) \
......@@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackHIP
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
......
......@@ -21,16 +21,16 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag>
__global__ void LookupTableV2(T *output, const T *table, const int64_t *ids,
__global__ void LookupTableV2(T *output, const T *table, const IdT *ids,
const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int64_t id = ids[idy];
auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......@@ -47,15 +47,15 @@ __global__ void LookupTableV2(T *output, const T *table, const int64_t *ids,
}
}
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T *table, const T *output, const IdT *ids,
const int64_t N, const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int64_t id = ids[idy];
auto id = static_cast<int64_t>(ids[idy]);
const T *out = output + idy * D;
T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......@@ -66,123 +66,107 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
}
template <typename T>
__global__ void InputTypeCovert(const T *in_ids, const int64_t K,
int64_t *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = (int64_t)(in_ids[i]);
}
}
template <typename T>
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
struct LookupTableV2CUDAFunctor {
LookupTableV2CUDAFunctor(const framework::ExecutionContext &context,
const framework::Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}
auto id_name = context.InputNames("Ids").front();
auto out_name = context.OutputNames("Out").front();
template <typename IdT>
void apply() {
auto *table_t = context_.Input<framework::Tensor>("W");
auto *output_t = context_.Output<framework::Tensor>("Out");
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
size_t K = ids_t_->numel();
dim3 threads(256, 4);
dim3 grids(80, 1);
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);
const auto *table = table_t->template data<T>();
const auto *ids = ids_t_->template data<IdT>();
auto *output = output_t->template mutable_data<T>(context_.GetPlace());
auto stream = context_.cuda_device_context().stream();
const int64_t *ids_p = nullptr;
if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
if (padding_idx == -1) {
LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx);
} else {
ids_p = ids_t->data<int64_t>();
LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx);
}
for (int64_t i = 0; i < K; ++i) {
PADDLE_ENFORCE_GE(
ids[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(paddle.nn.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids[i]));
PADDLE_ENFORCE_LT(
ids[i], N,
platform::errors::InvalidArgument(
"Variable value (input) of OP(paddle.nn.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids[i]));
}
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1)
LookupTableV2<
T, 256, 4, 80,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids_p, N, K, D, padding_idx);
else
LookupTableV2<
T, 256, 4, 80,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids_p, N, K, D, padding_idx);
}
private:
const framework::ExecutionContext &context_;
const framework::Tensor *ids_t_;
};
template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids_t = context.Input<framework::Tensor>("Ids");
LookupTableV2CUDAFunctor<T> functor(context, ids_t);
framework::VisitIntDataType(ids_t->type(), functor);
}
};
template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT *in_ids, const int64_t K,
OutT *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = static_cast<OutT>(in_ids[i]);
}
}
template <typename T>
struct LookupTableV2GradCUDAFunctor {
LookupTableV2GradCUDAFunctor(const framework::ExecutionContext &context,
const framework::Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}
template <typename IdT>
void apply() {
auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse");
context_.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context_.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *table = context_.Input<framework::Tensor>("W");
auto *d_output =
context_.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
const auto *ids_data = ids_t_->template data<IdT>();
int64_t ids_num = ids_t_->numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
auto gpu_place = context.GetPlace();
auto gpu_place = context_.GetPlace();
if (ids->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids->data<int>(), ids_num,
new_rows.MutableData(context.GetPlace()));
if (!std::is_same<IdT, int64_t>::value) {
InputTypeConvert<<<grids, threads, 0, stream>>>(
ids_data, ids_num, new_rows.MutableData(gpu_place));
} else {
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
memory::Copy(gpu_place, new_rows.CUDAMutableData(gpu_place), gpu_place,
ids_data, ids_num * sizeof(int64_t), stream);
}
d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());
d_table_value->template mutable_data<T>(gpu_place);
auto *d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->template data<T>();
auto *d_output_data = d_output->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
......@@ -197,41 +181,43 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
d_output->numel() * sizeof(T), stream);
} else {
auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
auto d_output_t =
context_.Input<framework::Tensor>(framework::GradVarName("Out"));
auto d_table_t =
context_.Output<framework::Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
int K = ids_t_->numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);
const int64_t *ids_p = nullptr;
if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}
const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
const T *d_output = d_output_t->template data<T>();
const auto *ids = ids_t_->template data<IdT>();
T *d_table = d_table_t->mutable_data<T>(context_.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
LookupTableV2Grad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids_p, N, K, D);
LookupTableV2Grad<T, IdT, 128, 8,
8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
private:
const framework::ExecutionContext &context_;
const framework::Tensor *ids_t_;
};
template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids_t = context.Input<framework::Tensor>("Ids");
LookupTableV2GradCUDAFunctor<T> functor(context, ids_t);
framework::VisitIntDataType(ids_t->type(), functor);
}
};
} // namespace operators
......
......@@ -34,35 +34,44 @@ using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1;
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const Tensor &ids) {
auto numel = ids.numel();
const auto *src = ids.data<InT>();
std::vector<OutT> ret(numel);
if (std::is_same<InT, OutT>::value) {
std::memcpy(ret.data(), src, numel * sizeof(InT));
} else {
for (decltype(numel) i = 0; i < numel; ++i) {
ret[i] = src[i];
}
}
return ret;
}
template <typename T>
class LookupTableV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
struct LookupTableV2CPUFunctor {
LookupTableV2CPUFunctor(const framework::ExecutionContext &context,
const Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t ids_numel = ids_t->numel();
template <typename IdT>
void apply() {
auto *output_t = context_.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context_.InputVar("W");
std::vector<int64_t> ids;
ids.reserve(ids_numel);
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_numel,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
auto ids_numel = static_cast<int64_t>(ids.size());
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
if (table_var->template IsType<LoDTensor>()) {
const auto &table_t = table_var->template Get<LoDTensor>();
int64_t row_number = table_t.dims()[0];
int64_t row_width = table_t.dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto *table = table_t.template data<T>();
auto *output = output_t->template mutable_data<T>(context_.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
......@@ -86,11 +95,11 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
row_width * sizeof(T));
}
}
} else if (table_var->IsType<pten::SelectedRows>()) {
const auto &table_t = table_var->Get<pten::SelectedRows>();
} else if (table_var->template IsType<pten::SelectedRows>()) {
const auto &table_t = table_var->template Get<pten::SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
const auto *table = table_t.value().template data<T>();
auto *output = output_t->template mutable_data<T>(context_.GetPlace());
auto input_data_type = table_t.value().type();
for (int64_t i = 0; i < ids_numel; ++i) {
......@@ -114,7 +123,7 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
} else {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context_);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
......@@ -122,18 +131,36 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
}
}
}
private:
const framework::ExecutionContext &context_;
const Tensor *ids_t_;
};
template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
class LookupTableV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W");
const auto *ids = context.Input<Tensor>("Ids");
LookupTableV2CPUFunctor<T> functor(context, ids);
framework::VisitIntDataType(ids->type(), functor);
}
};
template <typename T>
struct LookupTableV2GradCPUFunctor {
LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context,
const Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}
template <typename IdT>
void apply() {
auto *table_var = context_.InputVar("W");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<pten::SelectedRows>()) {
auto *table_t = context.Input<pten::SelectedRows>("W");
if (table_var->template IsType<LoDTensor>()) {
table_dim = context_.Input<LoDTensor>("W")->dims();
} else if (table_var->template IsType<pten::SelectedRows>()) {
auto *table_t = context_.Input<pten::SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -141,39 +168,30 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
"must be either LoDTensor or SelectedRows"));
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
bool is_sparse = context.Attr<bool>("is_sparse");
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
bool is_sparse = context_.Attr<bool>("is_sparse");
auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids;
ids.reserve(ids_num);
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
d_table->set_rows(ids);
auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->mutable_data<T>(context.GetPlace());
d_table_value->template mutable_data<T>(context_.GetPlace());
d_table->set_height(table_dim[0]);
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->template data<T>();
auto *d_table_data = d_table_value->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
......@@ -188,29 +206,16 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids;
ids.reserve(ids_num);
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context_.Output<LoDTensor>(framework::GradVarName("W"));
auto *ids_data = ids.data();
int64_t N = table_dim[0];
int64_t D = table_dim[1];
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
auto *d_output_data = d_output->template data<T>();
auto *d_table_data =
d_table->template mutable_data<T>(context_.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T));
......@@ -240,6 +245,20 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
}
}
}
private:
const framework::ExecutionContext &context_;
const Tensor *ids_t_;
};
template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids = context.Input<Tensor>("Ids");
LookupTableV2GradCPUFunctor<T> functor(context, ids);
framework::VisitIntDataType(ids->type(), functor);
}
};
} // namespace operators
......
......@@ -1652,7 +1652,9 @@ class Embedding(layers.Layer):
'is_distributed', self._is_distributed, 'remote_prefetch',
self._remote_prefetch, 'padding_idx', self._padding_idx)
check_variable_and_dtype(input, 'input', ['int64'], 'Embedding')
check_variable_and_dtype(input, 'input',
['uint8', 'int8', 'int16', 'int32', 'int64'],
'Embedding')
attrs = {
'is_sparse': self._is_sparse,
'is_distributed': self._is_distributed,
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.op import Operator
......@@ -25,29 +26,36 @@ import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestDygraphEmbeddingAPIError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
dict_size = 20
layer = fluid.dygraph.nn.Embedding(
size=[dict_size, 32], param_attr='emb.w', is_sparse=False)
# the input must be Variable.
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, layer, x0)
# the input dtype must be int64
data_t = fluid.data(name='word', shape=[1], dtype='int32')
self.assertRaises(TypeError, layer, data_t)
class TestStaticGraphSupportMultipleInt(unittest.TestCase):
def test_main(self):
dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
if paddle.in_dynamic_mode():
paddle.enable_static()
disable_static = True
else:
disable_static = False
for i, dtype in enumerate(dtypes):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
emb = paddle.nn.Embedding(10, 20)
y = emb(x)
if disable_static:
paddle.disable_static()
class TestLookupTableOp(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
table = np.random.random((17, 31)).astype("float64")
ids = np.random.randint(0, 17, 4).astype("int64")
ids = np.random.randint(0, 17, 4).astype(self.id_dtype())
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]}
def id_dtype(self):
return "int64"
def test_check_output(self):
self.check_output()
......@@ -55,6 +63,21 @@ class TestLookupTableOp(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
class TestLookupTableOpInt16(OpTest):
def id_dtype(self):
return "int16"
class TestLookupTableOpInt8(OpTest):
def id_dtype(self):
return "int8"
class TestLookupTableOpUInt8(OpTest):
def id_dtype(self):
return "uint8"
class TestLookupTableOpWithTensorIds(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
......@@ -256,4 +279,5 @@ class TestEmbedOpError(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -30,21 +30,6 @@ from paddle.fluid import Program, program_guard
paddle.enable_static()
class TestDygraphEmbeddingAPIError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
dict_size = 20
layer = fluid.dygraph.nn.Embedding(
size=[dict_size, 32], param_attr='emb.w', is_sparse=False)
# the input must be Variable
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], paddle.XPUPlace(0))
self.assertRaises(TypeError, layer, x0)
# the input dtype must be int64
data_t = fluid.data(name='word', shape=[1], dtype='int32')
self.assertRaises(TypeError, layer, data_t)
class TestLookupTableOp(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
......
......@@ -204,7 +204,9 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
helper = LayerHelper('embedding', **locals())
dtype = helper.input_dtype(input_param_name='weight')
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'embedding')
check_variable_and_dtype(x, 'input',
['uint8', 'int8', 'int16', 'int32', 'int64'],
'embedding')
is_distributed = False
remote_prefetch = sparse and (not is_distributed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册