attention_padding_mask_compute_test.cc 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
S
Shibo Tao 已提交
16

17 18 19 20
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
S
Shibo Tao 已提交
21

22
#include "lite/core/op_registry.h"
S
Shibo Tao 已提交
23
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

void attention_padding_mask_ref(
    const Tensor& x,
    const Tensor& y,
    Tensor* out,
    Tensor* pad_begin,
    const operators::AttentionPaddingMaskParam& param) {
  auto attn_offset = x.lod()[0];
  auto src_offset = y.lod()[0];
  int attn_seq_num = attn_offset.size() - 1;
  int src_seq_num = src_offset.size() - 1;
  int attn_seq_len = attn_offset[1];
  int src_seq_len = x.dims()[1];
  CHECK_EQ(attn_seq_num % src_seq_num, 0);

  auto count = x.numel();
  auto attn_data = x.data<float>();
  out->Resize(x.dims());
  out->set_lod(x.lod());
  auto out_data = out->mutable_data<float>();
  memcpy(out_data, attn_data, count * sizeof(float));

  for (int i = 0; i < attn_seq_num; ++i) {
    for (int j = 0; j < attn_seq_len; ++j) {
      auto tmp_out_data = out_data + src_seq_len * (attn_seq_len * i + j);
      int src_seq_idx = i % src_seq_num;
      int cur_len = src_offset[src_seq_idx + 1] - src_offset[src_seq_idx];
      for (int k = cur_len; k < src_seq_len; k++) {
        tmp_out_data[k] = param.mask;
      }
    }
  }
}

void prepare_input(Tensor* x, const LoD& lod, int64_t dim2rd) {
  std::vector<int64_t> x_dims{static_cast<int64_t>(lod[0].back()), dim2rd};
  x->Resize(x_dims);
  x->set_lod(lod);
  auto x_data = x->mutable_data<float>();
  auto x_num = x->numel();
  for (int i = 0; i < x_num; i++) {
    x_data[i] = (i - x_num) * 1.1;
  }
}

int get_max_len(const LoD& lod) {
  int max_len = 0;
  auto offset = lod[0];
  for (int i = 0; i < offset.size() - 1; i++) {
    int cur_len = offset[i + 1] - offset[i];
    max_len = max_len < cur_len ? cur_len : max_len;
  }
  return max_len;
}

TEST(attention_padding_mask_x86, retrive_op) {
  auto attention_padding_mask =
S
Shibo Tao 已提交
86
      KernelRegistry::Global().Create("attention_padding_mask");
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  ASSERT_FALSE(attention_padding_mask.empty());
  ASSERT_TRUE(attention_padding_mask.front());
}

TEST(attention_padding_mask_x86, init) {
  AttentionPaddingMaskCompute<float> attention_padding_mask;
  ASSERT_EQ(attention_padding_mask.precision(), PRECISION(kFloat));
  ASSERT_EQ(attention_padding_mask.target(), TARGET(kX86));
}

TEST(attention_padding_mask_x86, run_test) {
  lite::Tensor x, y;
  lite::Tensor out, pad_begin, out_ref, pad_begin_ref;

  LoD x_lod{{0, 3, 6, 9, 12}}, y_lod{{0, 4, 6}};
  prepare_input(&x, x_lod, get_max_len(y_lod));
  prepare_input(&y, y_lod, 1);

  operators::AttentionPaddingMaskParam param;
  param.X = &x;
  param.Y = &y;
  param.pad_id = 12800001;
  param.mask = -90000000.f;
  param.Out = &out;
  param.pad_begin = &pad_begin;

  std::unique_ptr<KernelContext> ctx(new KernelContext);
  ctx->As<X86Context>();
  AttentionPaddingMaskCompute<float> attention_padding_mask_kernel;
  attention_padding_mask_kernel.SetParam(param);
  attention_padding_mask_kernel.SetContext(std::move(ctx));
  attention_padding_mask_kernel.Run();

  attention_padding_mask_ref(x, y, &out_ref, &pad_begin_ref, param);
  auto out_data = out.data<float>();
  auto out_ref_data = out_ref.data<float>();
  for (int i = 0; i < out.numel(); i++) {
    EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
  }
}

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

133
USE_LITE_KERNEL(search_attention_padding_mask, kX86, kFloat, kNCHW, def);