match_matrix_tensor_compute.cu 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
W
Wilber 已提交
13
#include <algorithm>
14 15 16 17 18 19 20 21 22 23
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/match_matrix_tensor_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;

W
Wilber 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
template <typename dtype>
void gpu_transpose(
    cublasHandle_t handle, const dtype* src, int M, int N, dtype* dst);

template <>
void gpu_transpose<float>(
    cublasHandle_t handle, const float* src, int M, int N, float* dst) {
  float alpha = 1.0;
  float beta = 0.0;
  CUBLAS_CHECK(cublasSgeam(handle,
                           CUBLAS_OP_T,
                           CUBLAS_OP_N,
                           M,
                           N,
                           &alpha,
                           src,
                           N,
                           &beta,
                           dst,
                           M,
                           dst,
                           M));
}

template <typename dtype>
__global__ void padding_out(const dtype* src,
                            const int* offset,
                            const int seq_num_r,
                            const int max_len_r,
                            const int tl,
                            const int count,
                            dtype* dst) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  int thread_num = blockDim.x * gridDim.x;
  for (tid = threadIdx.x + blockIdx.x * blockDim.x; tid < count;
       tid += thread_num) {
    int seq_id = tid / (tl * max_len_r);
    int tl_id = (tid / (max_len_r)) % tl;
    int r_id = tid % max_len_r;
    int cur_len = offset[seq_id + 1] - offset[seq_id];
    if (r_id < cur_len) {
      dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id];
    } else {
      dst[tid] = 0.f;
    }
  }
}

72 73 74 75 76 77 78 79
void MatchMatrixTensorCompute::PrepareForRun() {
  gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}

void MatchMatrixTensorCompute::Run() {
  CHECK(ctx_) << "running context should be set first";
  auto& param = this->Param<param_t>();
  auto& context = this->ctx_->template As<CUDAContext>();
W
Wilber 已提交
80
  auto stream = context.exec_stream();
81 82 83 84 85 86 87 88 89 90 91

  auto* x = param.x;
  auto* w = param.w;
  auto* y = param.y;
  auto* out = param.out;
  auto* tmp = param.tmp;
  int dim_t = param.dim_t;
  int dim_in = x->dims()[1];

  const auto& offset_l = x->lod()[0];
  const auto& offset_r = y->lod()[0];
W
Wilber 已提交
92 93 94 95 96 97 98 99 100 101 102 103
  std::vector<int> offset_r_int(offset_r.size());
  std::transform(offset_r.begin(),
                 offset_r.end(),
                 offset_r_int.begin(),
                 [](int64_t x) -> int { return static_cast<int>(x); });

  int batch = offset_r.size() - 1;
  int len_l = offset_l[1] - offset_l[0];
  for (int i = 1; i < offset_l.size() - 1; i++) {
    int cur_len = offset_l[i + 1] - offset_l[i];
    CHECK_EQ(cur_len, len_l)
        << "each sequence of left matrix is the same length";
104
  }
W
Wilber 已提交
105 106 107 108
  int max_len_r = 0;
  for (int i = 0; i < offset_r.size() - 1; ++i) {
    int cur_len = offset_r[i + 1] - offset_r[i];
    max_len_r = cur_len > max_len_r ? cur_len : max_len_r;
109
  }
110

W
Wilber 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  _input_l_transform.Resize({batch, dim_t, dim_in, len_l});
  _input_l_transform_reorganize.Resize({batch, dim_t, len_l, dim_in});
  _output_tmp.Resize({batch, max_len_r, dim_t, len_l});
  out->Resize({batch, dim_t, len_l, max_len_r});

  _offset_r.Resize({static_cast<int64_t>(offset_r.size())});
  TargetWrapperCuda::MemcpyAsync(_offset_r.mutable_data<int>(TARGET(kCUDA)),
                                 &offset_r_int[0],
                                 sizeof(int) * offset_r.size(),
                                 IoDirection::HtoD,
                                 stream);

  int len_r = offset_r[offset_r.size() - 1];
  const float* input_l = x->data<float>();
  const float* input_r = y->data<float>();
  const float* weight_data = w->data<float>();
  float* input_l_transform =
      _input_l_transform.mutable_data<float>(TARGET(kCUDA));
  float* input_l_transform_reorganize =
      _input_l_transform_reorganize.mutable_data<float>(TARGET(kCUDA));
  float* output_tmp = _output_tmp.mutable_data<float>(TARGET(kCUDA));
  float* out_data = out->mutable_data<float>(TARGET(kCUDA));

  gemm_impl_->init(true, true, dim_t * dim_in, len_l, dim_in, &context);
  gemm_impl_->run(
      1.0f, 0.0f, weight_data, input_l, input_l_transform, &context);
  for (int i = 0; i < dim_t; ++i) {
    int offset = i * dim_in * len_l;
    gpu_transpose(gemm_impl_->get_handle(),
                  input_l_transform + offset,
                  dim_in,
                  len_l,
                  input_l_transform_reorganize + offset);
144
  }
W
Wilber 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  gemm_impl_->init(false, true, len_r, dim_t * len_l, dim_in, &context);
  gemm_impl_->run(
      1.0f, 0.0f, input_r, input_l_transform_reorganize, output_tmp, &context);
  int seq_num = offset_r.size() - 1;
  int count = seq_num * max_len_r * dim_t * len_l;
  const int blocks = 512;
  const int grids = (count + blocks - 1) / blocks;
  padding_out<float><<<grids, blocks, 0, stream>>>(_output_tmp.data<float>(),
                                                   _offset_r.data<int>(),
                                                   seq_num,
                                                   max_len_r,
                                                   dim_t * len_l,
                                                   count,
                                                   out_data);
  out->set_lod(y->lod());
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(match_matrix_tensor,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::MatchMatrixTensorCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("W",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("Y",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .BindOutput("Tmp",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();