pyramid_hash_op.cc 21.9 KB
Newer Older
A
Aurelius84 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <xxhash.h>
16

A
Aurelius84 已提交
17 18
#include <algorithm>
#include <cmath>
19

20
#include "paddle/fluid/framework/convert_utils.h"
A
Aurelius84 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/search_compute.h"

extern "C" {
#include "math/bloomfilter.h"
}

namespace paddle {
namespace operators {

using LoD = framework::LoD;

class PyramidHashOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "X (Tensor, MUST be Tensor<!!!_int32_!!!>) Input variable which "
             "should contain lod information.");
    AddInput("W", "W (Tensor)");
    AddInput("WhiteList", "WhiteList (Tensor)");
    AddInput("BlackList", "BlackList (Tensor)");
    AddAttr<int>("num_emb", "num_emb").SetDefault(0).EqualGreaterThan(0);
    AddAttr<int>("space_len", "space_len").SetDefault(0).EqualGreaterThan(0);
    AddAttr<int>("pyramid_layer", "pyramid_layer (must be >= 2)")
        .SetDefault(2)
        .EqualGreaterThan(2);
    AddAttr<int>("rand_len", "rand_len").SetDefault(0).EqualGreaterThan(0);
    AddAttr<float>("drop_out_percent", "drop_out_percent")
        .SetDefault(0)
        .EqualGreaterThan(0);
    AddAttr<int>("is_training", "is_training")
        .SetDefault(0)
        .EqualGreaterThan(0);
    AddAttr<bool>("use_filter", "use_filter").SetDefault(true);
    AddAttr<int>("white_list_len", "white_list_len")
        .SetDefault(0)
        .EqualGreaterThan(0);
    AddAttr<int>("black_list_len", "black_list_len")
        .SetDefault(0)
        .EqualGreaterThan(0);
    AddAttr<int>("seed", "seed").SetDefault(0).EqualGreaterThan(0);
    AddAttr<float>("lr", "learning rate").SetDefault(0.0).EqualGreaterThan(0.0);
C
Chengmo 已提交
63 64 65 66 67 68
    AddAttr<std::string>(
        "distribute_update_vars",
        "['PyramidHash_emb_0','Filter']"
        "Decided which params should be updated in distribute training. "
        "Used in Distribute Transpiler to create a trainer/server program.")
        .SetDefault("");
A
Aurelius84 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    AddOutput("Out", "Out (Tensor, default Tensor<float>) Output variable");
    AddOutput("DropPos", "Out (Tensor, Tensor<int>) Output variable");
    AddOutput("X_Temp_Out", "Out (Tensor, Tensor<int>) Output variable")
        .AsIntermediate();

    AddComment(R"DOC(
      PyramidHash

      NOTE: only support 'float32' data type now.

    )DOC");
  }
};

class PyramidHashOP : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
88
    PADDLE_ENFORCE_EQ(
89 90
        ctx->HasInput("X"),
        true,
91 92
        platform::errors::NotFound("Input(X) of PyramidHashOP is not found."));
    PADDLE_ENFORCE_EQ(
93 94
        ctx->HasInput("W"),
        true,
95
        platform::errors::NotFound("Input(W) of PyramidHashOP is not found."));
96 97
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"),
                      true,
98 99
                      platform::errors::NotFound(
                          "Output(Out) of PyramidHashOP is not found."));
100 101
    PADDLE_ENFORCE_EQ(ctx->HasOutput("DropPos"),
                      true,
102 103
                      platform::errors::NotFound(
                          "Output(DropPos) of PyramidHashOP is not found."));
A
Aurelius84 已提交
104 105

    auto x_dims = ctx->GetInputDim("X");
106 107
    PADDLE_ENFORCE_EQ(x_dims.size(),
                      2,
108 109 110 111
                      platform::errors::InvalidArgument(
                          "The rank of Input(X) of PyramidHashOP is invalid. "
                          "It should be 2, but got %d",
                          x_dims.size()));
A
Aurelius84 已提交
112 113

    auto w_dims = ctx->GetInputDim("W");
114 115
    PADDLE_ENFORCE_EQ(w_dims.size(),
                      2,
116 117 118 119
                      platform::errors::InvalidArgument(
                          "The rank of Input(W) of PyramidHashOP is invalid. "
                          "It should be 2, but got %d",
                          w_dims.size()));
