From e81773c962315395a5d8a6e15caa37b7fc1e34ea Mon Sep 17 00:00:00 2001 From: phlrain Date: Mon, 28 Feb 2022 09:40:54 +0000 Subject: [PATCH] move reset impl to phi; test=develop --- .../phi/kernels/cpu/embedding_grad_kernel.cc | 85 +++- paddle/phi/kernels/cpu/embedding_kernel.cc | 2 +- .../sparse_weight_embedding_grad_kernel.cc | 103 ++++- .../cpu/sparse_weight_embedding_kernel.cc | 16 +- paddle/phi/kernels/embedding_grad_kernel.h | 9 + .../phi/kernels/gpu/embedding_grad_kernel.cu | 109 ++++- paddle/phi/kernels/gpu/embedding_kernel.cu | 2 +- .../sparse_weight_embedding_grad_kernel.h | 8 + paddle/phi/ops/compat/embedding_sig.cc | 40 +- .../unittests/test_lookup_table_v2_op.py | 416 +++++++++--------- 10 files changed, 551 insertions(+), 239 deletions(-) diff --git a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc index 56161d62614..67c28eefc87 100644 --- a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc @@ -114,12 +114,95 @@ void EmbeddingGradKernel(const Context& ctx, paddle::framework::TransToProtoVarType(input.dtype()), functor); } +template +struct LookupTableV2SparseGradCPUFunctor { + LookupTableV2SparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + auto* d_table = weight_grad_; + auto* d_output = &out_grad_; + d_table->set_rows(ids); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + d_table_value->template mutable_data(dev_ctx_.GetPlace()); + + d_table->set_height(table_dim[0]); + + auto* d_output_data = d_output->template data(); + auto* d_table_data = d_table_value->template data(); + + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + SelectedRows* weight_grad_; + int64_t padding_idx_; +}; + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + LookupTableV2SparseGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + paddle::framework::VisitIntDataType( + paddle::framework::TransToProtoVarType(input.dtype()), functor); +} + } // namespace phi -PT_REGISTER_KERNEL(embedding_grad, +PD_REGISTER_KERNEL(embedding_grad, CPU, ALL_LAYOUT, phi::EmbeddingGradKernel, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(embedding_sparse_grad, + CPU, + ALL_LAYOUT, + phi::EmbeddingSparseGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc index fe3d1f9a37b..63ea7004d42 100644 --- a/paddle/phi/kernels/cpu/embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -99,7 +99,7 @@ void EmbeddingKernel(const Context& ctx, } // namespace phi -PT_REGISTER_KERNEL(embedding, +PD_REGISTER_KERNEL(embedding, CPU, ALL_LAYOUT, phi::EmbeddingKernel, diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc index 1cc5f734357..743faa3e43e 100644 --- a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/fluid/framework/convert_utils.h" @@ -23,13 +23,13 @@ namespace phi { template -struct LookupTableV2GradCPUFunctor { - LookupTableV2GradCPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const SelectedRows& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - DenseTensor* weight_grad) +struct SparseWeightLookupTableV2GradCPUFunctor { + SparseWeightLookupTableV2GradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -101,6 +101,68 @@ struct LookupTableV2GradCPUFunctor { int64_t padding_idx_; }; +template +struct SparseWeightLookupTableV2SparseGradCPUFunctor { + SparseWeightLookupTableV2SparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + auto* d_table = weight_grad_; + auto* d_output = &out_grad_; + d_table->set_rows(ids); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + d_table_value->template mutable_data(dev_ctx_.GetPlace()); + + d_table->set_height(table_dim[0]); + + auto* d_output_data = d_output->template data(); + auto* d_table_data = d_table_value->template data(); + + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + const DenseTensor& out_grad_; + SelectedRows* weight_grad_; + int64_t padding_idx_; +}; + template void SparseWeightEmbeddingGradKernel(const Context& ctx, const DenseTensor& input, @@ -108,7 +170,20 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, DenseTensor* weight_grad) { - LookupTableV2GradCPUFunctor functor( + SparseWeightLookupTableV2GradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + paddle::framework::VisitIntDataType( + paddle::framework::TransToProtoVarType(input.dtype()), functor); +} + +template +void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + SparseWeightLookupTableV2SparseGradCPUFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); paddle::framework::VisitIntDataType( paddle::framework::TransToProtoVarType(input.dtype()), functor); @@ -116,10 +191,18 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, } // namespace phi -PT_REGISTER_KERNEL(sparse_weight_embedding_grad, +PD_REGISTER_KERNEL(sparse_weight_embedding_grad, CPU, ALL_LAYOUT, phi::SparseWeightEmbeddingGradKernel, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingSparseGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc index 7a9fef47300..d8a53f42f60 100644 --- a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc @@ -24,12 +24,12 @@ namespace phi { template -struct LookupTableV2CPUFunctor { - LookupTableV2CPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const SelectedRows& weight, - int64_t padding_idx, - DenseTensor* out) +struct LookupTableV2CPUSparseFunctor { + LookupTableV2CPUSparseFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -94,7 +94,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx, const SelectedRows& weight, int64_t padding_idx, DenseTensor* out) { - LookupTableV2CPUFunctor functor( + LookupTableV2CPUSparseFunctor functor( ctx, input, weight, padding_idx, out); paddle::framework::VisitIntDataType( paddle::framework::TransToProtoVarType(input.dtype()), functor); @@ -102,7 +102,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx, } // namespace phi -PT_REGISTER_KERNEL(sparse_weight_embedding, +PD_REGISTER_KERNEL(sparse_weight_embedding, CPU, ALL_LAYOUT, phi::SparseWeightEmbeddingKernel, diff --git a/paddle/phi/kernels/embedding_grad_kernel.h b/paddle/phi/kernels/embedding_grad_kernel.h index 155e7329be6..40ffe6ec886 100644 --- a/paddle/phi/kernels/embedding_grad_kernel.h +++ b/paddle/phi/kernels/embedding_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" namespace phi { @@ -26,4 +27,12 @@ void EmbeddingGradKernel(const Context& ctx, int64_t padding_idx, DenseTensor* weight_grad); +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 0acec201c17..d04e221a538 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -21,7 +21,9 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + namespace phi { template @@ -120,12 +122,117 @@ void EmbeddingGradKernel(const Context& ctx, paddle::framework::VisitIntDataType( paddle::framework::TransToProtoVarType(input.dtype()), functor); } + +template +struct LookupTableV2SparseGradCUDAFunctor { + LookupTableV2SparseGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + padding_idx_(padding_idx), + weight_grad_(weight_grad) {} + + template + void apply() { + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + + const auto* ids_data = input_.template data(); + auto* d_table = weight_grad_; + auto* table = &weight_; + auto* d_output = &out_grad_; + int64_t ids_num = input_.numel(); + dim3 threads(128, 8); + dim3 grids(8, 1); + auto stream = dev_ctx_.stream(); + paddle::framework::Vector new_rows; + new_rows.resize(ids_num); + auto gpu_place = dev_ctx_.GetPlace(); + + paddle::framework::MixVector mixv_new_rows(&new_rows); + if (!std::is_same::value) { + InputTypeConvert<<>>( + ids_data, ids_num, mixv_new_rows.MutableData(gpu_place)); + } else { + paddle::memory::Copy(gpu_place, + mixv_new_rows.CUDAMutableData(gpu_place), + gpu_place, + ids_data, + ids_num * sizeof(int64_t), + stream); + } + + mixv_new_rows.CopyToCPU(); + d_table->set_rows(new_rows); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table->dims()[1]}); + d_table_value->template mutable_data(gpu_place); + + auto* d_table_data = d_table_value->template data(); + auto* d_output_data = d_output->template data(); + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + paddle::memory::Copy(gpu_place, + d_table_data, + gpu_place, + d_output_data, + d_output->numel() * sizeof(T), + stream); + } + + private: + const phi::GPUContext& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + int64_t padding_idx_; + SelectedRows* weight_grad_; +}; + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + LookupTableV2SparseGradCUDAFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + paddle::framework::VisitIntDataType( + paddle::framework::TransToProtoVarType(input.dtype()), functor); +} + } // namespace phi -PT_REGISTER_KERNEL(embedding_grad, +PD_REGISTER_KERNEL(embedding_grad, GPU, ALL_LAYOUT, phi::EmbeddingGradKernel, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(embedding_sparse_grad, + GPU, + ALL_LAYOUT, + phi::EmbeddingSparseGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 114942bfddd..6830af163e2 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -115,7 +115,7 @@ void EmbeddingKernel(const Context &ctx, } // namespace phi -PT_REGISTER_KERNEL(embedding, +PD_REGISTER_KERNEL(embedding, GPU, ALL_LAYOUT, phi::EmbeddingKernel, diff --git a/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h b/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h index 51627db7870..772268c2cc3 100644 --- a/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h +++ b/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h @@ -27,4 +27,12 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, int64_t padding_idx, DenseTensor* weight_grad); +template +void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad); + } // namespace phi diff --git a/paddle/phi/ops/compat/embedding_sig.cc b/paddle/phi/ops/compat/embedding_sig.cc index 350da0b13c8..b79a381dcec 100644 --- a/paddle/phi/ops/compat/embedding_sig.cc +++ b/paddle/phi/ops/compat/embedding_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("W")) { - LOG(ERROR) << "is dense here"; return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); } else { - LOG(ERROR) << "is selcted rows"; return KernelSignature( "sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); } @@ -30,23 +28,37 @@ KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EmbeddingGradOpArgumentMapping( const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("W")) { - return KernelSignature("embedding_grad", - {"Ids", "W", GradVarName("Out")}, - {"padding_idx"}, - {GradVarName("W")}); + if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { + return KernelSignature("embedding_sparse_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } } else { - return KernelSignature("sparse_weight_embedding_grad", - {"Ids", "W", GradVarName("Out")}, - {"padding_idx"}, - {GradVarName("W")}); + if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { + return KernelSignature("sparse_weight_embedding_sparse_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("sparse_weight_embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } } } } // namespace phi -PT_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding); -PT_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad); +PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding); +PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad); -PT_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping); -PT_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad, +PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad, phi::EmbeddingGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 7fd70e0cc66..cad6437d1d3 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -25,23 +25,24 @@ import paddle.compat as cpt import paddle.fluid as fluid from paddle.fluid import Program, program_guard -# class TestStaticGraphSupportMultipleInt(unittest.TestCase): -# def test_main(self): -# dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64'] -# if paddle.in_dynamic_mode(): -# paddle.enable_static() -# disable_static = True -# else: -# disable_static = False -# for i, dtype in enumerate(dtypes): -# with paddle.static.program_guard(paddle.static.Program(), -# paddle.static.Program()): -# x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype) -# emb = paddle.nn.Embedding(10, 20) -# y = emb(x) - -# if disable_static: -# paddle.disable_static() + +class TestStaticGraphSupportMultipleInt(unittest.TestCase): + def test_main(self): + dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64'] + if paddle.in_dynamic_mode(): + paddle.enable_static() + disable_static = True + else: + disable_static = False + for i, dtype in enumerate(dtypes): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype) + emb = paddle.nn.Embedding(10, 20) + y = emb(x) + + if disable_static: + paddle.disable_static() class TestLookupTableOp(OpTest): @@ -62,17 +63,19 @@ class TestLookupTableOp(OpTest): self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) -# class TestLookupTableOpInt16(OpTest): -# def id_dtype(self): -# return "int16" +class TestLookupTableOpInt16(OpTest): + def id_dtype(self): + return "int16" -# class TestLookupTableOpInt8(OpTest): -# def id_dtype(self): -# return "int8" -# class TestLookupTableOpUInt8(OpTest): -# def id_dtype(self): -# return "uint8" +class TestLookupTableOpInt8(OpTest): + def id_dtype(self): + return "int8" + + +class TestLookupTableOpUInt8(OpTest): + def id_dtype(self): + return "uint8" class TestLookupTableOpWithTensorIds(OpTest): @@ -90,183 +93,190 @@ class TestLookupTableOpWithTensorIds(OpTest): self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) -# @skip_check_grad_ci( -# reason="Since paddings are not trainable and fixed in forward," -# "the gradient of paddings makes no sense and we don't " -# "test the gradient here.") -# class TestLookupTableOpWithPadding(TestLookupTableOp): -# def test_check_output(self): -# ids = np.squeeze(self.inputs['Ids']) -# padding_idx = np.random.choice(ids, 1)[0] -# self.outputs['Out'][ids == padding_idx] = np.zeros(31) -# self.attrs = {'padding_idx': int(padding_idx)} -# self.check_output() - -# @skip_check_grad_ci( -# reason="Since paddings are not trainable and fixed in forward," -# "the gradient of paddings makes no sense and we don't " -# "test the gradient here.") -# class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): -# def test_check_output(self): -# ids = self.inputs['Ids'] -# flatten_idx = ids.flatten() -# padding_idx = np.random.choice(flatten_idx, 1)[0] -# self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) -# self.attrs = {'padding_idx': cpt.long_type(padding_idx)} -# self.check_output() - -# class TestLookupTableWIsSelectedRows(unittest.TestCase): -# def prepare_ids(self, scope, place): -# ids_tensor = scope.var('Ids').get_tensor() -# ids_array = np.array([0, 4, 3, 5]).astype("int32") -# ids_tensor.set(ids_array, place) -# return ids_array - -# def prepare_w(self, scope, place): -# rows = [0, 1, 2, 3, 4, 5, 6] -# row_numel = 12 - -# w_selected_rows = scope.var('W').get_selected_rows() -# w_selected_rows.set_height(len(rows)) -# w_selected_rows.set_rows(rows) -# w_array = np.ones((len(rows), row_numel)).astype("float32") -# for i in range(len(rows)): -# w_array[i] *= i -# w_tensor = w_selected_rows.get_tensor() -# w_tensor.set(w_array, place) - -# def create_out_tensor(self, scope, place): -# return scope.var('Out').get_tensor() - -# def check_result(self, ids_array, result_array): -# # all(): return True if all elements of the iterable are true (or if the iterable is empty) -# for idx, row in enumerate(ids_array): -# assert (row == result_array[idx]).all() - -# def check_with_place(self, place): -# scope = core.Scope() - -# ids_array = self.prepare_ids(scope, place) - -# self.prepare_w(scope, place) - -# out_tensor = self.create_out_tensor(scope, place) - -# # create and run lookup_table operator -# lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out') -# lookup_table.run(scope, place) - -# # get result from Out -# result_array = np.array(out_tensor) - -# self.check_result(ids_array, result_array) - -# def test_w_is_selected_rows(self): -# places = [core.CPUPlace()] -# # currently only support CPU -# for place in places: -# self.check_with_place(place) - -# class TestLookupTableWithTensorIdsWIsSelectedRows( -# TestLookupTableWIsSelectedRows): -# def prepare_ids(self, scope, place): -# ids_tensor = scope.var('Ids').get_tensor() -# ids_array = np.random.randint( -# low=0, high=6, size=(2, 4, 3)).astype("int64") -# ids_tensor.set(ids_array, place) -# return ids_array - -# def check_result(self, ids_array, result_array): -# for idx, row in np.ndenumerate(ids_array): -# assert (row == result_array[idx]).all() - -# class TestLookupTableIsSparse(unittest.TestCase): -# def init_data(self): -# self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64") -# self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32") - -# def get_w_grad(self, is_sparse): -# self.init_data() -# main_program = fluid.Program() -# with fluid.program_guard(main_program, fluid.Program()): -# x = fluid.layers.data(name='x', shape=[5], dtype='int64') -# y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32') -# emb = fluid.input.embedding( -# input=x, -# size=[10, 16], -# param_attr=fluid.ParamAttr( -# name="emb_weight", -# learning_rate=10, -# initializer=fluid.initializer.NumpyArrayInitializer( -# self.w_data)), -# is_sparse=is_sparse) -# y = fluid.layers.reduce_sum(emb, dim=-1) - -# loss = fluid.layers.square_error_cost(input=y, label=y_) -# loss = fluid.layers.mean(loss) - -# sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4) -# sgd_optimizer.minimize(loss) - -# place = fluid.CPUPlace() -# exe = fluid.Executor(place) -# exe.run(fluid.default_startup_program()) -# ret = exe.run(feed={'x': self.x_data, -# 'y_': self.y_data}, -# fetch_list=['emb_weight'], -# return_numpy=False) -# return np.array(ret[0]) - -# def test_w_grad(self): -# self.w_data = np.random.random(size=(10, 16)).astype("float32") -# w_grad = self.get_w_grad(False) -# w_grad_with_sparse = self.get_w_grad(True) -# self.check_grad(w_grad, w_grad_with_sparse) - -# def check_grad(self, w_grad1, w_grad2, tolerance=1e-6): -# np.testing.assert_allclose( -# w_grad1, w_grad2, rtol=tolerance, atol=tolerance) - -# class TestLookupTableApi(unittest.TestCase): -# def test_api(self): -# x = fluid.layers.data(name='x', shape=[20], dtype='int64') -# emb = fluid.embedding(input=x, size=[128, 64]) - -# place = fluid.CPUPlace() -# x_data = np.random.randint(0, 127, [2, 20]).astype("int64") - -# exe = fluid.Executor(place) -# exe.run(fluid.default_startup_program()) -# ret = exe.run(feed={'x': x_data, }, -# fetch_list=[emb], -# return_numpy=False) - -# class TestEmbedOpError(unittest.TestCase): -# def test_errors(self): -# with program_guard(Program(), Program()): -# input_data = np.random.randint(0, 10, (4, 6)).astype("int64") - -# def test_Variable(): -# # the input type must be Variable -# fluid.embedding(input=input_data, size=(10, 64)) - -# self.assertRaises(TypeError, test_Variable) - -# def test_input_dtype(): -# # the input dtype must be int64 -# input = fluid.data(name='x1', shape=[4, 6], dtype='float32') -# fluid.embedding(input=input, size=(10, 64)) - -# self.assertRaises(TypeError, test_input_dtype) - -# def test_param_dtype(): -# # dtype must be float32 or float64 -# input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64') -# fluid.embedding(input=input2, size=(10, 64), dtype='int64') - -# self.assertRaises(TypeError, test_param_dtype) -# input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64') -# fluid.embedding(input=input3, size=(10, 64), dtype='float16') +@skip_check_grad_ci( + reason="Since paddings are not trainable and fixed in forward," + "the gradient of paddings makes no sense and we don't " + "test the gradient here.") +class TestLookupTableOpWithPadding(TestLookupTableOp): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output() + + +@skip_check_grad_ci( + reason="Since paddings are not trainable and fixed in forward," + "the gradient of paddings makes no sense and we don't " + "test the gradient here.") +class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': cpt.long_type(padding_idx)} + self.check_output() + + +class TestLookupTableWIsSelectedRows(unittest.TestCase): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.array([0, 4, 3, 5]).astype("int32") + ids_tensor.set(ids_array, place) + return ids_array + + def prepare_w(self, scope, place): + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 12 + + w_selected_rows = scope.var('W').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("float32") + for i in range(len(rows)): + w_array[i] *= i + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) + + def create_out_tensor(self, scope, place): + return scope.var('Out').get_tensor() + + def check_result(self, ids_array, result_array): + # all(): return True if all elements of the iterable are true (or if the iterable is empty) + for idx, row in enumerate(ids_array): + assert (row == result_array[idx]).all() + + def check_with_place(self, place): + scope = core.Scope() + + ids_array = self.prepare_ids(scope, place) + + self.prepare_w(scope, place) + + out_tensor = self.create_out_tensor(scope, place) + + # create and run lookup_table operator + lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out') + lookup_table.run(scope, place) + + # get result from Out + result_array = np.array(out_tensor) + + self.check_result(ids_array, result_array) + + def test_w_is_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + +class TestLookupTableWithTensorIdsWIsSelectedRows( + TestLookupTableWIsSelectedRows): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.random.randint( + low=0, high=6, size=(2, 4, 3)).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def check_result(self, ids_array, result_array): + for idx, row in np.ndenumerate(ids_array): + assert (row == result_array[idx]).all() + + +class TestLookupTableIsSparse(unittest.TestCase): + def init_data(self): + self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64") + self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32") + + def get_w_grad(self, is_sparse): + self.init_data() + main_program = fluid.Program() + with fluid.program_guard(main_program, fluid.Program()): + x = fluid.layers.data(name='x', shape=[5], dtype='int64') + y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32') + emb = fluid.input.embedding( + input=x, + size=[10, 16], + param_attr=fluid.ParamAttr( + name="emb_weight", + learning_rate=10, + initializer=fluid.initializer.NumpyArrayInitializer( + self.w_data)), + is_sparse=is_sparse) + y = fluid.layers.reduce_sum(emb, dim=-1) + + loss = fluid.layers.square_error_cost(input=y, label=y_) + loss = fluid.layers.mean(loss) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4) + sgd_optimizer.minimize(loss) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(feed={'x': self.x_data, + 'y_': self.y_data}, + fetch_list=['emb_weight'], + return_numpy=False) + return np.array(ret[0]) + + def test_w_grad(self): + self.w_data = np.random.random(size=(10, 16)).astype("float32") + w_grad = self.get_w_grad(False) + w_grad_with_sparse = self.get_w_grad(True) + self.check_grad(w_grad, w_grad_with_sparse) + + def check_grad(self, w_grad1, w_grad2, tolerance=1e-6): + np.testing.assert_allclose( + w_grad1, w_grad2, rtol=tolerance, atol=tolerance) + + +class TestLookupTableApi(unittest.TestCase): + def test_api(self): + x = fluid.layers.data(name='x', shape=[20], dtype='int64') + emb = fluid.embedding(input=x, size=[128, 64]) + + place = fluid.CPUPlace() + x_data = np.random.randint(0, 127, [2, 20]).astype("int64") + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(feed={'x': x_data, }, + fetch_list=[emb], + return_numpy=False) + + +class TestEmbedOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + input_data = np.random.randint(0, 10, (4, 6)).astype("int64") + + def test_Variable(): + # the input type must be Variable + fluid.embedding(input=input_data, size=(10, 64)) + + self.assertRaises(TypeError, test_Variable) + + def test_input_dtype(): + # the input dtype must be int64 + input = fluid.data(name='x1', shape=[4, 6], dtype='float32') + fluid.embedding(input=input, size=(10, 64)) + + self.assertRaises(TypeError, test_input_dtype) + + def test_param_dtype(): + # dtype must be float32 or float64 + input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64') + fluid.embedding(input=input2, size=(10, 64), dtype='int64') + + self.assertRaises(TypeError, test_param_dtype) + input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64') + fluid.embedding(input=input3, size=(10, 64), dtype='float16') + if __name__ == "__main__": paddle.enable_static() -- GitLab