diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ae41992835584d8106337aacae0a09ba20e72ce3..c6bafb64405cd65f60d66071dea31bc85061578c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -112,7 +112,8 @@ paddle.fluid.initializer.force_init_on_cpu (ArgSpec(args=[], varargs=None, keywo paddle.fluid.initializer.init_on_cpu (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'eaa04fd68661a3af59abd0e19b3b6eda')) paddle.fluid.initializer.NumpyArrayInitializer ('paddle.fluid.initializer.NumpyArrayInitializer', ('document', '064f134a27c16372967d450f499762ab')) paddle.fluid.initializer.NumpyArrayInitializer.__init__ (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.input.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'c79292312a35b99ff2801a274b666358')) +paddle.fluid.embedding (ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')), ('document', 'd4ac047e0d5e6b7b1c5ff6ef7d7cfff5')) +paddle.fluid.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'eef66730acc806088f9e8ba90252bda1')) paddle.fluid.layers.fc (ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, None)), ('document', '0dc8181f14a33f91fbae9385a9b3d9fd')) paddle.fluid.layers.center_loss (ArgSpec(args=['input', 'label', 'num_classes', 'alpha', 'param_attr', 'update_center'], varargs=None, keywords=None, defaults=(True,)), ('document', '7129819d94625c6104054e8187768589')) paddle.fluid.layers.embedding (ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')), ('document', 'd8e405486a1e4e189b51d6ee28d67b1e')) diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1b982356a80ff299b16550b7f7eb57122ced418 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -0,0 +1,192 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/lookup_table_v2_op.h" + +#include + +#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace operators { + +class LookupTableV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, + "Input(W) of LookupTableV2Op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true, + "Input(Ids) of LookupTableV2Op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of LookupTableV2Op should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); + VLOG(5) << "ids rank is " << ids_rank << std::endl; + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + + auto output_dims = framework::vectorize(ids_dims); + output_dims.push_back(table_dims[1]); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + + if (ctx->GetOutputsVarType("Out")[0] == + framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Ids", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); + AddAttr("is_distributed", + "(boolean, default false) distributed lookup table.") + .SetDefault(false); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); + + // for parameter prefetch + AddAttr("remote_prefetch", "").SetDefault(false); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({}); + AddAttr>( + "table_names", + "(string vector, the splited table names that will be fetched from " + "parameter server)" + "in the order of input variables for mapping") + .SetDefault({}); + + AddComment(R"DOC( +Lookup Table V2 Operator. + +This operator is used to perform lookups on the parameter W, +then concatenated into a dense tensor. + +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. + +)DOC"); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LookupTableV2GradOpNoBuffer, "W"); + +class LookupTableV2GradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + + op->SetType("lookup_table_v2_grad"); + + op->SetInput("W", Input("W")); + op->SetInput("Ids", Input("Ids")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + + op->SetOutput(framework::GradVarName("W"), InputGrad("W")); + + op->SetAttrMap(Attrs()); + return op; + } +}; + +class LookupTableV2OpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar( + ctx.InputVar(framework::GradVarName("Out"))); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto attr = ctx->GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + } + ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lookup_table_v2, ops::LookupTableV2Op, + ops::LookupTableV2OpMaker, ops::LookupTableV2GradOpDescMaker); + +REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad, + ops::LookupTableV2GradOpNoBuffer, + ops::LookupTableV2OpGradVarTypeInference); + +REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel, + ops::LookupTableV2Kernel); +REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad, + ops::LookupTableV2GradKernel, + ops::LookupTableV2GradKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e7f580c5fdbbde74bd739a81d8a5abed80788fd2 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -0,0 +1,201 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/lookup_table_v2_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +__global__ void LookupTableV2(T *output, const T *table, const int64_t *ids, + const int64_t N, const int64_t K, const int64_t D, + const int64_t padding_idx) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * GridDimX; + + while (idy < K) { + int64_t id = ids[idy]; + PADDLE_ENFORCE( + id >= 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); + PADDLE_ENFORCE( + id < N, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); + T *out = output + idy * D; + const T *tab = table + id * D; + for (int i = idx; i < D; i += BlockDimX) { + if (PaddingFlag) { + if (id == padding_idx) + out[i] = static_cast(0); + else + out[i] = tab[i]; + } else { + out[i] = tab[i]; + } + } + idy += BlockDimY * GridDimX; + } +} + +template +__global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, + const int64_t N, const int64_t K, + const int64_t D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * GridDimX; + + while (idy < K) { + int64_t id = ids[idy]; + PADDLE_ENFORCE( + id >= 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); + PADDLE_ENFORCE( + id < N, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); + const T *out = output + idy * D; + T *tab = table + id * D; + for (int i = idx; i < D; i += BlockDimX) { + paddle::platform::CudaAtomicAdd(&tab[i], out[i]); + } + idy += BlockDimY * GridDimX; + } +} + +template +class LookupTableV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *table_t = context.Input("W"); + auto *ids_t = context.Input("Ids"); + auto *output_t = context.Output("Out"); + int64_t padding_idx = context.Attr("padding_idx"); + + auto id_name = context.Inputs("Ids").front(); + auto out_name = context.Outputs("Out").front(); + + size_t N = table_t->dims()[0]; + size_t D = table_t->dims()[1]; + size_t K = ids_t->numel(); + + auto *ids = ids_t->data(); + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + dim3 threads(128, 8); + dim3 grids(8, 1); + + if (padding_idx == -1) + LookupTableV2< + T, 128, 8, 8, + false><<>>( + output, table, ids, N, K, D, padding_idx); + else + LookupTableV2< + T, 128, 8, 8, + true><<>>( + output, table, ids, N, K, D, padding_idx); + } +}; + +template +class LookupTableV2GradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto &dev_ctx = + context.template device_context(); + bool is_sparse = context.Attr("is_sparse"); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + if (is_sparse) { + auto *ids = context.Input("Ids"); + auto *table = context.Input("W"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + auto *ids_data = ids->data(); + int64_t ids_num = ids->numel(); + + auto stream = dev_ctx.stream(); + // copy GPU memory to CPU pinned memory + framework::Vector new_rows; + new_rows.resize(ids_num); + auto gpu_place = boost::get(context.GetPlace()); + + // TODO(yuyang18): Strange code here. + memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()), + gpu_place, ids_data, ids_num * sizeof(int64_t), stream); + d_table->set_rows(new_rows); + + auto *d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table->dims()[1]}); + d_table_value->mutable_data(context.GetPlace()); + + auto *d_table_data = d_table_value->data(); + auto *d_output_data = d_output->data(); + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); + memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data, + d_output->numel() * sizeof(T), stream); + + } else { + auto ids_t = context.Input("Ids"); + auto d_output_t = context.Input(framework::GradVarName("Out")); + auto d_table_t = context.Output(framework::GradVarName("W")); + + int N = d_table_t->dims()[0]; + int D = d_table_t->dims()[1]; + int K = ids_t->numel(); + const int64_t *ids = ids_t->data(); + const T *d_output = d_output_t->data(); + T *d_table = d_table_t->mutable_data(context.GetPlace()); + + auto t = framework::EigenVector::Flatten(*d_table_t); + t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); + + dim3 threads(128, 8); + dim3 grids(8, 1); + LookupTableV2Grad<<>>( + d_table, d_output, ids, N, K, D); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel, + ops::LookupTableV2CUDAKernel, + ops::LookupTableV2CUDAKernel); +REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad, + ops::LookupTableV2GradCUDAKernel, + ops::LookupTableV2GradCUDAKernel, + ops::LookupTableV2GradCUDAKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..16f4d7c4171b0b160d32352deeb0fa0f460a3291 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -0,0 +1,218 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" + +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/operators/distributed/parameter_prefetch.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +constexpr int64_t kNoPadding = -1; + +template +class LookupTableV2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *ids_t = context.Input("Ids"); // int tensor + auto *output_t = context.Output("Out"); // float tensor + auto *table_var = context.InputVar("W"); + + auto id_name = context.Inputs("Ids").front(); + auto embedding_name = context.Inputs("W").front(); + auto out_name = context.Outputs("Out").front(); + + // for remote prefetch + auto epmap = context.Attr>("epmap"); + auto remote_prefetch = context.Attr("remote_prefetch"); + auto height_sections = + context.Attr>("height_sections"); + auto table_names = context.Attr>("table_names"); + + if (remote_prefetch && !epmap.empty()) { +// if epmap is not empty, then the parameter will be fetched from remote +// parameter server + +#ifdef PADDLE_WITH_DISTRIBUTE + operators::distributed::prefetch(id_name, out_name, embedding_name, false, + table_names, epmap, height_sections, + context, context.scope()); +#else + PADDLE_THROW( + "paddle is not compiled with distribute support, can not do " + "parameter prefetch!"); +#endif + } else { + int64_t padding_idx = context.Attr("padding_idx"); + int64_t *ids = const_cast(ids_t->data()); + int64_t ids_numel = ids_t->numel(); + + if (table_var->IsType()) { + auto *table_t = context.Input("W"); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT( + ids[i], row_number, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i]); + PADDLE_ENFORCE_GE( + ids[i], 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i]); + memcpy(output + i * row_width, table + ids[i] * row_width, + row_width * sizeof(T)); + } + } + } else if (table_var->IsType()) { + const auto &table_t = table_var->Get(); + int64_t row_width = table_t.value().dims()[1]; + const auto *table = table_t.value().data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE(ids[i], 0); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); + } + } + } + } + } +}; + +template +class LookupTableV2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *table_var = context.InputVar("W"); + DDim table_dim; + if (table_var->IsType()) { + table_dim = context.Input("W")->dims(); + } else if (table_var->IsType()) { + auto *table_t = context.Input("W"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter W of a LookupTableV2 " + "must be either LoDTensor or SelectedRows"); + } + + int64_t padding_idx = context.Attr("padding_idx"); + bool is_sparse = context.Attr("is_sparse"); + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + if (is_sparse) { + auto *ids = context.Input("Ids"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + auto *ids_data = ids->data(); + int64_t ids_num = ids->numel(); + + std::vector new_rows; + new_rows.resize(ids_num); + std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); + d_table->set_rows(new_rows); + + auto *d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + d_table_value->mutable_data(context.GetPlace()); + + d_table->set_height(table_dim[0]); + + auto *d_output_data = d_output->data(); + auto *d_table_data = d_table_value->data(); + + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + + } else { + auto *ids = context.Input("Ids"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + auto *ids_data = ids->data(); + + int64_t N = table_dim[0]; + int64_t D = table_dim[1]; + + auto *d_output_data = d_output->data(); + auto *d_table_data = d_table->mutable_data(context.GetPlace()); + + memset(d_table_data, 0, d_table->numel() * sizeof(T)); + + for (int64_t i = 0; i < ids->numel(); ++i) { + if (padding_idx != kNoPadding && ids_data[i] == padding_idx) { + // the gradient of padding_idx should be 0, already done by memset, so + // do nothing. + } else { + PADDLE_ENFORCE_LT( + ids_data[i], N, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, ids_data[i]); + PADDLE_ENFORCE_GE( + ids_data[i], 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, ids_data[i]); + for (int j = 0; j < D; ++j) { + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 6f266c576f5045a2c1ef10cd17ff00e02f204b89..180fae663161a6b540ed7e91e14f8a05953bdec5 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -62,7 +62,7 @@ from . import average from . import metrics from . import transpiler from . import incubate -from . import input +from .input import embedding, one_hot from . import distribute_lookup_table from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder @@ -93,7 +93,8 @@ __all__ = framework.__all__ + executor.__all__ + \ data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ 'io', 'initializer', - 'input', + 'embedding', + 'one_hot', 'layers', 'contrib', 'dygraph', diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py index 4169f646c0d9f0cdc132dbd31791b58549893fed..8afbd662ad7b87157872c61d2cafad1c84aa3b77 100644 --- a/python/paddle/fluid/input.py +++ b/python/paddle/fluid/input.py @@ -16,7 +16,7 @@ from __future__ import print_function from .framework import Variable, in_dygraph_mode from .layer_helper import LayerHelper -__all__ = ['one_hot'] +__all__ = ['one_hot', 'embedding'] def one_hot(input, depth, allow_out_of_range=False): @@ -40,7 +40,7 @@ def one_hot(input, depth, allow_out_of_range=False): import paddle.fluid as fluid label = fluid.layers.data(name="label", shape=[1], dtype="int64") - one_hot_label = fluid.input.one_hot(input=label, depth=10) + one_hot_label = fluid.one_hot(input=label, depth=10) """ helper = LayerHelper("one_hot_v2", **locals()) @@ -65,3 +65,73 @@ def one_hot(input, depth, allow_out_of_range=False): outputs={'Out': one_hot_out}, stop_gradient=True) return one_hot_out + + +def embedding(input, + size, + is_sparse=False, + is_distributed=False, + padding_idx=None, + param_attr=None, + dtype='float32'): + """ + **Embedding Layer** + + This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in + a lookup table. The result of this lookup is the embedding of each ID in the + :attr:`input`. + + All the input variables are passed in as local variables to the LayerHelper + constructor. + + Args: + input(Variable): Input is a Tensor Variable, which contains the IDs information. + The value of the input IDs should satisfy :math:`0<= id < size[0]`. + size(tuple|list): The shape of the look up table parameter. It should + have two elements which indicate the size of the dictionary of + embeddings and the size of each embedding vector respectively. + is_sparse(bool): The flag indicating whether to use sparse update. + is_distributed(bool): Whether to run lookup table from remote parameter server. + padding_idx(int|long|None): It will output all-zero padding data whenever + lookup encounters :math:`padding\_idx` in Ids. If set :attr:`None`, it makes + no effect to output. If :math:`padding\_idx < 0`, the :math:`padding\_idx` + will automatically be converted to :math:`size[0] + padding\_idx` to use. + Default: None. + param_attr(ParamAttr): Parameters for this layer. + dtype(np.dtype|core.VarDesc.VarType|str): The dtype refers to the data type of output + tensor. It can be float32, float_16, int etc. + + Returns: + Variable: The tensor variable storing the embeddings of the \ + supplied inputs. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + # [batch_size, 20] -> [batch_size, 20, 64] + data = fluid.layers.data(name='sequence', shape=[20], dtype='int64') + emb = fluid.embedding(input=data, size=[128, 64]) + """ + + helper = LayerHelper('embedding', **locals()) + remote_prefetch = is_sparse and (not is_distributed) + if remote_prefetch: + assert is_sparse is True and is_distributed is False + w = helper.create_parameter( + attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False) + tmp = helper.create_variable_for_type_inference(dtype) + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + size[0] + padding_idx) + helper.append_op( + type='lookup_table_v2', + inputs={'Ids': input, + 'W': w}, + outputs={'Out': tmp}, + attrs={ + 'is_sparse': is_sparse, + 'is_distributed': is_distributed, + 'remote_prefetch': remote_prefetch, + 'padding_idx': padding_idx + }) + return tmp diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..46a219bbb2fd9ad131793a2b52768b975e1debdb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.op import Operator +import paddle.compat as cpt + + +class TestLookupTableOp(OpTest): + def setUp(self): + self.op_type = "lookup_table_v2" + table = np.random.random((17, 31)).astype("float32") + ids = np.random.randint(0, 17, 4).astype("int64") + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) + + +class TestLookupTableOpWithTensorIds(OpTest): + def setUp(self): + self.op_type = "lookup_table_v2" + table = np.random.random((17, 31)).astype("float32") + ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype("int64") + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) + + +class TestLookupTableOpWithPadding(TestLookupTableOp): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output() + + def test_check_grad(self): + # Since paddings are not trainable and fixed in forward, the gradient of + # paddings makes no sense and we don't test the gradient here. + pass + + +class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': cpt.long_type(padding_idx)} + self.check_output() + + def test_check_grad(self): + # Since paddings are not trainable and fixed in forward, the gradient of + # paddings makes no sense and we don't test the gradient here. + pass + + +class TestLookupTableWIsSelectedRows(OpTest): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.array([0, 4, 3, 5]).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def prepare_w(self, scope, place): + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 12 + + w_selected_rows = scope.var('W').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("float32") + for i in range(len(rows)): + w_array[i] *= i + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) + + def create_out_tensor(self, scope, place): + return scope.var('Out').get_tensor() + + def check_result(self, ids_array, result_array): + # all(): return True if all elements of the iterable are true (or if the iterable is empty) + for idx, row in enumerate(ids_array): + assert (row == result_array[idx]).all() + + def check_with_place(self, place): + scope = core.Scope() + + ids_array = self.prepare_ids(scope, place) + + self.prepare_w(scope, place) + + out_tensor = self.create_out_tensor(scope, place) + + # create and run lookup_table operator + lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out') + lookup_table.run(scope, place) + + # get result from Out + result_array = np.array(out_tensor) + + self.check_result(ids_array, result_array) + + def test_w_is_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + +class TestLookupTableWithTensorIdsWIsSelectedRows( + TestLookupTableWIsSelectedRows): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.random.randint( + low=0, high=6, size=(2, 4, 3)).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def check_result(self, ids_array, result_array): + for idx, row in np.ndenumerate(ids_array): + assert (row == result_array[idx]).all() + + +class TestLookupTableIsSparse(unittest.TestCase): + def init_data(self): + self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64") + self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32") + + def get_w_grad(self, is_sparse): + self.init_data() + main_program = fluid.Program() + with fluid.program_guard(main_program, fluid.Program()): + x = fluid.layers.data(name='x', shape=[5], dtype='int64') + y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32') + emb = fluid.input.embedding( + input=x, + size=[10, 16], + param_attr=fluid.ParamAttr( + name="emb_weight", + learning_rate=10, + initializer=fluid.initializer.NumpyArrayInitializer( + self.w_data)), + is_sparse=is_sparse) + y = fluid.layers.reduce_sum(emb, dim=-1) + + loss = fluid.layers.square_error_cost(input=y, label=y_) + loss = fluid.layers.mean(loss) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4) + sgd_optimizer.minimize(loss) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(feed={'x': self.x_data, + 'y_': self.y_data}, + fetch_list=['emb_weight'], + return_numpy=False) + return np.array(ret[0]) + + def test_w_grad(self): + self.w_data = np.random.random(size=(10, 16)).astype("float32") + w_grad = self.get_w_grad(False) + w_grad_with_sparse = self.get_w_grad(True) + self.check_grad(w_grad, w_grad_with_sparse) + + def check_grad(self, w_grad1, w_grad2, tolerance=1e-6): + np.testing.assert_allclose( + w_grad1, w_grad2, rtol=tolerance, atol=tolerance) + + +class TestLookupTableApi(unittest.TestCase): + def test_api(self): + x = fluid.layers.data(name='x', shape=[20], dtype='int64') + emb = fluid.embedding(input=x, size=[128, 64]) + + place = fluid.CPUPlace() + x_data = np.random.randint(0, 127, [2, 20]).astype("int64") + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(feed={'x': x_data, }, + fetch_list=[emb], + return_numpy=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py b/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py index 85069b0203984ed41fe92c651294922642adcc4a..dc948c42bc6d6a568f99e8c709514e4196c5a81c 100644 --- a/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py @@ -186,12 +186,12 @@ class TestOneHotOpApi(unittest.TestCase): label = np.array([np.random.randint(0, depth - 1) for i in range(6)]).reshape([6, 1]) with fluid.dygraph.guard(): - one_hot_label = fluid.input.one_hot( + one_hot_label = fluid.one_hot( input=fluid.dygraph.to_variable(label), depth=depth) def _run(self, depth): label = fluid.layers.data(name="label", shape=[1], dtype="int64") - one_hot_label = fluid.input.one_hot(input=label, depth=depth) + one_hot_label = fluid.one_hot(input=label, depth=depth) place = fluid.CPUPlace() label_data = np.array([np.random.randint(0, 10 - 1)