From aacd16dbb427275913ed01e0c6de7cf0475f1fe3 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 28 Oct 2019 23:11:23 +0800 Subject: [PATCH] add pyramid_hash_op (#20698) --- paddle/fluid/operators/CMakeLists.txt | 12 +- paddle/fluid/operators/math/bloomfilter.h | 197 ++++++++ paddle/fluid/operators/pyramid_hash_op.cc | 445 ++++++++++++++++++ paddle/fluid/operators/search_compute.h | 17 +- python/paddle/fluid/contrib/layers/nn.py | 97 ++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../tests/unittests/test_pyramid_hash_op.py | 61 +++ 7 files changed, 827 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/math/bloomfilter.h create mode 100644 paddle/fluid/operators/pyramid_hash_op.cc create mode 100644 python/paddle/fluid/tests/unittests/test_pyramid_hash_op.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 07931de4ffd..652845e1aa6 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -48,14 +48,17 @@ if (WITH_DISTRIBUTE) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) endif() -SET(OP_ONLY_MKL "") +SET(OP_MKL_DEPS "") if (NOT WITH_MKL OR NOT WITH_AVX) - SET(OP_ONLY_MKL ${OP_ONLY_MKL} match_matrix_tensor_op) - SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op) + SET(OP_MKL_DEPS ${OP_MKL_DEPS} match_matrix_tensor_op) + SET(OP_MKL_DEPS ${OP_MKL_DEPS} var_conv_2d_op) +endif() +if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32) + SET(OP_MKL_DEPS ${OP_MKL_DEPS} pyramid_hash_op) endif() register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op - sync_batch_norm_op multihead_matmul_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) + sync_batch_norm_op multihead_matmul_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above @@ -87,6 +90,7 @@ if (WITH_DGC) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc) endif() + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) diff --git a/paddle/fluid/operators/math/bloomfilter.h b/paddle/fluid/operators/math/bloomfilter.h new file mode 100644 index 00000000000..6b36251aa7f --- /dev/null +++ b/paddle/fluid/operators/math/bloomfilter.h @@ -0,0 +1,197 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#define BLOOMFILTER_MAGIC_NUM_NEW 17070416 + +#include +#include + +#include +#include + +#include + +namespace paddle { +namespace operators { +namespace math { + +#pragma pack(4) +struct bloomfilter { + uint64_t magic_num; + uint64_t m; + uint64_t k; + uint64_t count; + unsigned char bit_vector[1]; +}; +int bloomfilter_get(const struct bloomfilter *bloomfilter, const void *key, + size_t len); +int bloomfilter_check(struct bloomfilter *filter); + +#define bit_get(v, n) ((v)[(n) >> 3] & (0x1 << (0x7 - ((n)&0x7)))) +#define ROTL64(x, r) (((x) << (r)) | ((x) >> (64 - (r)))) +#define BIG_CONSTANT(x) (x##LLU) + +uint64_t fmix64(uint64_t k) { + k ^= k >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + +void murmurhash3_x64_128(const void *key, const int len, const uint32_t seed, + void *out) { + const uint8_t *data = (const uint8_t *)key; + const int nblocks = len / 16; + + uint64_t h1 = seed; + uint64_t h2 = seed; + int i = 0; + + const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); + const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + const uint64_t *blocks = (const uint64_t *)(data); + + uint64_t k1; + uint64_t k2; + + for (i = 0; i < nblocks; i++) { + k1 = blocks[i * 2 + 0]; + k2 = blocks[i * 2 + 1]; + + k1 *= c1; + k1 = ROTL64(k1, 31); + k1 *= c2; + h1 ^= k1; + + h1 = ROTL64(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + + k2 *= c2; + k2 = ROTL64(k2, 33); + k2 *= c1; + h2 ^= k2; + + h2 = ROTL64(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + + //---------- + // tail + + const uint8_t *tail = (const uint8_t *)(data + nblocks * 16); + uint64_t nk1 = 0; + uint64_t nk2 = 0; + // no break here!!! + switch (len & 15) { + case 15: + nk2 ^= ((uint64_t)tail[14]) << 48; + case 14: + nk2 ^= ((uint64_t)tail[13]) << 40; + case 13: + nk2 ^= ((uint64_t)tail[12]) << 32; + case 12: + nk2 ^= ((uint64_t)tail[11]) << 24; + case 11: + nk2 ^= ((uint64_t)tail[10]) << 16; + case 10: + nk2 ^= ((uint64_t)tail[9]) << 8; + case 9: + nk2 ^= ((uint64_t)tail[8]) << 0; + nk2 *= c2; + nk2 = ROTL64(nk2, 33); + nk2 *= c1; + h2 ^= nk2; + case 8: + nk1 ^= ((uint64_t)tail[7]) << 56; + case 7: + nk1 ^= ((uint64_t)tail[6]) << 48; + case 6: + nk1 ^= ((uint64_t)tail[5]) << 40; + case 5: + nk1 ^= ((uint64_t)tail[4]) << 32; + case 4: + nk1 ^= ((uint64_t)tail[3]) << 24; + case 3: + nk1 ^= ((uint64_t)tail[2]) << 16; + case 2: + nk1 ^= ((uint64_t)tail[1]) << 8; + case 1: + nk1 ^= ((uint64_t)tail[0]) << 0; + nk1 *= c1; + nk1 = ROTL64(nk1, 31); + nk1 *= c2; + h1 ^= nk1; + } + + //---------- + // finalization + + h1 ^= len; + h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + // ((uint64_t *)out)[0] = h1; + reinterpret_cast(out)[0] = h1; + // ((uint64_t *)out)[1] = h2; + reinterpret_cast(out)[1] = h2; +} + +int bloomfilter_check(struct bloomfilter *filter) { + if (filter->magic_num == BLOOMFILTER_MAGIC_NUM_NEW) { + return 1; + } else { + fprintf(stderr, "error magic_num %ld\n", filter->magic_num); + return 0; + } +} + +int bloomfilter_get(const struct bloomfilter *bloomfilter, const void *key, + size_t len) { + uint32_t i; + uint64_t result[2]; + + for (i = 0; i < bloomfilter->k; i++) { + murmurhash3_x64_128(key, len, i, &result); + result[0] %= bloomfilter->m; + result[1] %= bloomfilter->m; + if (!bit_get(bloomfilter->bit_vector, result[0])) { + return 0; + } + if (!bit_get(bloomfilter->bit_vector, result[1])) { + return 0; + } + } + return 1; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc new file mode 100644 index 00000000000..e0b63d2ead7 --- /dev/null +++ b/paddle/fluid/operators/pyramid_hash_op.cc @@ -0,0 +1,445 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/search_compute.h" + +extern "C" { +#include "math/bloomfilter.h" +// void* memcpy1(void* dst, void* src, uint32_t length); +} + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +class PyramidHashOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "X (Tensor, MUST be Tensor) Input variable which " + "should contain lod information."); + AddInput("W", "W (Tensor)"); + AddInput("WhiteList", "WhiteList (Tensor)"); + AddInput("BlackList", "BlackList (Tensor)"); + AddAttr("num_emb", "num_emb").SetDefault(0).EqualGreaterThan(0); + AddAttr("space_len", "space_len").SetDefault(0).EqualGreaterThan(0); + AddAttr("pyramid_layer", "pyramid_layer (must be >= 2)") + .SetDefault(2) + .EqualGreaterThan(2); + AddAttr("rand_len", "rand_len").SetDefault(0).EqualGreaterThan(0); + AddAttr("drop_out_percent", "drop_out_percent") + .SetDefault(0) + .EqualGreaterThan(0); + AddAttr("is_training", "is_training") + .SetDefault(0) + .EqualGreaterThan(0); + AddAttr("use_filter", "use_filter").SetDefault(true); + AddAttr("white_list_len", "white_list_len") + .SetDefault(0) + .EqualGreaterThan(0); + AddAttr("black_list_len", "black_list_len") + .SetDefault(0) + .EqualGreaterThan(0); + AddAttr("seed", "seed").SetDefault(0).EqualGreaterThan(0); + AddAttr("lr", "learning rate").SetDefault(0.0).EqualGreaterThan(0.0); + + AddOutput("Out", "Out (Tensor, default Tensor) Output variable"); + AddOutput("DropPos", "Out (Tensor, Tensor) Output variable"); + AddOutput("X_Temp_Out", "Out (Tensor, Tensor) Output variable") + .AsIntermediate(); + + AddComment(R"DOC( + PyramidHash + + NOTE: only support 'float32' data type now. + + )DOC"); + } +}; + +class PyramidHashOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "X(Input) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "W(Input) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Out(Output) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("DropPos"), true, + "DropPos(TMP Output) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The rank of X(Input) should be 2."); + + auto w_dims = ctx->GetInputDim("W"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "W should be 2-D tensor"); + + int space_len = ctx->Attrs().Get("space_len"); + int rand_len = ctx->Attrs().Get("rand_len"); + + PADDLE_ENFORCE_EQ(w_dims[0], space_len + rand_len, + "w_dims[0] should be equal to (space_len + rand_len)"); + PADDLE_ENFORCE_EQ(w_dims[1], 1, "w_dims[1] should be equal to 1"); + + int num_emb = ctx->Attrs().Get("num_emb"); + PADDLE_ENFORCE_EQ(num_emb % rand_len, 0, + "random length should mod embedding size"); + + int white_list_len = ctx->Attrs().Get("white_list_len"); + if (white_list_len > 0) { + PADDLE_ENFORCE_EQ( + ctx->HasInput("WhiteList"), true, + "WhiteList(Input) should not be null when white_list_len > 0"); + auto wl_dims = ctx->GetInputDim("WhiteList"); + PADDLE_ENFORCE_EQ(wl_dims.size(), 2, "WhiteList should be 2-D tensor"); + PADDLE_ENFORCE_EQ(wl_dims[0], white_list_len, + "wl_dims[0] should be equal to white_list_len"); + PADDLE_ENFORCE_EQ(wl_dims[1], 1, "wl_dims[1] should be equal to 1"); + } + + int black_list_len = ctx->Attrs().Get("black_list_len"); + if (black_list_len > 0) { + PADDLE_ENFORCE_EQ( + ctx->HasInput("BlackList"), true, + "BlackList(Input) should not be null when black_list_len > 0"); + auto bl_dims = ctx->GetInputDim("BlackList"); + PADDLE_ENFORCE_EQ(bl_dims.size(), 2, "BlackList should be 2-D tensor"); + PADDLE_ENFORCE_EQ(bl_dims[0], black_list_len, + "bl_dims[0] should be equal to black_list_len"); + PADDLE_ENFORCE_EQ(bl_dims[1], 1, "bl_dims[1] should be equal to 1"); + } + + if (ctx->IsRuntime()) { + // something to do in runtime. + } else { + // compile time + ctx->SetOutputDim("Out", framework::make_ddim({-1, num_emb})); + ctx->SetOutputDim("X_Temp_Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "W"), ctx.GetPlace()); + } +}; + +template +class CPUPyramidHashOPKernel : public framework::OpKernel { + public: + bool should_use_term(math::bloomfilter* _filter, + math::bloomfilter* _black_filter, const T* word_repr, + int len) const { + return (!_filter || + 1 == math::bloomfilter_get(_filter, word_repr, len * sizeof(T))) && + (!_black_filter || + 0 == math::bloomfilter_get(_black_filter, word_repr, + len * sizeof(T))); + } + + void hash_embedding_ff(const T* hash_id, int len, T* top_pos, + const T* weights, int _num_emb, int _rand_len, + int _space_len) const { + for (unsigned int j = 0; j != _num_emb; j += _rand_len) { + unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len; + memcpy(top_pos + j, const_cast(weights + pos), + _rand_len * sizeof(T)); + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* bottom = ctx.Input("X"); + auto* _blobs_0 = ctx.Input("W"); + auto* _blobs_1 = ctx.Input("WhiteList"); + auto* _blobs_2 = ctx.Input("BlackList"); + auto* top = ctx.Output("Out"); + auto* drop_pos = ctx.Output("DropPos"); + + int _num_emb = ctx.Attr("num_emb"); + bool use_filter = ctx.Attr("use_filter"); + int white_list_len = ctx.Attr("white_list_len"); + int black_list_len = ctx.Attr("black_list_len"); + int _pyramid_layer = ctx.Attr("pyramid_layer"); + int _is_training = ctx.Attr("is_training"); + int seed = ctx.Attr("seed"); + unsigned int _seed = (unsigned int)seed; + int _rand_len = ctx.Attr("rand_len"); + int _space_len = ctx.Attr("space_len"); + float _drop_out_percent = ctx.Attr("drop_out_percent"); + + const auto& offset = bottom->lod()[0]; + const auto* bottom_data_ori = bottom->data(); + auto* buff = ctx.Output("X_Temp_Out"); + buff->Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]})); + T* bottom_data = buff->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < bottom->dims()[0]; i++) { + bottom_data[i] = bottom_data_ori[i]; + } + + const auto* weights = _blobs_0->data(); + + std::vector top_offset; + top_offset.resize(offset.size()); + top_offset[0] = 0; + + math::bloomfilter* _filter = NULL; + math::bloomfilter* _black_filter = NULL; + if (use_filter) { + if (white_list_len != 0) { + _filter = (math::bloomfilter*)_blobs_1->data(); + PADDLE_ENFORCE_EQ(math::bloomfilter_check(_filter), 1, + "white filter not load"); + } + if (black_list_len != 0) { + _black_filter = (math::bloomfilter*)_blobs_2->data(); + PADDLE_ENFORCE_EQ(math::bloomfilter_check(_black_filter), 1, + "black filter not load"); + } + } + + drop_pos->Resize(framework::make_ddim( + {bottom->dims()[0] * bottom->dims()[1] * _pyramid_layer, 1})); + std::vector drop_pos_offset; + drop_pos_offset.resize(offset.size()); + drop_pos_offset[0] = 0; + int* iter = drop_pos->mutable_data(ctx.GetPlace()); + int* iter_end = iter; + + for (int i = 0; i < top_offset.size() - 1; ++i) { + int w = offset[i + 1] - offset[i]; + int nsentense_with_pyramid = 0; + if (w < 2) { + nsentense_with_pyramid = 0; + } else { + for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { + for (int l = 0; l < w - ilayer; ++l) { + if (should_use_term(_filter, _black_filter, + (const T*)(bottom_data + offset[i] + l), + ilayer + 1)) { + if (_is_training != 0) { + unsigned int rand_val = rand_r(&_seed); + T rate = static_cast(rand_val) / (RAND_MAX); + *(iter_end++) = (rate < _drop_out_percent ? 0 : 1); + } else { + *(iter_end++) = 1; + } + } else { + *(iter_end++) = 0; + } + } + } + nsentense_with_pyramid = std::count(iter, iter_end, 1); + iter = iter_end; + } + drop_pos_offset[i + 1] = drop_pos_offset[i] + nsentense_with_pyramid; + top_offset[i + 1] = + top_offset[i] + + (nsentense_with_pyramid == 0 ? 1 : nsentense_with_pyramid); + } + + int top_l = top_offset[top_offset.size() - 1]; + + framework::LoD top_lod; + top_lod.push_back(top_offset); + top->set_lod(top_lod); + top->Resize(framework::make_ddim({top_l, _num_emb})); + auto* top_data = top->mutable_data(ctx.GetPlace()); + + framework::LoD drop_pos_lod; + drop_pos_lod.push_back(drop_pos_offset); + drop_pos->set_lod(drop_pos_lod); + + iter = drop_pos->mutable_data(ctx.GetPlace()); + int top_counter = 0; + for (int i = 0; i < offset.size() - 1; ++i) { + int w_drop = drop_pos_offset[i + 1] - drop_pos_offset[i]; + int w = offset[i + 1] - offset[i]; + if (w_drop == 0) { + if (w >= 2) { + for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; + ++ilayer) { + for (int l = 0; l < w - ilayer; ++l) { + iter++; + } + } + } + auto* top_pos = top_data + top_counter++ * _num_emb; + memset(top_pos, 0, _num_emb * sizeof(T)); + continue; + } + if (w >= 2) { + for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { + for (int l = 0; l < w - ilayer; ++l) { + if (*(iter++) == 0) { + // do nothing + } else { + auto* top_pos = top_data + top_counter++ * _num_emb; + hash_embedding_ff((const T*)(bottom_data + offset[i] + l), + ilayer + 1, top_pos, weights, _num_emb, + _rand_len, _space_len); + } + } + } + } + } + if (iter != iter_end) { + exit(1); + } + if (_is_training == 0) { + avx_axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1], + _drop_out_percent); + } + } +}; + +class PyramidHashOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "Input(W) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), true, + "Input(DropPos) should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) of PyramidHashGradOp should not be null."); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "W"), ctx.GetPlace()); + } +}; + +class PyramidHashGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("pyramid_hash_grad"); + op_desc_ptr->SetInput("X", Input("X")); + op_desc_ptr->SetInput("W", Input("W")); + op_desc_ptr->SetInput("DropPos", Output("DropPos")); + + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + +template +class CPUPyramidHashOPGradKernel : public framework::OpKernel { + public: + void hash_embedding_bp(const T* hash_id, int len, const T* top_pos, + T* weights, T mlr, int _num_emb, int _rand_len, + int _space_len) const { + for (unsigned int j = 0; j != _num_emb; j += _rand_len) { + unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len; + avx_axpy(top_pos + j, weights + pos, _rand_len, mlr); + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* bottom = ctx.Input("X"); + auto* _blobs = ctx.Input("W"); + auto* drop_pos = ctx.Input("DropPos"); + auto* top = ctx.Input(framework::GradVarName("Out")); + + int _num_emb = ctx.Attr("num_emb"); + float _lr = ctx.Attr("lr"); + int _rand_len = ctx.Attr("rand_len"); + int _space_len = ctx.Attr("space_len"); + int _pyramid_layer = ctx.Attr("pyramid_layer"); + + const auto* bottom_data_ori = bottom->data(); + Tensor buff; + buff.Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]})); + T* bottom_data = buff.mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < bottom->dims()[0]; i++) { + bottom_data[i] = bottom_data_ori[i]; + } + + int _slot_len = bottom->dims()[0]; + if (_slot_len == bottom->lod()[0].size() - 1 && + std::count(bottom_data, bottom_data + _slot_len, -1) == _slot_len) { + return; + } + + auto& offset = bottom->lod()[0]; + auto& drop_pos_offset = drop_pos->lod()[0]; + + const auto* top_diff = top->data(); + T* weights = const_cast(_blobs->data()); + T mlr = -1.0 * _lr; + + const int* iter = drop_pos->data(); + int top_counter = 0; + for (int i = 0; i < offset.size() - 1; ++i) { + int w = offset[i + 1] - offset[i]; + int w_drop = drop_pos_offset[i + 1] - drop_pos_offset[i]; + if (w_drop == 0) { + top_counter++; + } + if (w > 1) { + for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { + for (int l = 0; l < w - ilayer; ++l) { + if (*(iter++) == 0) { + // do nothing + } else { + const T* top_pos = top_diff + top_counter++ * _num_emb; + hash_embedding_bp((const T*)(bottom_data + offset[i] + l), + ilayer + 1, top_pos, weights, mlr, _num_emb, + _rand_len, _space_len); + } + } + } + } else { + // do nothing + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plt = paddle::platform; +namespace frm = paddle::framework; +REGISTER_OPERATOR(pyramid_hash, ops::PyramidHashOP, ops::PyramidHashOpMaker, + ops::PyramidHashGradOpMaker); +REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad); + +REGISTER_OP_CPU_KERNEL( + pyramid_hash, ops::CPUPyramidHashOPKernel); +REGISTER_OP_CPU_KERNEL( + pyramid_hash_grad, + ops::CPUPyramidHashOPGradKernel); diff --git a/paddle/fluid/operators/search_compute.h b/paddle/fluid/operators/search_compute.h index 995c85d9ff4..a2bf3a6bed0 100644 --- a/paddle/fluid/operators/search_compute.h +++ b/paddle/fluid/operators/search_compute.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/dynload/mklml.h" namespace paddle { namespace operators { @@ -103,5 +102,21 @@ inline void avx_axpy(const T* x, T* y, size_t len, const T alpha) { } } +template +inline void avx_axpy_noadd(const T* x, T* y, size_t len, const T alpha) { + unsigned int jjj, lll; + jjj = lll = 0; + + lll = len & ~AVX_CUT_LEN_MASK; + __m256x mm_alpha = _mm256_broadcast_sx(&alpha); + for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { + _mm256_store_px(y + jjj, _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))); + } + + for (; jjj < len; jjj++) { + y[jjj] = alpha * x[jjj]; + } +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 578ef1185e4..86e22444cb8 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -32,6 +32,7 @@ __all__ = [ 'tree_conv', 'fused_embedding_seq_pool', 'multiclass_nms2', + 'search_pyramid_hash', ] @@ -625,3 +626,99 @@ def multiclass_nms2(bboxes, if return_index: return output, index return output + + +def search_pyramid_hash(input, + num_emb, + space_len, + pyramid_layer, + rand_len, + drop_out_percent, + is_training, + use_filter, + white_list_len, + black_list_len, + seed, + lr, + param_attr=None, + param_attr_wl=None, + param_attr_bl=None, + name=None, + dtype='float32'): + """ + **Pyramid hash embedding** + + Args: + input (Variable): LoDTensor Variable contained the IDs' information. + num_emb (int): The embedding size of output. + space_len (int): The length of pyramid hash embedding space. + pyramid_layer (int): The number of pyramid layers. It should be greater than 2. + rand_len (int): The minimum length of pyramid hash cell. + drop_out_percent (float): The probability of dropping out the input token randomly. + It should satisfy: [0., 1.] + is_training (bool): Whether in training or testing phrase. + use_filter(bool): If set True, the white filter and black filter should be given by + :attr:`param_attr_wl` and :attr:`param_attr_bl` . + white_list_len(int): If set :math:`white_list_len>0` , white filter with shape [white_list_len, 1] + should be provided by param_attr_wl. + black_list_len(int): If set :math:`black_list_len>0` , black filter with shape [black_list_len, 1] + should be provided by param_attr_bl. + seed(int): The number of random seed. + lr(float): The learning rate of weight created by :attr:`param_attr` with shape [space_len+rand_len, 1] + in this layer. + param_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the + default weight parameter property is used. See usage for details in :ref:`api_fluid_ParamAttr` . + param_attr_wl(ParamAttr): Specified parameters of white filter. + param_attr_bl(ParamAttr): Specified parameters of black filter. + name(str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + dtype(str): The data type of output variable, float32. + Returns: + Variable: LoDTensor of pyramid hash embedding. + """ + helper = LayerHelper('search_pyramid_hash', **locals()) + + w_shape = [space_len + rand_len, 1] + w = helper.create_parameter( + attr=param_attr, shape=w_shape, dtype=dtype, is_bias=False) + w.stop_gradient = True + + input_vars = {'X': input, 'W': w} + if white_list_len > 0: + wl_shape = [white_list_len, 1] + white_list = helper.create_parameter( + attr=param_attr_wl, shape=wl_shape, dtype=dtype, is_bias=False) + white_list.stop_gradient = True + input_vars['WhiteList'] = white_list + + if black_list_len >= 0: + bl_shape = [black_list_len, 1] + black_list = helper.create_parameter( + attr=param_attr_bl, shape=bl_shape, dtype=dtype, is_bias=False) + black_list.stop_gradient = True + input_vars['BlackList'] = black_list + + res = helper.create_variable_for_type_inference(dtype) + drop_pos = helper.create_variable_for_type_inference(dtype) + x_temp_out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='pyramid_hash', + inputs=input_vars, + outputs={"Out": res, + "X_Temp_Out": x_temp_out, + 'DropPos': drop_pos}, + attrs={ + 'num_emb': num_emb, + 'space_len': space_len, + 'pyramid_layer': pyramid_layer, + 'rand_len': rand_len, + 'drop_out_percent': drop_out_percent, + 'is_training': is_training, + 'use_filter': use_filter, + 'white_list_len': white_list_len, + 'black_list_len': black_list_len, + 'seed': seed, + 'lr': lr, + }) + + return res diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f77f018f65b..10735ab030a 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -74,6 +74,9 @@ if(NOT WITH_MKL OR NOT WITH_AVX) list(REMOVE_ITEM TEST_OPS test_match_matrix_tensor_op) list(REMOVE_ITEM TEST_OPS test_var_conv_2d) endif() +if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32) + list(REMOVE_ITEM TEST_OPS test_pyramid_hash_op) +endif() if(WITH_GPU OR NOT WITH_MKLML) # matmul with multiple heads need MKL support diff --git a/python/paddle/fluid/tests/unittests/test_pyramid_hash_op.py b/python/paddle/fluid/tests/unittests/test_pyramid_hash_op.py new file mode 100644 index 00000000000..c1435d8781d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pyramid_hash_op.py @@ -0,0 +1,61 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid + + +class TestPyramidHashOpApi(unittest.TestCase): + def test_api(self): + num_voc = 128 + embed_dim = 64 + x_shape, x_lod = [16, 10], [[3, 5, 2, 6]] + x = fluid.data(name='x', shape=x_shape, dtype='int32', lod_level=1) + hash_embd = fluid.contrib.search_pyramid_hash( + input=x, + num_emb=embed_dim, + space_len=num_voc * embed_dim, + pyramid_layer=4, + rand_len=16, + drop_out_percent=0.5, + is_training=True, + use_filter=False, + white_list_len=6400, + black_list_len=2800, + seed=3, + lr=0.002, + param_attr=fluid.ParamAttr( + name="PyramidHash_emb_0", + learning_rate=0, ), + param_attr_wl=fluid.ParamAttr( + name="Filter", + learning_rate=0, ), + param_attr_bl=None, + name=None, ) + + place = fluid.CPUPlace() + x_tensor = fluid.create_lod_tensor( + np.random.randint(0, num_voc, x_shape).astype('int32'), x_lod, + place) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(feed={'x': x_tensor}, + fetch_list=[hash_embd], + return_numpy=False) + + +if __name__ == "__main__": + unittest.main() -- GitLab