matmul_compute.cc 8.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/arm/matmul_compute.h"
#include <vector>
17
#include "lite/backends/arm/math/funcs.h"
Y
Yan Chunwei 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

void MatMulCompute::PrepareForRun() {
  auto& ctx = this->ctx_->template As<ARMContext>();
}

void MatMulCompute::Run() {
  auto& param = Param<param_t>();

  const auto* x_data = param.X->data<float>();
  const auto* y_data = param.Y->data<float>();
  auto* o_data = param.Out->mutable_data<float>();

  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
W
Wilber 已提交
39
  auto o_dims = param.Out->dims();
Y
Yan Chunwei 已提交
40 41 42 43 44
  bool x_transpose = param.transpose_X;
  bool y_transpose = param.transpose_Y;
  float alpha = param.alpha;
  auto& ctx = this->ctx_->template As<ARMContext>();

45 46 47
  operators::ActivationParam act_param;
  act_param.has_active = false;

48 49
  if ((x_dims.size() >= 2 && y_dims.size() >= 2) &&
      (x_dims.size() != 2 || y_dims.size() != 2)) {
Y
Yan Chunwei 已提交
50 51
    // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
    // x: [B, M, K], y: [K, N], out: [B, M, N]
52 53 54
    // or
    // x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
    // x: [M, K], y: [B, K, N], out: [B, M, N]
W
Wilber 已提交
55 56
    int lda, ldb, ldc;
    if (!x_transpose) {
Y
Yan Chunwei 已提交
57
      m_ = x_dims[x_dims.size() - 2];
W
Wilber 已提交
58 59 60 61 62 63 64 65 66
      k_ = x_dims[x_dims.size() - 1];
      lda = k_;
    } else {
      m_ = x_dims[x_dims.size() - 1];
      k_ = x_dims[x_dims.size() - 2];
      lda = m_;
    }

    if (!y_transpose) {
Y
Yan Chunwei 已提交
67
      n_ = y_dims[y_dims.size() - 1];
W
Wilber 已提交
68 69 70 71 72 73 74 75 76 77 78 79
      ldb = n_;
    } else {
      n_ = y_dims[y_dims.size() - 2];
      ldb = k_;
    }

    ldc = n_;

    int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1];
    int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
    int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1];

80
    if (x_dims.size() > 2 && y_dims.size() > 2) {
W
Wilber 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
      for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
        lite::arm::math::sgemm(x_transpose,
                               y_transpose,
                               m_,
                               n_,
                               k_,
                               alpha,
                               x_data + i * x_inner,
                               lda,
                               y_data + i * y_inner,
                               ldb,
                               0.f,
                               o_data + i * out_inner,
                               ldc,
                               nullptr,
                               false,
97
                               act_param,
W
Wilber 已提交
98
                               &ctx);
Y
Yan Chunwei 已提交
99
      }
100
    } else if (x_dims.size() > 2 && y_dims.size() == 2) {
W
Wilber 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
      for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
        lite::arm::math::sgemm(x_transpose,
                               y_transpose,
                               m_,
                               n_,
                               k_,
                               alpha,
                               x_data + i * x_inner,
                               lda,
                               y_data,
                               ldb,
                               0.f,
                               o_data + i * out_inner,
                               ldc,
                               nullptr,
                               false,
117
                               act_param,
W
Wilber 已提交
118
                               &ctx);
Y
Yan Chunwei 已提交
119
      }
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    } else if (x_dims.size() == 2 && y_dims.size() > 2) {
      for (size_t i = 0; i < y_dims.count(0, y_dims.size() - 2); ++i) {
        lite::arm::math::sgemm(x_transpose,
                               y_transpose,
                               m_,
                               n_,
                               k_,
                               alpha,
                               x_data,
                               lda,
                               y_data + i * y_inner,
                               ldb,
                               0.f,
                               o_data + i * out_inner,
                               ldc,
                               nullptr,
                               false,
                               act_param,
                               &ctx);
      }
W
Wilber 已提交
140
    }
Y
Yan Chunwei 已提交
141 142
  } else if (x_dims.size() == 2 && y_dims.size() == 2) {
    // x: [M, K], y: [K, N], out: [M, N]
W
Wilber 已提交
143 144 145 146 147
    int lda, ldb, ldc;
    if (!x_transpose) {
      m_ = x_dims[0];
      k_ = x_dims[1];
      lda = k_;
Y
Yan Chunwei 已提交
148
    } else {
W
Wilber 已提交
149 150 151
      m_ = x_dims[1];
      k_ = x_dims[0];
      lda = m_;
Y
Yan Chunwei 已提交
152
    }
W
Wilber 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    if (!y_transpose) {
      n_ = y_dims[1];
      ldb = n_;
    } else {
      n_ = y_dims[0];
      ldb = k_;
    }
    ldc = n_;

    lite::arm::math::sgemm(x_transpose,
                           y_transpose,
                           m_,
                           n_,
                           k_,
                           alpha,
                           x_data,
                           lda,
                           y_data,
                           ldb,
                           0.f,
                           o_data,
                           ldc,
                           nullptr,
                           false,
177
                           act_param,
W
Wilber 已提交
178
                           &ctx);
Y
Yan Chunwei 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
  } else if (x_dims.size() > 2 && y_dims.size() == 1) {
    // x: [B, M, K], y: [K], out: [B, M]
    CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
        << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
        << ")";
    for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 1); ++i) {
      o_data[i] = 0;
      for (size_t j = 0; j < y_dims[0]; ++j) {
        o_data[i] += x_data[i * y_dims[0] + j] * y_data[j] * alpha;
      }
    }
  } else if (x_dims.size() == 1 && y_dims.size() == 1) {
    // x: [K], y: [K], out: [1]
    if (x_dims[0] == y_dims[0] && x_transpose == false &&
        y_transpose == false) {
      o_data[0] = 0.;
      for (size_t i = 0; i < x_dims[0]; ++i) {
        o_data[0] += x_data[i] * y_data[i] * alpha;
      }
    }
    // x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
    if (x_transpose == true && y_transpose == true) {
      m_ = x_dims[0];
      k_ = 1;
      n_ = y_dims[0];
W
Wilber 已提交
204 205 206
      int lda = k_;
      int ldb = n_;
      int ldc = n_;
Y
Yan Chunwei 已提交
207
      if (n_ == 1) {
208 209 210 211 212 213 214 215 216 217 218
        lite::arm::math::sgemv(x_data,
                               y_data,
                               o_data,
                               false,
                               m_,
                               k_,
                               false,
                               nullptr,
                               false,
                               lite_api::ActivationType::kIndentity,
                               &ctx);
Y
Yan Chunwei 已提交
219 220 221 222 223 224
        if (fabsf(alpha - 1.f) > 1e-8f) {
          for (size_t i = 0; i < param.Out->dims().production(); ++i) {
            o_data[i] *= alpha;
          }
        }
      } else {
W
Wilber 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        lite::arm::math::sgemm(false,
                               false,
                               m_,
                               n_,
                               k_,
                               alpha,
                               x_data,
                               lda,
                               y_data,
                               ldb,
                               0.f,
                               o_data,
                               ldc,
                               nullptr,
                               false,
240
                               act_param,
W
Wilber 已提交
241
                               &ctx);
Y
Yan Chunwei 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
      }
    }
  } else {
    LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
               << ")";
  }
}

}  // namespace arm
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(
    matmul, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::MatMulCompute, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();