A
Aurelius84 已提交
120 121 122 123

    int space_len = ctx->Attrs().Get<int>("space_len");
    int rand_len = ctx->Attrs().Get<int>("rand_len");

124
    PADDLE_ENFORCE_EQ(
125 126
        w_dims[0],
        space_len + rand_len,
127 128 129
        platform::errors::InvalidArgument(
            "The first dimension of Input(W) of PyramidHashOP is invalid. "
            "It should be space_len + rand_len, but now %d != %d + %d",
130 131 132
            w_dims[0],
            space_len,
            rand_len));
133
    PADDLE_ENFORCE_EQ(
134 135
        w_dims[1],
        1,
136 137 138 139
        platform::errors::InvalidArgument(
            "The second dimension of Input(W) of PyramidHashOP is invalid."
            " It should be 1, but got %d",
            w_dims[1]));
A
Aurelius84 已提交
140 141

    int num_emb = ctx->Attrs().Get<int>("num_emb");
142
    PADDLE_ENFORCE_EQ(
143 144
        num_emb % rand_len,
        0,
145 146 147
        platform::errors::InvalidArgument(
            "The PyramidHashOP's Attr(num_emb) should mod Attr(rand_len), "
            "but num_emb is %d, rand_len is %d",
148 149
            num_emb,
            rand_len));
A
Aurelius84 已提交
150 151 152 153

    int white_list_len = ctx->Attrs().Get<int>("white_list_len");
    if (white_list_len > 0) {
      PADDLE_ENFORCE_EQ(
154 155
          ctx->HasInput("WhiteList"),
          true,
156 157
          platform::errors::NotFound("Input(WhiteList) of PyramidHashOP is not "
                                     "found but white_list_len > 0."));
A
Aurelius84 已提交
158
      auto wl_dims = ctx->GetInputDim("WhiteList");
159
      PADDLE_ENFORCE_EQ(
160 161
          wl_dims.size(),
          2,
162 163 164 165
          platform::errors::InvalidArgument(
              "The rank of Input(WhiteList) of PyramidHashOP is invalid."
              " It should be 2, but got %d",
              wl_dims.size()));
166 167
      PADDLE_ENFORCE_EQ(wl_dims[0],
                        white_list_len,
168 169 170 171 172
                        platform::errors::InvalidArgument(
                            "The first dimension of Input(WhiteList) of "
                            "PyramidHashOP is invalid."
                            " It should be equal to Attr(white_list_len) "
                            ", but first dimension is %d, white_list_len is %d",
173 174 175 176
                            wl_dims[0],
                            white_list_len));
      PADDLE_ENFORCE_EQ(wl_dims[1],
                        1,
177 178 179 180 181
                        platform::errors::InvalidArgument(
                            "The second dimension of Input(WhiteList) of "
                            "PyramidHashOP is invalid."
                            " It should be 1, but got %d",
                            wl_dims[1]));
A
Aurelius84 已提交
182 183 184 185 186
    }

    int black_list_len = ctx->Attrs().Get<int>("black_list_len");
    if (black_list_len > 0) {
      PADDLE_ENFORCE_EQ(
187 188
          ctx->HasInput("BlackList"),
          true,
189 190
          platform::errors::NotFound("Input(BlackList) of PyramidHashOP is not "
                                     "found but black_list_len > 0."));
A
Aurelius84 已提交
191
      auto bl_dims = ctx->GetInputDim("BlackList");
192
      PADDLE_ENFORCE_EQ(
193 194
          bl_dims.size(),
          2,
195 196 197 198
          platform::errors::InvalidArgument(
              "The rank of Input(BlackList) of PyramidHashOP is invalid."
              " It should be 2, but got %d",
              bl_dims.size()));
199 200
      PADDLE_ENFORCE_EQ(bl_dims[0],
                        black_list_len,
201 202 203 204 205
                        platform::errors::InvalidArgument(
                            "The first dimension of Input(BlackList) of "
                            "PyramidHashOP is invalid."
                            " It should be equal to Attr(black_list_len)"
                            ", but first dimension is %d, black_list_len is %d",
206 207 208 209
                            bl_dims[0],
                            black_list_len));
      PADDLE_ENFORCE_EQ(bl_dims[1],
                        1,
210 211 212 213 214
                        platform::errors::InvalidArgument(
                            "The second dimension of Input(BlackList) of "
                            "PyramidHashOP is invalid."
                            " It should be 1, but got %d",
                            bl_dims[1]));
A
Aurelius84 已提交
215 216 217 218 219 220
    }

    if (ctx->IsRuntime()) {
      // something to do in runtime.
    } else {
      // compile time
221
      ctx->SetOutputDim("Out", phi::make_ddim({-1, num_emb}));
A
Aurelius84 已提交
222 223 224 225 226 227
      ctx->SetOutputDim("X_Temp_Out", x_dims);
      ctx->ShareLoD("X", /*->*/ "Out");
    }
  }

 protected:
