未验证 提交 60f1461a 编写于 作者: S sneaxiy 提交者: GitHub

Make Embedding layer support more int ids type (#39381)

* add more int id type support for embedding

* add ut

* add more ut

* fix ci error
上级 ccdcfa2d
...@@ -27,6 +27,9 @@ limitations under the License. */ ...@@ -27,6 +27,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
template <typename T> template <typename T>
struct IsComplex : public std::false_type {}; struct IsComplex : public std::false_type {};
...@@ -63,6 +66,13 @@ struct DataTypeTrait<void> { ...@@ -63,6 +66,13 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); COMPLEX128);
#define _ForEachIntDataType_(callback) \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8);
#define _ForEachDataTypeSmall_(callback) \ #define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \ _ForEachDataTypeHelper_(callback, double, FP64); \
...@@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) { ...@@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackSmall #undef VisitDataTypeCallbackSmall
} }
template <typename Visitor>
inline void VisitIntDataType(proto::VarType::Type type, Visitor visitor) {
#define VisitIntDataTypeCallback(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachIntDataType_(VisitIntDataTypeCallback);
PADDLE_THROW(platform::errors::Unimplemented(
"Expected integral data type, but got %s", DataTypeToString(type)));
#undef VisitIntDataTypeCallback
}
template <typename Visitor> template <typename Visitor>
inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) { inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
#define VisitDataTypeCallbackTiny(cpp_type, proto_type) \ #define VisitDataTypeCallbackTiny(cpp_type, proto_type) \
...@@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) { ...@@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackHIP #undef VisitDataTypeCallbackHIP
} }
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -21,16 +21,16 @@ limitations under the License. */ ...@@ -21,16 +21,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX, template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag> bool PaddingFlag>
__global__ void LookupTableV2(T *output, const T *table, const int64_t *ids, __global__ void LookupTableV2(T *output, const T *table, const IdT *ids,
const int64_t N, const int64_t K, const int64_t D, const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) { const int64_t padding_idx) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) { while (idy < K) {
int64_t id = ids[idy]; auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D; T *out = output + idy * D;
const T *tab = table + id * D; const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
...@@ -47,15 +47,15 @@ __global__ void LookupTableV2(T *output, const T *table, const int64_t *ids, ...@@ -47,15 +47,15 @@ __global__ void LookupTableV2(T *output, const T *table, const int64_t *ids,
} }
} }
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, __global__ void LookupTableV2Grad(T *table, const T *output, const IdT *ids,
const int64_t N, const int64_t K, const int64_t N, const int64_t K,
const int64_t D) { const int64_t D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) { while (idy < K) {
int64_t id = ids[idy]; auto id = static_cast<int64_t>(ids[idy]);
const T *out = output + idy * D; const T *out = output + idy * D;
T *tab = table + id * D; T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
...@@ -66,123 +66,107 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, ...@@ -66,123 +66,107 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
} }
template <typename T> template <typename T>
__global__ void InputTypeCovert(const T *in_ids, const int64_t K, struct LookupTableV2CUDAFunctor {
int64_t *out_ids) { LookupTableV2CUDAFunctor(const framework::ExecutionContext &context,
for (int i = 0; i < K; i++) { const framework::Tensor *ids_t)
out_ids[i] = (int64_t)(in_ids[i]); : context_(context), ids_t_(ids_t) {}
}
}
template <typename T>
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto id_name = context.InputNames("Ids").front(); template <typename IdT>
auto out_name = context.OutputNames("Out").front(); void apply() {
auto *table_t = context_.Input<framework::Tensor>("W");
auto *output_t = context_.Output<framework::Tensor>("Out");
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel(); size_t K = ids_t_->numel();
dim3 threads(256, 4); dim3 threads(256, 4);
dim3 grids(80, 1); dim3 grids(80, 1);
// copy GPU memory to CPU pinned memory const auto *table = table_t->template data<T>();
framework::Vector<int64_t> ids; const auto *ids = ids_t_->template data<IdT>();
ids.resize(K); auto *output = output_t->template mutable_data<T>(context_.GetPlace());
auto stream = context_.cuda_device_context().stream();
const int64_t *ids_p = nullptr; if (padding_idx == -1) {
LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>(
if (ids_t->type() == framework::proto::VarType::INT32) { output, table, ids, N, K, D, padding_idx);
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else { } else {
ids_p = ids_t->data<int64_t>(); LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx);
} }
for (int64_t i = 0; i < K; ++i) {
PADDLE_ENFORCE_GE(
ids[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(paddle.nn.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids[i]));
PADDLE_ENFORCE_LT(
ids[i], N,
platform::errors::InvalidArgument(
"Variable value (input) of OP(paddle.nn.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids[i]));
}
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1)
LookupTableV2<
T, 256, 4, 80,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids_p, N, K, D, padding_idx);
else
LookupTableV2<
T, 256, 4, 80,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids_p, N, K, D, padding_idx);
} }
private:
const framework::ExecutionContext &context_;
const framework::Tensor *ids_t_;
}; };
template <typename T> template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> { class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const auto *ids_t = context.Input<framework::Tensor>("Ids");
LookupTableV2CUDAFunctor<T> functor(context, ids_t);
framework::VisitIntDataType(ids_t->type(), functor);
}
};
template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT *in_ids, const int64_t K,
OutT *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = static_cast<OutT>(in_ids[i]);
}
}
template <typename T>
struct LookupTableV2GradCUDAFunctor {
LookupTableV2GradCUDAFunctor(const framework::ExecutionContext &context,
const framework::Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}
template <typename IdT>
void apply() {
auto &dev_ctx = auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>(); context_.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context_.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of // Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids"); auto *table = context_.Input<framework::Tensor>("W");
auto *table = context.Input<LoDTensor>("W"); auto *d_output =
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); context_.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_table = auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W")); context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); const auto *ids_data = ids_t_->template data<IdT>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids_t_->numel();
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num); new_rows.resize(ids_num);
auto gpu_place = context.GetPlace(); auto gpu_place = context_.GetPlace();
if (ids->type() == framework::proto::VarType::INT32) { if (!std::is_same<IdT, int64_t>::value) {
InputTypeCovert< InputTypeConvert<<<grids, threads, 0, stream>>>(
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>( ids_data, ids_num, new_rows.MutableData(gpu_place));
ids->data<int>(), ids_num,
new_rows.MutableData(context.GetPlace()));
} else { } else {
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()), memory::Copy(gpu_place, new_rows.CUDAMutableData(gpu_place), gpu_place,
gpu_place, ids_data, ids_num * sizeof(int64_t), stream); ids_data, ids_num * sizeof(int64_t), stream);
} }
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]}); d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace()); d_table_value->template mutable_data<T>(gpu_place);
auto *d_table_data = d_table_value->data<T>(); auto *d_table_data = d_table_value->template data<T>();
auto *d_output_data = d_output->data<T>(); auto *d_output_data = d_output->template data<T>();
auto d_output_dims = d_output->dims(); auto d_output_dims = d_output->dims();
auto d_output_dims_2d = auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
...@@ -197,41 +181,43 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> { ...@@ -197,41 +181,43 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
d_output->numel() * sizeof(T), stream); d_output->numel() * sizeof(T), stream);
} else { } else {
auto ids_t = context.Input<LoDTensor>("Ids"); auto d_output_t =
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out")); context_.Input<framework::Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W")); auto d_table_t =
context_.Output<framework::Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0]; int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1]; int D = d_table_t->dims()[1];
int K = ids_t->numel(); int K = ids_t_->numel();
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
// copy GPU memory to CPU pinned memory const T *d_output = d_output_t->template data<T>();
framework::Vector<int64_t> ids; const auto *ids = ids_t_->template data<IdT>();
ids.resize(K); T *d_table = d_table_t->mutable_data<T>(context_.GetPlace());
const int64_t *ids_p = nullptr;
if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}
const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t); auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0)); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
LookupTableV2Grad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>( LookupTableV2Grad<T, IdT, 128, 8,
d_table, d_output, ids_p, N, K, D); 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
} }
} }
private:
const framework::ExecutionContext &context_;
const framework::Tensor *ids_t_;
};
template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids_t = context.Input<framework::Tensor>("Ids");
LookupTableV2GradCUDAFunctor<T> functor(context, ids_t);
framework::VisitIntDataType(ids_t->type(), functor);
}
}; };
} // namespace operators } // namespace operators
......
...@@ -34,35 +34,44 @@ using DDim = framework::DDim; ...@@ -34,35 +34,44 @@ using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const Tensor &ids) {
auto numel = ids.numel();
const auto *src = ids.data<InT>();
std::vector<OutT> ret(numel);
if (std::is_same<InT, OutT>::value) {
std::memcpy(ret.data(), src, numel * sizeof(InT));
} else {
for (decltype(numel) i = 0; i < numel; ++i) {
ret[i] = src[i];
}
}
return ret;
}
template <typename T> template <typename T>
class LookupTableV2Kernel : public framework::OpKernel<T> { struct LookupTableV2CPUFunctor {
public: LookupTableV2CPUFunctor(const framework::ExecutionContext &context,
void Compute(const framework::ExecutionContext &context) const override { const Tensor *ids_t)
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor : context_(context), ids_t_(ids_t) {}
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); template <typename IdT>
int64_t ids_numel = ids_t->numel(); void apply() {
auto *output_t = context_.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context_.InputVar("W");
std::vector<int64_t> ids; int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
ids.reserve(ids_numel);
if (ids_t->type() == framework::proto::VarType::INT32) { auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_numel, auto ids_numel = static_cast<int64_t>(ids.size());
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
if (table_var->IsType<LoDTensor>()) { if (table_var->template IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W"); const auto &table_t = table_var->template Get<LoDTensor>();
int64_t row_number = table_t->dims()[0]; int64_t row_number = table_t.dims()[0];
int64_t row_width = table_t->dims()[1]; int64_t row_width = table_t.dims()[1];
auto *table = table_t->data<T>(); auto *table = table_t.template data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->template mutable_data<T>(context_.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) { if (padding_idx != kNoPadding && ids[i] == padding_idx) {
...@@ -86,11 +95,11 @@ class LookupTableV2Kernel : public framework::OpKernel<T> { ...@@ -86,11 +95,11 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
} else if (table_var->IsType<pten::SelectedRows>()) { } else if (table_var->template IsType<pten::SelectedRows>()) {
const auto &table_t = table_var->Get<pten::SelectedRows>(); const auto &table_t = table_var->template Get<pten::SelectedRows>();
int64_t row_width = table_t.value().dims()[1]; int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>(); const auto *table = table_t.value().template data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->template mutable_data<T>(context_.GetPlace());
auto input_data_type = table_t.value().type(); auto input_data_type = table_t.value().type();
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
...@@ -114,7 +123,7 @@ class LookupTableV2Kernel : public framework::OpKernel<T> { ...@@ -114,7 +123,7 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context_);
blas.VCOPY(row_width, table + id_index * row_width, blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width); output + i * row_width);
} }
...@@ -122,18 +131,36 @@ class LookupTableV2Kernel : public framework::OpKernel<T> { ...@@ -122,18 +131,36 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
} }
} }
} }
private:
const framework::ExecutionContext &context_;
const Tensor *ids_t_;
}; };
template <typename T> template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> { class LookupTableV2Kernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W"); const auto *ids = context.Input<Tensor>("Ids");
LookupTableV2CPUFunctor<T> functor(context, ids);
framework::VisitIntDataType(ids->type(), functor);
}
};
template <typename T>
struct LookupTableV2GradCPUFunctor {
LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context,
const Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}
template <typename IdT>
void apply() {
auto *table_var = context_.InputVar("W");
DDim table_dim; DDim table_dim;
if (table_var->IsType<LoDTensor>()) { if (table_var->template IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims(); table_dim = context_.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<pten::SelectedRows>()) { } else if (table_var->template IsType<pten::SelectedRows>()) {
auto *table_t = context.Input<pten::SelectedRows>("W"); auto *table_t = context_.Input<pten::SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -141,39 +168,30 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> { ...@@ -141,39 +168,30 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
"must be either LoDTensor or SelectedRows")); "must be either LoDTensor or SelectedRows"));
} }
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context_.Attr<bool>("is_sparse");
auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of // Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
auto *ids_t = context.Input<LoDTensor>("Ids"); auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W")); context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids;
ids.reserve(ids_num);
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
d_table->set_rows(ids); d_table->set_rows(ids);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->mutable_data<T>(context.GetPlace()); d_table_value->template mutable_data<T>(context_.GetPlace());
d_table->set_height(table_dim[0]); d_table->set_height(table_dim[0]);
auto *d_output_data = d_output->data<T>(); auto *d_output_data = d_output->template data<T>();
auto *d_table_data = d_table_value->data<T>(); auto *d_table_data = d_table_value->template data<T>();
auto d_output_dims = d_output->dims(); auto d_output_dims = d_output->dims();
auto d_output_dims_2d = auto d_output_dims_2d =
...@@ -188,29 +206,16 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> { ...@@ -188,29 +206,16 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else { } else {
auto *ids_t = context.Input<LoDTensor>("Ids"); auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_table = context_.Output<LoDTensor>(framework::GradVarName("W"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids;
ids.reserve(ids_num);
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
auto *ids_data = ids.data(); auto *ids_data = ids.data();
int64_t N = table_dim[0]; int64_t N = table_dim[0];
int64_t D = table_dim[1]; int64_t D = table_dim[1];
auto *d_output_data = d_output->data<T>(); auto *d_output_data = d_output->template data<T>();
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace()); auto *d_table_data =
d_table->template mutable_data<T>(context_.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T)); memset(d_table_data, 0, d_table->numel() * sizeof(T));
...@@ -240,6 +245,20 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> { ...@@ -240,6 +245,20 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
} }
} }
} }
private:
const framework::ExecutionContext &context_;
const Tensor *ids_t_;
};
template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids = context.Input<Tensor>("Ids");
LookupTableV2GradCPUFunctor<T> functor(context, ids);
framework::VisitIntDataType(ids->type(), functor);
}
}; };
} // namespace operators } // namespace operators
......
...@@ -1652,7 +1652,9 @@ class Embedding(layers.Layer): ...@@ -1652,7 +1652,9 @@ class Embedding(layers.Layer):
'is_distributed', self._is_distributed, 'remote_prefetch', 'is_distributed', self._is_distributed, 'remote_prefetch',
self._remote_prefetch, 'padding_idx', self._padding_idx) self._remote_prefetch, 'padding_idx', self._padding_idx)
check_variable_and_dtype(input, 'input', ['int64'], 'Embedding') check_variable_and_dtype(input, 'input',
['uint8', 'int8', 'int16', 'int32', 'int64'],
'Embedding')
attrs = { attrs = {
'is_sparse': self._is_sparse, 'is_sparse': self._is_sparse,
'is_distributed': self._is_distributed, 'is_distributed': self._is_distributed,
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
...@@ -25,29 +26,36 @@ import paddle.fluid as fluid ...@@ -25,29 +26,36 @@ import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
class TestDygraphEmbeddingAPIError(unittest.TestCase): class TestStaticGraphSupportMultipleInt(unittest.TestCase):
def test_errors(self): def test_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
dict_size = 20 if paddle.in_dynamic_mode():
layer = fluid.dygraph.nn.Embedding( paddle.enable_static()
size=[dict_size, 32], param_attr='emb.w', is_sparse=False) disable_static = True
# the input must be Variable. else:
x0 = fluid.create_lod_tensor( disable_static = False
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) for i, dtype in enumerate(dtypes):
self.assertRaises(TypeError, layer, x0) with paddle.static.program_guard(paddle.static.Program(),
# the input dtype must be int64 paddle.static.Program()):
data_t = fluid.data(name='word', shape=[1], dtype='int32') x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
self.assertRaises(TypeError, layer, data_t) emb = paddle.nn.Embedding(10, 20)
y = emb(x)
if disable_static:
paddle.disable_static()
class TestLookupTableOp(OpTest): class TestLookupTableOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "lookup_table_v2" self.op_type = "lookup_table_v2"
table = np.random.random((17, 31)).astype("float64") table = np.random.random((17, 31)).astype("float64")
ids = np.random.randint(0, 17, 4).astype("int64") ids = np.random.randint(0, 17, 4).astype(self.id_dtype())
self.inputs = {'W': table, 'Ids': ids} self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]} self.outputs = {'Out': table[ids]}
def id_dtype(self):
return "int64"
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -55,6 +63,21 @@ class TestLookupTableOp(OpTest): ...@@ -55,6 +63,21 @@ class TestLookupTableOp(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
class TestLookupTableOpInt16(OpTest):
def id_dtype(self):
return "int16"
class TestLookupTableOpInt8(OpTest):
def id_dtype(self):
return "int8"
class TestLookupTableOpUInt8(OpTest):
def id_dtype(self):
return "uint8"
class TestLookupTableOpWithTensorIds(OpTest): class TestLookupTableOpWithTensorIds(OpTest):
def setUp(self): def setUp(self):
self.op_type = "lookup_table_v2" self.op_type = "lookup_table_v2"
...@@ -256,4 +279,5 @@ class TestEmbedOpError(unittest.TestCase): ...@@ -256,4 +279,5 @@ class TestEmbedOpError(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -30,21 +30,6 @@ from paddle.fluid import Program, program_guard ...@@ -30,21 +30,6 @@ from paddle.fluid import Program, program_guard
paddle.enable_static() paddle.enable_static()
class TestDygraphEmbeddingAPIError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
dict_size = 20
layer = fluid.dygraph.nn.Embedding(
size=[dict_size, 32], param_attr='emb.w', is_sparse=False)
# the input must be Variable
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], paddle.XPUPlace(0))
self.assertRaises(TypeError, layer, x0)
# the input dtype must be int64
data_t = fluid.data(name='word', shape=[1], dtype='int32')
self.assertRaises(TypeError, layer, data_t)
class TestLookupTableOp(OpTest): class TestLookupTableOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "lookup_table_v2" self.op_type = "lookup_table_v2"
......
...@@ -204,7 +204,9 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): ...@@ -204,7 +204,9 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
dtype = helper.input_dtype(input_param_name='weight') dtype = helper.input_dtype(input_param_name='weight')
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'embedding') check_variable_and_dtype(x, 'input',
['uint8', 'int8', 'int16', 'int32', 'int64'],
'embedding')
is_distributed = False is_distributed = False
remote_prefetch = sparse and (not is_distributed) remote_prefetch = sparse and (not is_distributed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册