transform_test.cc 8.6 KB
Newer Older
T
TianXiaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"

DEFINE_string(input, "", "input_data");
DEFINE_int32(batch, 1, "batch");

namespace paddle {
namespace lite {

X
xiaogang 已提交
32
namespace test_transformer {
T
TianXiaogang 已提交
33
std::vector<std::string> inputed_lines;
X
xiaogang 已提交
34
void load_input_lines(const char* filename) {
T
TianXiaogang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
  static const int max_line_buf_size = 100 * 1024 * 1024;
  char* line_buffer = (char*)calloc(max_line_buf_size, sizeof(char));  // NOLINT
  FILE* input_file = fopen(filename, "r");

  while (fgets(line_buffer, max_line_buf_size, input_file)) {
    // trim newline at end
    char* pos = NULL;
    if ((pos = strchr(line_buffer, '\n')) != NULL) {
      *pos = 0;
    }
    inputed_lines.push_back(line_buffer);
  }
  free(line_buffer);
  line_buffer = NULL;
  fclose(input_file);
}
X
xiaogang 已提交
51
void split2(const std::string& main_str,
T
TianXiaogang 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
            std::vector<std::string>& str_list,  // NOLINT
            const std::string& delimiter) {
  size_t pre_pos = 0;
  size_t position = 0;
  std::string tmp_str;

  str_list.clear();
  if (main_str.empty()) {
    return;
  }

  while ((position = main_str.find(delimiter, pre_pos)) != std::string::npos) {
    tmp_str.assign(main_str, pre_pos, position - pre_pos);
    str_list.push_back(tmp_str);
    pre_pos = position + 1;
  }

  tmp_str.assign(main_str, pre_pos, main_str.length() - pre_pos);

  if (!tmp_str.empty()) {
    str_list.push_back(tmp_str);
  }
}
}  // NOLINT

X
xiaogang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89
void pad_batch_input(std::vector<std::string>& input_lines,  // NOLINT
                     int pad_idx,
                     int n_head,
                     Tensor* src_word,
                     Tensor* src_pos,
                     Tensor* src_attn_bias,
                     Tensor* trg_word,
                     Tensor* init_scores,
                     Tensor* init_idx,
                     Tensor* trg_bias,
                     int line_start,
                     int batch_size,
                     int bos_idx) {
T
TianXiaogang 已提交
90 91 92 93 94 95 96 97 98 99
  int max_len = 0;
  int max_line = input_lines.size();

  std::vector<std::vector<std::string>> batch_lines;
  for (int i = line_start; i < line_start + batch_size; ++i) {
    int i_index = i % max_line;
    std::string cur_line = input_lines[i_index];

    std::vector<std::string> split_str;

X
xiaogang 已提交
100
    test_transformer::split2(cur_line, split_str, " ");
T
TianXiaogang 已提交
101 102 103 104 105

    batch_lines.push_back(split_str);
    max_len = max_len >= split_str.size() ? max_len : split_str.size();
  }

X
xiaogang 已提交
106 107
  src_word->Resize(std::vector<DDim::value_type>({batch_size, max_len}));
  src_pos->Resize(std::vector<DDim::value_type>({batch_size, max_len}));
T
TianXiaogang 已提交
108 109 110
  src_attn_bias->Resize(
      std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len}));
  trg_bias->Resize(
X
xiaogang 已提交
111 112 113
      std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len}));
  auto* src_word_data = src_word->mutable_data<int64_t>();
  auto* src_pos_data = src_pos->mutable_data<int64_t>();
T
TianXiaogang 已提交
114 115 116 117 118 119
  float* src_bias_data = src_attn_bias->mutable_data<float>();
  float* trg_bias_data = trg_bias->mutable_data<float>();
  for (int i = 0; i < batch_size; ++i) {
    std::vector<std::string> cur_words = batch_lines[i];
    int fill_len = cur_words.size();
    int src_bias_start = i * n_head * max_len * max_len;
X
xiaogang 已提交
120
    int trg_bias_start = i * n_head * max_len * max_len;
T
TianXiaogang 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    for (int j = 0; j < fill_len; ++j) {
      src_word_data[i * max_len + j] = (atoi(cur_words[j].c_str()));
      src_pos_data[i * max_len + j] = j;
      src_bias_data[src_bias_start + j] = 0;
      trg_bias_data[trg_bias_start + j] = 0;
    }
    for (int j = fill_len; j < max_len; ++j) {
      src_word_data[i * max_len + j] = pad_idx;
      src_pos_data[i * max_len + j] = 0;
      src_bias_data[src_bias_start + j] = -1000000000;
      trg_bias_data[trg_bias_start + j] = -1000000000;
    }
    for (int j = src_bias_start;
         j < src_bias_start + n_head * max_len * max_len;
         ++j) {
      int value_ind = j % max_len + src_bias_start;
      src_bias_data[j] = src_bias_data[value_ind];
    }
X
xiaogang 已提交
139 140 141
    for (int j = trg_bias_start;
         j < trg_bias_start + n_head * max_len * max_len;
         ++j) {
T
TianXiaogang 已提交
142 143 144 145 146
      int value_ind = j % max_len + trg_bias_start;
      trg_bias_data[j] = trg_bias_data[value_ind];
    }
  }