228
  phi::KernelKey GetExpectedKernelType(
A
Aurelius84 已提交
229
      const framework::ExecutionContext& ctx) const override {
230 231
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"),
                          ctx.GetPlace());
A
Aurelius84 已提交
232 233 234 235 236 237 238
  }
};

template <typename DeviceContext, typename T>
class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
 public:
  bool should_use_term(math::bloomfilter* _filter,
239 240
                       math::bloomfilter* _black_filter,
                       const float* word_repr,
A
Aurelius84 已提交
241
                       int len) const {
242 243
    return (!_filter || 1 == math::bloomfilter_get(
                                 _filter, word_repr, len * sizeof(float))) &&
A
Aurelius84 已提交
244
           (!_black_filter ||
245 246
            0 == math::bloomfilter_get(
                     _black_filter, word_repr, len * sizeof(float)));
A
Aurelius84 已提交
247 248
  }

249 250 251 252 253 254
  void hash_embedding_ff(const float* hash_id,
                         int len,
                         T* top_pos,
                         const T* weights,
                         int _num_emb,
                         int _rand_len,
A
Aurelius84 已提交
255
                         int _space_len) const {
256 257 258
    unsigned int pos1 = XXH32(hash_id, len * sizeof(float), 0) % _space_len;
    unsigned int pos2 =
        XXH32(hash_id, len * sizeof(float), _rand_len) % _space_len;
259

260
    for (int j = 0; j != _num_emb; j += _rand_len) {
261 262 263
      if (j + _rand_len < _num_emb) {
        __builtin_prefetch(weights + pos2);
        __builtin_prefetch(top_pos + j + _rand_len);
264
      }
265 266

      unsigned int pos3 =
267
          XXH32(hash_id, len * sizeof(float), j + 2 * _rand_len) % _space_len;
268 269
      memcpy(
          top_pos + j, const_cast<T*>(weights + pos1), _rand_len * sizeof(T));
270 271
      pos1 = pos2;
      pos2 = pos3;
A
Aurelius84 已提交
272 273 274 275
    }
  }

  void Compute(const framework::ExecutionContext& ctx) const override {
276
    auto* bottom = ctx.Input<phi::DenseTensor>("X");
277 278 279
    auto* _blobs_0 = ctx.Input<phi::DenseTensor>("W");
    auto* _blobs_1 = ctx.Input<phi::DenseTensor>("WhiteList");
    auto* _blobs_2 = ctx.Input<phi::DenseTensor>("BlackList");
280 281
    auto* top = ctx.Output<phi::DenseTensor>("Out");
    auto* drop_pos = ctx.Output<phi::DenseTensor>("DropPos");
A
Aurelius84 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296

    int _num_emb = ctx.Attr<int>("num_emb");
    bool use_filter = ctx.Attr<bool>("use_filter");
    int white_list_len = ctx.Attr<int>("white_list_len");
    int black_list_len = ctx.Attr<int>("black_list_len");
    int _pyramid_layer = ctx.Attr<int>("pyramid_layer");
    int _is_training = ctx.Attr<int>("is_training");
    int seed = ctx.Attr<int>("seed");
    unsigned int _seed = (unsigned int)seed;
    int _rand_len = ctx.Attr<int>("rand_len");
    int _space_len = ctx.Attr<int>("space_len");
    float _drop_out_percent = ctx.Attr<float>("drop_out_percent");

    const auto& offset = bottom->lod()[0];
    const auto* bottom_data_ori = bottom->data<int32_t>();
297
    auto* buff = ctx.Output<phi::DenseTensor>("X_Temp_Out");
298
    buff->Resize(phi::make_ddim({bottom->dims()[0], bottom->dims()[1]}));
299
    float* bottom_data = buff->mutable_data<float>(ctx.GetPlace());
300
    for (int i = 0; i < bottom->dims()[0]; i++) {
A
Aurelius84 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313
      bottom_data[i] = bottom_data_ori[i];
    }

    const auto* weights = _blobs_0->data<T>();

    std::vector<size_t> top_offset;
    top_offset.resize(offset.size());
    top_offset[0] = 0;

    math::bloomfilter* _filter = NULL;
    math::bloomfilter* _black_filter = NULL;
    if (use_filter) {
      if (white_list_len != 0) {
314
        _filter = (math::bloomfilter*)_blobs_1->data<float>();
315
        PADDLE_ENFORCE_EQ(
316 317
            math::bloomfilter_check(_filter),
            1,
318 319 320 321
            platform::errors::PreconditionNotMet(
                "The white filter is not loaded successfully, please make sure "
                "'white_list_len': %d is valid for Input(WhiteList).",
                white_list_len));
A
Aurelius84 已提交
322 323
      }
      if (black_list_len != 0) {
324
        _black_filter = (math::bloomfilter*)_blobs_2->data<float>();
325
        PADDLE_ENFORCE_EQ(
326 327
            math::bloomfilter_check(_black_filter),
            1,
328 329 330 331
            platform::errors::PreconditionNotMet(
                "The black filter is not loaded successfully, please make sure "
                "'black_list_len': %d is valid for Input(BlackList).",
                black_list_len));
A
Aurelius84 已提交
332 333 334
      }
    }

335
    drop_pos->Resize(phi::make_ddim(
A
Aurelius84 已提交
336 337 338 339 340 341 342
        {bottom->dims()[0] * bottom->dims()[1] * _pyramid_layer, 1}));
    std::vector<size_t> drop_pos_offset;
    drop_pos_offset.resize(offset.size());
    drop_pos_offset[0] = 0;
    int* iter = drop_pos->mutable_data<int>(ctx.GetPlace());
    int* iter_end = iter;

343
    for (size_t i = 0; i < top_offset.size() - 1; ++i) {
A
Aurelius84 已提交
344 345 346 347 348 349 350
      int w = offset[i + 1] - offset[i];
      int nsentense_with_pyramid = 0;
      if (w < 2) {
        nsentense_with_pyramid = 0;
      } else {
        for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) {
          for (int l = 0; l < w - ilayer; ++l) {
351 352
            if (should_use_term(_filter,
                                _black_filter,
353
                                (const float*)(bottom_data + offset[i] + l),
A
Aurelius84 已提交
354 355 356
                                ilayer + 1)) {
              if (_is_training != 0) {
                unsigned int rand_val = rand_r(&_seed);
357
                float rate = static_cast<float>(rand_val) / (RAND_MAX);
A
Aurelius84 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
                *(iter_end++) = (rate < _drop_out_percent ? 0 : 1);
              } else {
                *(iter_end++) = 1;
              }
            } else {
              *(iter_end++) = 0;
            }
          }
        }
        nsentense_with_pyramid = std::count(iter, iter_end, 1);
        iter = iter_end;
      }
      drop_pos_offset[i + 1] = drop_pos_offset[i] + nsentense_with_pyramid;
      top_offset[i + 1] =
          top_offset[i] +
          (nsentense_with_pyramid == 0 ? 1 : nsentense_with_pyramid);
    }

    int top_l = top_offset[top_offset.size() - 1];

    framework::LoD top_lod;
    top_lod.push_back(top_offset);
    top->set_lod(top_lod);
