search_grnn_compute.h 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
W
Wilber 已提交
17
#include <vector>
18 19 20 21 22 23 24 25 26
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

W
Wilber 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
class SeqSortedseqTranseUtil {
 public:
  explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false)
      : _is_reverse(is_reverse),
        _is_bi(is_bi),
        _dev_map_vec(nullptr),
        _dev_map_vec_length(0) {}

  ~SeqSortedseqTranseUtil() {
    if (_dev_map_vec != nullptr) {
      TargetWrapperCuda::Free(static_cast<void*>(_dev_map_vec));
    }
  }

  std::vector<int>& get_length_index() { return _length_index; }
  std::vector<int>& get_emit_offset_vec() { return _emit_offset_vec; }
  std::vector<int>& get_map_vec() { return _map_vec; }
  int* get_dev_map_vec() { return _dev_map_vec; }
  int get_emit_length() { return _emit_length; }

  template <typename Dtype>
  void seq_2_sorted_seq(const Dtype* input,
                        Dtype* output,
                        int word_size,
                        cudaStream_t stream);

  template <typename Dtype>
  void sorted_seq_2_seq(const Dtype* input,
                        Dtype* output,
                        int hidden_size,
                        cudaStream_t stream);

  bool get_sorted_map(const std::vector<int>& offset_vec,
                      cudaStream_t stream_id);

 private:
  std::vector<int> _length_index;
  std::vector<int> _emit_offset_vec;
  std::vector<int> _map_vec;
  int _emit_length;

  bool _is_reverse;
  bool _is_bi;
  int* _dev_map_vec;
  int _dev_map_vec_length;
};

74 75 76 77 78 79 80 81 82 83 84
class SearchGrnnCompute
    : public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
 public:
  using param_t = operators::SearchGrnnParam;
  using TargetW = TargetWrapper<TARGET(kCUDA)>;

  void PrepareForRun() override;
  void Run() override;
  virtual ~SearchGrnnCompute() = default;

 private:
W
Wilber 已提交
85 86 87 88 89 90 91
  // Weights preprocess:
  // wi need to be transpose, the axes should be (2, 0, 1)
  // wh0 should transpose, {wh1 wh2} need be transpose, the axes should be {2,
  // 0, 1}
  void WeightsPreprocess();

 private:
92
  std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
W
Wilber 已提交
93 94 95 96 97 98 99 100 101 102 103 104

  lite::Tensor _temp_tensor_in;
  lite::Tensor _temp_tensor_out;
  lite::Tensor _temp_wx;
  lite::Tensor _temp_wh;
  lite::Tensor _temp_zero;
  lite::Tensor _temp_weights_h2h;

  lite::Tensor _wi;
  lite::Tensor _wh;

  SeqSortedseqTranseUtil _seq_util;
105 106 107 108 109 110
};

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle