mul_compute_test.cc 4.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

T
tensor-tang 已提交
15
#include "paddle/fluid/lite/kernels/arm/mul_compute.h"
T
tensor-tang 已提交
16
#include <gtest/gtest.h>
T
tensor-tang 已提交
17 18
#include <algorithm>
#include <iostream>
T
tensor-tang 已提交
19
#include <memory>
T
tensor-tang 已提交
20
#include <random>
T
tensor-tang 已提交
21
#include <utility>
T
tensor-tang 已提交
22 23 24 25 26 27 28 29 30
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
#define A(i, j) a[i * lda + j]
#define B(i, j) b[i * ldb + j]
#define C(i, j) c[i * ldc + j]

template <typename T>
void mul_gemm(const T* a, const int M, const int K, const T* b, const int K_,
              const int N, T* c) {
  EXPECT_TRUE(K_ == K && M > 0 && N > 0 && K > 0);
  EXPECT_TRUE(a && b && c);
  const int lda = K;
  const int ldb = N;
  const int ldc = N;
  for (int m = 0; m < M; ++m) {
    for (int n = 0; n < N; ++n) {
      C(m, n) = 0.0f;
      for (int k = 0; k < K; ++k) {
        C(m, n) += A(m, k) * B(k, n);
      }
    }
  }
}

T
tensor-tang 已提交
53 54 55 56 57 58 59 60 61 62 63
template <typename T>
void FillData(T* a, const int n, const T lower = static_cast<T>(-2.f),
              const T upper = static_cast<T>(2.f)) {
  static unsigned int seed = 100;
  std::mt19937 rng(seed++);
  std::uniform_real_distribution<double> uniform_dist(0, 1);
  for (int i = 0; i < n; ++i) {
    a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
  }
}

T
tensor-tang 已提交
64 65 66 67 68
TEST(mul_arm, retrive_op) {
  auto mul =
      KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
  ASSERT_FALSE(mul.empty());
  ASSERT_TRUE(mul.front());
T
tensor-tang 已提交
69 70
}

T
tensor-tang 已提交
71
TEST(mul_arm, init) {
T
tensor-tang 已提交
72
  MulCompute mul;
T
tensor-tang 已提交
73 74
  ASSERT_EQ(mul.precision(), PRECISION(kFloat));
  ASSERT_EQ(mul.target(), TARGET(kARM));
T
tensor-tang 已提交
75 76
}

T
tensor-tang 已提交
77
TEST(mul_arm, compare_test) {
T
tensor-tang 已提交
78 79 80 81 82
  using T = float;

  for (int m : {1, 2, 3, 4}) {
    for (int n : {1, 2, 3, 4}) {
      for (int k : {1, 2, 3, 4}) {
T
tensor-tang 已提交
83
        VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k;
T
tensor-tang 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96
        lite::Tensor x, y, out, ref;
        x.Resize({m, k});
        y.Resize({k, n});
        out.Resize({m, n});
        ref.Resize({m, n});

        auto* x_data = x.mutable_data<T>();
        auto* y_data = y.mutable_data<T>();
        auto* out_data = out.mutable_data<T>();
        auto* ref_data = ref.mutable_data<T>();

        FillData<T>(x_data, x.dims().production());
        FillData<T>(y_data, y.dims().production());
T
tensor-tang 已提交
97
        FillData<T>(out_data, out.dims().production(), 0, 0);
T
Tensor Tang 已提交
98
        FillData<T>(ref_data, ref.dims().production(), 0, 0);
T
tensor-tang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

        MulCompute mul;
        operators::MulParam param;

        param.x = &x;
        param.y = &y;
        param.output = &out;

        DeviceInfo::Init();
        std::unique_ptr<KernelContext> ctx(new KernelContext);
        ctx->As<ARMContext>();
        mul.SetParam(param);
        mul.SetContext(std::move(ctx));
        mul.PrepareForRun();

        mul.Run();

116 117
        mul_gemm<T>(x_data, m, k, y_data, k, n, ref_data);

T
tensor-tang 已提交
118 119 120 121 122
        for (int i = 0; i < out.dims().production(); i++) {
          EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
        }
      }
    }
T
tensor-tang 已提交
123
  }
T
tensor-tang 已提交
124 125 126 127
}

TEST(mul_arm, num_col_dims) {
  using T = float;
T
tensor-tang 已提交
128

T
tensor-tang 已提交
129 130 131 132 133
  lite::Tensor x, y, out, ref;
  x.Resize({2, 3, 4});
  y.Resize({3, 4, 5});
  out.Resize({2, 5});
  ref.Resize({2, 5});
T
tensor-tang 已提交
134

T
tensor-tang 已提交
135 136 137 138
  auto* x_data = x.mutable_data<T>();
  auto* y_data = y.mutable_data<T>();
  auto* out_data = out.mutable_data<T>();
  auto* ref_data = ref.mutable_data<T>();
T
tensor-tang 已提交
139

T
tensor-tang 已提交
140 141 142 143 144 145 146 147 148 149
  FillData<T>(x_data, x.dims().production());
  FillData<T>(y_data, y.dims().production());
  FillData<T>(out_data, out.dims().production());
  FillData<T>(ref_data, out.dims().production());

  MulCompute mul;
  operators::MulParam param;

  param.x = &x;
  param.y = &y;
T
tensor-tang 已提交
150
  param.output = &out;
T
tensor-tang 已提交
151 152
  param.x_num_col_dims = 1;
  param.y_num_col_dims = 2;
T
tensor-tang 已提交
153 154 155 156

  DeviceInfo::Init();
  std::unique_ptr<KernelContext> ctx(new KernelContext);
  ctx->As<ARMContext>();
T
tensor-tang 已提交
157 158
  mul.SetParam(param);
  mul.SetContext(std::move(ctx));
T
tensor-tang 已提交
159
  mul.PrepareForRun();
T
tensor-tang 已提交
160

T
tensor-tang 已提交
161
  mul.Run();
T
tensor-tang 已提交
162

163 164
  mul_gemm<T>(x_data, 2, 12, y_data, 12, 5, ref_data);

T
tensor-tang 已提交
165 166
  for (int i = 0; i < out.dims().production(); i++) {
    EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
T
tensor-tang 已提交
167 168 169 170 171 172 173 174
  }
}

}  // namespace arm
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

T
tensor-tang 已提交
175
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);