381
    top->Resize(phi::make_ddim({top_l, _num_emb}));
A
Aurelius84 已提交
382 383 384 385 386 387 388 389
    auto* top_data = top->mutable_data<T>(ctx.GetPlace());

    framework::LoD drop_pos_lod;
    drop_pos_lod.push_back(drop_pos_offset);
    drop_pos->set_lod(drop_pos_lod);

    iter = drop_pos->mutable_data<int>(ctx.GetPlace());
    int top_counter = 0;
390
    for (size_t i = 0; i < offset.size() - 1; ++i) {
A
Aurelius84 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
      int w_drop = drop_pos_offset[i + 1] - drop_pos_offset[i];
      int w = offset[i + 1] - offset[i];
      if (w_drop == 0) {
        if (w >= 2) {
          for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w;
               ++ilayer) {
            for (int l = 0; l < w - ilayer; ++l) {
              iter++;
            }
          }
        }
        auto* top_pos = top_data + top_counter++ * _num_emb;
        memset(top_pos, 0, _num_emb * sizeof(T));
        continue;
      }
      if (w >= 2) {
        for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) {
          for (int l = 0; l < w - ilayer; ++l) {
            if (*(iter++) == 0) {
              // do nothing
            } else {
              auto* top_pos = top_data + top_counter++ * _num_emb;
413
              hash_embedding_ff((const float*)(bottom_data + offset[i] + l),
414 415 416 417 418 419
                                ilayer + 1,
                                top_pos,
                                weights,
                                _num_emb,
                                _rand_len,
                                _space_len);
A
Aurelius84 已提交
420 421 422 423 424 425 426 427
            }
          }
        }
      }
    }
    if (iter != iter_end) {
      exit(1);
    }