X
xiaogang 已提交
147 148 149
  trg_word->Resize(std::vector<DDim::value_type>({batch_size, max_len}));
  auto* trg_word_data = trg_word->mutable_data<int64_t>();
  for (int i = 0; i < batch_size * max_len; ++i) {
T
TianXiaogang 已提交
150 151 152 153 154 155
    trg_word_data[i] = bos_idx;
  }

  init_scores->Resize(std::vector<DDim::value_type>({batch_size, 1}));
  init_idx->Resize(std::vector<DDim::value_type>({batch_size}));
  float* score_data = init_scores->mutable_data<float>();
X
xiaogang 已提交
156
  auto* idx_data = init_idx->mutable_data<int32_t>();
T
TianXiaogang 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
  for (int i = 0; i < init_scores->numel(); ++i) {
    score_data[i] = 0;
  }
  std::vector<std::vector<uint64_t>> lod_s;
  lod_s.resize(2);
  for (int i = 0; i < batch_size; ++i) {
    lod_s[0].push_back(i);
    lod_s[1].push_back(i);
    idx_data[i] = i;
  }
  lod_s[0].push_back(batch_size);
  lod_s[1].push_back(batch_size);
  auto score_lod = init_scores->mutable_lod();
  *score_lod = lod_s;

  auto trg_word_lod = trg_word->mutable_lod();
  *trg_word_lod = lod_s;
}

void TestModel(const std::vector<Place>& valid_places,
               const Place& preferred_place,
               bool use_npu = false) {
X
xiaogang 已提交
179
#ifdef LITE_WITH_ARM
T
TianXiaogang 已提交
180 181
  DeviceInfo::Init();
  DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
X
xiaogang 已提交
182
#endif
T
TianXiaogang 已提交
183 184 185
  lite::Predictor predictor;
  std::string test_data_path = FLAGS_input;

X
xiaogang 已提交
186 187 188 189 190
  predictor.Build("",
                  FLAGS_model_dir + "/__model__",
                  FLAGS_model_dir + "/weights",
                  valid_places);
  // predictor.Build(FLAGS_model_dir, "", "", valid_places);
T
TianXiaogang 已提交
191 192 193 194 195 196

  int n_head = 8;
  int batch_size = FLAGS_batch;
  int bos_idx = 0;
  int eos_idx = 1;

X
xiaogang 已提交
197
  test_transformer::load_input_lines(test_data_path.c_str());
T
TianXiaogang 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212

  auto* trg_bias = predictor.GetInput(6);
  auto* src_word = predictor.GetInput(0);
  auto* src_pos = predictor.GetInput(1);
  auto* src_bias = predictor.GetInput(2);
  auto* trg_word = predictor.GetInput(3);
  auto* init_score = predictor.GetInput(4);
  auto* init_idx = predictor.GetInput(5);

  for (int i = 0; i < FLAGS_warmup; ++i) {
    predictor.Run();
  }

  auto start = GetCurrentUS();
  for (int i = 0; i < FLAGS_repeats; ++i) {
X
xiaogang 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225
    pad_batch_input(test_transformer::inputed_lines,
                    eos_idx,
                    n_head,
                    src_word,    // src_word
                    src_pos,     // src_pos
                    src_bias,    // src_bias
                    trg_word,    // trg_word
                    init_score,  // init_score
                    init_idx,    // init_idx
                    trg_bias,    // trg_bias
                    i * batch_size,
                    batch_size,
                    bos_idx);
T
TianXiaogang 已提交
226
    predictor.Run();
X
xiaogang 已提交
227 228 229 230 231 232 233 234 235 236 237
    auto* outs = predictor.GetOutput(0);
    auto o_data = outs->data<int64_t>();
    auto lod = outs->lod();
    for (int i = 0; i < outs->numel(); ++i) {
      LOG(INFO) << o_data[i];
    }
    for (int i = 0; i < lod.size(); ++i) {
      for (int j = 0; j < lod[i].size(); ++j) {
        LOG(INFO) << lod[i][j];
      }
    }
T
TianXiaogang 已提交
238 239 240 241 242 243 244 245 246
  }

  LOG(INFO) << "================== Speed Report ===================";
  LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
            << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
            << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
            << " ms in average.";
}

X
xiaogang 已提交
247 248 249 250 251
}  // namespace lite
}  // namespace paddle
using namespace paddle::lite;  // NOLINT
int main(int argc, char** argv) {
  gflags::ParseCommandLineFlags(&argc, &argv, true);
T
TianXiaogang 已提交
252
  std::vector<Place> valid_places({
X
xiaogang 已提交
253
      Place{TARGET(kARM), PRECISION(kInt64)},
T
TianXiaogang 已提交
254
      Place{TARGET(kARM), PRECISION(kFloat)},
X
xiaogang 已提交
255
      Place{TARGET(kHost), PRECISION(kFloat)},
T
TianXiaogang 已提交
256 257 258 259
  });

  TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)}));
}