428
    auto weight_type = framework::TransToProtoVarType(_blobs_0->dtype());
429
    if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) {
430 431 432
      axpy_noadd(top_data,
                 top_data,
                 top->dims()[0] * top->dims()[1],
433
                 _drop_out_percent);
A
Aurelius84 已提交
434 435 436 437 438 439 440 441 442
    }
  }
};

class PyramidHashOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
443 444
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
445 446
                      platform::errors::NotFound(
                          "Input(X) of PyramidHashOpGrad is not found."));
447 448
    PADDLE_ENFORCE_EQ(ctx->HasInput("W"),
                      true,
449 450
                      platform::errors::NotFound(
                          "Input(W) of PyramidHashOpGrad is not found."));
451 452
    PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"),
                      true,
453 454 455
                      platform::errors::NotFound(
                          "Input(DropPos) of PyramidHashOpGrad is not found."));
    PADDLE_ENFORCE_EQ(
456 457
        ctx->HasInput("X_Temp_Out"),
        true,
458 459
        platform::errors::NotFound(
            "Input(X_Temp_Out) of PyramidHashOpGrad is not found."));
A
Aurelius84 已提交
460
    PADDLE_ENFORCE_EQ(
461 462
        ctx->HasInput(framework::GradVarName("Out")),
        true,
463 464
        platform::errors::NotFound(
            "Input(Out@Grad) of PyramidHashOpGrad is not found."));
A
Aurelius84 已提交
465 466 467
  }

 protected:
468
  phi::KernelKey GetExpectedKernelType(
A
Aurelius84 已提交
469
      const framework::ExecutionContext& ctx) const override {
470 471
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"),
                          ctx.GetPlace());
A
Aurelius84 已提交
472 473 474
  }
};

H
hong 已提交
475 476
template <typename T>
class PyramidHashGradOpMaker : public framework::SingleGradOpMaker<T> {
A
Aurelius84 已提交
477
 public:
H
hong 已提交
478
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
A
Aurelius84 已提交
479 480

 protected:
481
  void Apply(GradOpPtr<T> op_desc_ptr) const override {
A
Aurelius84 已提交
482
    op_desc_ptr->SetType("pyramid_hash_grad");
H
hong 已提交
483 484 485
    op_desc_ptr->SetInput("X", this->Input("X"));
    op_desc_ptr->SetInput("W", this->Input("W"));
    op_desc_ptr->SetInput("DropPos", this->Output("DropPos"));
486
    op_desc_ptr->SetInput("X_Temp_Out", this->Output("X_Temp_Out"));
H
hong 已提交
487 488 489 490 491

    op_desc_ptr->SetInput(framework::GradVarName("Out"),
                          this->OutputGrad("Out"));
    op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op_desc_ptr->SetAttrMap(this->Attrs());
A
Aurelius84 已提交
492 493 494 495 496 497
  }
};

template <typename DeviceContext, typename T>
class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
 public:
498 499 500 501 502 503 504
  void hash_embedding_bp(const T* hash_id,
                         int len,
                         const T* top_pos,
                         T* weights,
                         T mlr,
                         int _num_emb,
                         int _rand_len,
A
Aurelius84 已提交
505
                         int _space_len) const {
506
    for (int j = 0; j != _num_emb; j += _rand_len) {
A
Aurelius84 已提交
507
      unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len;
508
      axpy(top_pos + j, weights + pos, _rand_len, mlr);
A
Aurelius84 已提交
509 510 511 512
    }
  }

  void Compute(const framework::ExecutionContext& ctx) const override {
513
    auto* bottom = ctx.Input<phi::DenseTensor>("X");
514
    auto* _blobs = ctx.Input<phi::DenseTensor>("W");
515 516
    auto* drop_pos = ctx.Input<phi::DenseTensor>("DropPos");
    auto* top = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
A
Aurelius84 已提交
517 518 519 520 521 522 523

    int _num_emb = ctx.Attr<int>("num_emb");
    float _lr = ctx.Attr<float>("lr");
    int _rand_len = ctx.Attr<int>("rand_len");
    int _space_len = ctx.Attr<int>("space_len");
    int _pyramid_layer = ctx.Attr<int>("pyramid_layer");

524
    auto* buff = ctx.Input<phi::DenseTensor>("X_Temp_Out");
525
    auto* bottom_data = buff->data<T>();
A
Aurelius84 已提交
526 527

    int _slot_len = bottom->dims()[0];
528
    if (static_cast<size_t>(_slot_len) == bottom->lod()[0].size() - 1 &&
A
Aurelius84 已提交
529 530 531 532 533 534 535 536
        std::count(bottom_data, bottom_data + _slot_len, -1) == _slot_len) {
      return;
    }

    auto& offset = bottom->lod()[0];
    auto& drop_pos_offset = drop_pos->lod()[0];

    const auto* top_diff = top->data<T>();
537
    // in-place update weight, so need const_cast
A
Aurelius84 已提交
538 539 540 541 542
    T* weights = const_cast<T*>(_blobs->data<T>());
    T mlr = -1.0 * _lr;

    const int* iter = drop_pos->data<int>();
    int top_counter = 0;
543
    for (size_t i = 0; i < offset.size() - 1; ++i) {
A
Aurelius84 已提交
544 545 546 547 548 549 550 551 552 553 554 555 556
      int w = offset[i + 1] - offset[i];
      int w_drop = drop_pos_offset[i + 1] - drop_pos_offset[i];
      if (w_drop == 0) {
        top_counter++;
      }
      if (w > 1) {
        for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) {
          for (int l = 0; l < w - ilayer; ++l) {
            if (*(iter++) == 0) {
              // do nothing
            } else {
              const T* top_pos = top_diff + top_counter++ * _num_emb;
              hash_embedding_bp((const T*)(bottom_data + offset[i] + l),
557 558 559 560 561 562 563
                                ilayer + 1,
                                top_pos,
                                weights,
                                mlr,
                                _num_emb,
                                _rand_len,
                                _space_len);
A
Aurelius84 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
            }
          }
        }
      } else {
        // do nothing
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plt = paddle::platform;
namespace frm = paddle::framework;
580 581 582
REGISTER_OPERATOR(pyramid_hash,
                  ops::PyramidHashOP,
                  ops::PyramidHashOpMaker,
H
hong 已提交
583 584
                  ops::PyramidHashGradOpMaker<paddle::framework::OpDesc>,
                  ops::PyramidHashGradOpMaker<paddle::imperative::OpBase>);
A
Aurelius84 已提交
585 586
REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad);

L
Leo Chen 已提交
587 588 589 590 591
REGISTER_OP_CPU_KERNEL(pyramid_hash,
                       ops::CPUPyramidHashOPKernel<phi::CPUContext, float>,
                       ops::CPUPyramidHashOPKernel<phi::CPUContext, int8_t>);
REGISTER_OP_CPU_KERNEL(pyramid_hash_grad,
                       ops::CPUPyramidHashOPGradKernel<phi::CPUContext, float>);