提交 8dd9d6b1 编写于 作者: L Liangliang He

Add kernels benchmark against eigen

上级 2940ce5d
......@@ -9,7 +9,7 @@ http_archive(
strip_prefix = "protobuf-3.4.0",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/protobuf/protobuf-3.4.0.zip",
"https://github.com/google/protobuf/archive/v3.4.0.zip"
"https://github.com/google/protobuf/archive/v3.4.0.zip",
],
)
......@@ -20,7 +20,7 @@ new_http_archive(
strip_prefix = "googletest-release-1.8.0",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/googletest/googletest-release-1.8.0.zip",
"https://github.com/google/googletest/archive/release-1.8.0.zip"
"https://github.com/google/googletest/archive/release-1.8.0.zip",
],
)
......@@ -31,7 +31,7 @@ new_http_archive(
strip_prefix = "OpenCL-Headers-master",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/OpenCL-Headers/OpenCL-Headers-master.zip",
"https://github.com/KhronosGroup/OpenCL-Headers/archive/master.zip"
"https://github.com/KhronosGroup/OpenCL-Headers/archive/master.zip",
],
)
......@@ -42,7 +42,7 @@ new_http_archive(
strip_prefix = "OpenCL-CLHPP-4c6f7d56271727e37fb19a9b47649dd175df2b12",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/OpenCL-CLHPP/OpenCL-CLHPP-4c6f7d56271727e37fb19a9b47649dd175df2b12.zip",
"https://github.com/KhronosGroup/OpenCL-CLHPP/archive/4c6f7d56271727e37fb19a9b47649dd175df2b12.zip"
"https://github.com/KhronosGroup/OpenCL-CLHPP/archive/4c6f7d56271727e37fb19a9b47649dd175df2b12.zip",
],
)
......@@ -53,7 +53,29 @@ new_http_archive(
strip_prefix = "half-code-356-trunk",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/half/half-code-356-trunk.zip",
"https://sourceforge.net/code-snapshots/svn/h/ha/half/code/half-code-356-trunk.zip"
"https://sourceforge.net/code-snapshots/svn/h/ha/half/code/half-code-356-trunk.zip",
],
)
new_http_archive(
name = "eigen",
build_file = "third_party/eigen3/eigen.BUILD",
sha256 = "ca7beac153d4059c02c8fc59816c82d54ea47fe58365e8aded4082ded0b820c4",
strip_prefix = "eigen-eigen-f3a22f35b044",
urls = [
"http://cnbj1.fds.api.xiaomi.com/mace/third-party/eigen/f3a22f35b044.tar.gz",
"http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz",
"https://bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz",
],
)
http_archive(
name = "gemmlowp",
sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
urls = [
"http://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
"https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
],
)
......@@ -81,7 +103,7 @@ http_archive(
strip_prefix = "gflags-30dbc81fb5ffdc98ea9b14b1918bfe4e8779b26e",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/gflags/gflags-30dbc81fb5ffdc98ea9b14b1918bfe4e8779b26e.zip",
"https://github.com/gflags/gflags/archive/30dbc81fb5ffdc98ea9b14b1918bfe4e8779b26e.zip"
"https://github.com/gflags/gflags/archive/30dbc81fb5ffdc98ea9b14b1918bfe4e8779b26e.zip",
],
)
......
......@@ -18,14 +18,17 @@ cc_library(
],
exclude = [
"*_test.cc",
"*_benchmark.cc",
"arm/*_test.cc",
],
) + if_android(glob([
) + if_android(glob(
[
"opencl/*.cc",
],
exclude = [
"opencl/*_test.cc",
])),
],
)),
hdrs = glob(
[
"*.h",
......@@ -35,16 +38,26 @@ cc_library(
"buffer_to_image.h",
],
) + if_android(glob([
"opencl/*.h",
"buffer_to_image.h",
])),
copts = ["-Werror", "-Wextra", "-Wno-missing-field-initializers"] +
if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
"opencl/*.h",
"buffer_to_image.h",
])),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
]) + if_android_armv7([
"-mfloat-abi=softfp",
]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
......@@ -62,13 +75,22 @@ cc_test(
"opencl/*_test.cc",
],
),
copts = ["-Werror", "-Wextra", "-Wno-missing-field-initializers"] +
if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
......@@ -77,3 +99,32 @@ cc_test(
"@gtest//:gtest_main",
],
)
cc_test(
name = "kernels_benchmark",
testonly = 1,
srcs = glob(["*_benchmark.cc"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
":kernels",
"//mace/core:test_benchmark_main",
"//third_party/eigen3",
],
)
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Dense>
#include <algorithm>
#include <string>
#include <vector>
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/gemm.h"
#include "public/gemmlowp.h"
namespace mace {
namespace kernels {
namespace test {
// Test the speed of different access order of a NHWC buffer
namespace {
// Matmul with (m, k) x (k, n)
void MatmulBenchmark_Mace(int iters, int m, int k, int n) {
mace::testing::StopTiming();
std::vector<float> lhs(m * k);
std::vector<float> rhs(k * n);
std::vector<float> result(m * n);
// warm up
Gemm(lhs.data(), rhs.data(), 1, m, k, n, result.data());
mace::testing::StartTiming();
while (iters--) {
Gemm(lhs.data(), rhs.data(), 1, m, k, n, result.data());
}
}
void MatmulBenchmark_Eigen(int iters, int m, int k, int n) {
mace::testing::StopTiming();
Eigen::MatrixXd lhs = Eigen::MatrixXd::Random(m, k);
Eigen::MatrixXd rhs = Eigen::MatrixXd::Random(k, n);
Eigen::MatrixXd result = Eigen::MatrixXd::Zero(m, n);
// warm up
result = lhs * rhs;
mace::testing::StartTiming();
while (iters--) {
result = lhs * rhs;
}
}
} // namespace
#define MACE_BM_MATMUL_FUNC(M, K, N, FUNC) \
static void MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC(int iters) { \
const int64_t macc = static_cast<int64_t>(iters) * M * K * N; \
const int64_t tot = static_cast<int64_t>(iters) * (M + N) * K; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot * sizeof(float)); \
MatmulBenchmark_##FUNC(iters, M, K, N); \
} \
MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC)
#define MACE_BM_MATMUL(M, K, N) \
MACE_BM_MATMUL_FUNC(M, K, N, Mace); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen);
// Embedding size 384
MACE_BM_MATMUL(7, 384, 384);
MACE_BM_MATMUL(7, 384, 1536);
MACE_BM_MATMUL(7, 1536, 384);
MACE_BM_MATMUL(15, 384, 384);
MACE_BM_MATMUL(15, 384, 1536);
MACE_BM_MATMUL(15, 1536, 384);
MACE_BM_MATMUL(1, 384, 384);
MACE_BM_MATMUL(1, 384, 1536);
MACE_BM_MATMUL(1, 1536, 384);
MACE_BM_MATMUL(1, 384, 44678);
// Embedding size 128
MACE_BM_MATMUL(1, 128, 1536);
MACE_BM_MATMUL(1, 128, 44678);
} // namespace test
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <string>
#include <vector>
#include "mace/core/testing/test_benchmark.h"
namespace mace {
namespace kernels {
namespace test {
// Test the speed of different access order of a NHWC buffer
namespace {
void MemoryAccessBenchmark_NHWC(
int iters, int batch, int height, int width, int channels) {
mace::testing::StopTiming();
std::vector<float> buffer(batch * height * width * channels);
std::fill_n(buffer.begin(), buffer.size(), 0.1);
mace::testing::StartTiming();
while (iters--) {
for (int n = 0; n < batch; ++n) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
buffer[n * height * width * channels + h * width * channels +
w * channels + c] = 1.0f;
}
}
}
}
}
}
void MemoryAccessBenchmark_NWCH(
int iters, int batch, int height, int width, int channels) {
mace::testing::StopTiming();
std::vector<float> buffer(batch * height * width * channels);
std::fill_n(buffer.begin(), buffer.size(), 0.1);
mace::testing::StartTiming();
while (iters--) {
for (int n = 0; n < batch; ++n) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < height; ++h) {
buffer[n * height * width * channels + h * width * channels +
w * channels + c] = 1.0f;
}
}
}
}
}
}
void MemoryAccessBenchmark_NHCW(
int iters, int batch, int height, int width, int channels) {
mace::testing::StopTiming();
std::vector<float> buffer(batch * height * width * channels);
std::fill_n(buffer.begin(), buffer.size(), 0.1);
mace::testing::StartTiming();
while (iters--) {
for (int n = 0; n < batch; ++n) {
for (int h = 0; h < height; ++h) {
for (int c = 0; c < channels; ++c) {
for (int w = 0; w < width; ++w) {
buffer[n * height * width * channels + h * width * channels +
w * channels + c] = 1.0f;
}
}
}
}
}
}
} // namespace
#define MACE_BM_MEMORY_ACCESS(N, H, W, C, ORDER) \
static void MACE_BM_MEMORY_ACCESS_##N##_##H##_##W##_##C##_##ORDER( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot * sizeof(float)); \
MemoryAccessBenchmark_##ORDER(iters, N, H, W, C); \
} \
MACE_BENCHMARK(MACE_BM_MEMORY_ACCESS_##N##_##H##_##W##_##C##_##ORDER)
MACE_BM_MEMORY_ACCESS(10, 64, 64, 1024, NHWC);
MACE_BM_MEMORY_ACCESS(10, 64, 64, 1024, NHCW);
MACE_BM_MEMORY_ACCESS(10, 64, 64, 1024, NWCH);
MACE_BM_MEMORY_ACCESS(10, 64, 1024, 64, NHCW);
MACE_BM_MEMORY_ACCESS(10, 64, 1024, 64, NWCH);
} // namespace test
} // namespace kernels
} // namespace mace
# Description:
# Eigen is a C++ template library for linear algebra: vectors,
# matrices, and related algorithms.
# This file is mostly stolen from tensorflow.
licenses([
# Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
# We've taken special care to not reference any restricted code.
"reciprocal", # MPL2
"notice", # Portions BSD
])
exports_files(["LICENSE"])
cc_library(
name = "eigen3",
hdrs = glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + [
"Eigen/Core",
"Eigen/LU",
"Eigen/Cholesky",
"Eigen/Eigenvalues",
"Eigen/QR",
"Eigen/SVD",
"unsupported/Eigen/SpecialFunctions",
"unsupported/Eigen/CXX11/ThreadPool",
"unsupported/Eigen/CXX11/Tensor",
"unsupported/Eigen/CXX11/FixedPoint",
],
visibility = ["//visibility:public"],
deps = [
"@eigen//:eigen",
],
)
#include "Eigen/Eigenvalues"
此差异已折叠。
# Description:
# Eigen is a C++ template library for linear algebra: vectors,
# matrices, and related algorithms.
# This file is mostly stolen from tensorflow.
licenses([
# Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
# We've taken special care to not reference any restricted code.
"reciprocal", # MPL2
"notice", # Portions BSD
])
exports_files(["COPYING.MPL2"])
# License-restricted (i.e. not reciprocal or notice) files inside Eigen/...
EIGEN_RESTRICTED_FILES = [
"Eigen/src/OrderingMethods/Amd.h",
"Eigen/src/SparseCholesky/**",
]
# Notable transitive dependencies of restricted files inside Eigen/...
EIGEN_RESTRICTED_DEPS = [
"Eigen/Eigen",
"Eigen/IterativeLinearSolvers",
"Eigen/MetisSupport",
"Eigen/Sparse",
"Eigen/SparseCholesky",
"Eigen/SparseLU",
]
# Note: unsupported/Eigen is unsupported and might go away at any time.
EIGEN_FILES = [
"Eigen/**",
"unsupported/Eigen/CXX11/**",
"unsupported/Eigen/FFT",
"unsupported/Eigen/KroneckerProduct",
"unsupported/Eigen/src/FFT/**",
"unsupported/Eigen/src/KroneckerProduct/**",
"unsupported/Eigen/MatrixFunctions",
"unsupported/Eigen/SpecialFunctions",
"unsupported/Eigen/src/SpecialFunctions/**",
]
# List of files picked up by glob but actually part of another target.
EIGEN_EXCLUDE_FILES = [
"Eigen/src/Core/arch/AVX/PacketMathGoogleTest.cc",
]
# Files known to be under MPL2 license.
EIGEN_MPL2_HEADER_FILES = glob(
EIGEN_FILES,
exclude = EIGEN_EXCLUDE_FILES +
EIGEN_RESTRICTED_FILES +
EIGEN_RESTRICTED_DEPS + [
# Guarantees any file missed by excludes above will not compile.
"Eigen/src/Core/util/NonMPL2.h",
"Eigen/**/CMakeLists.txt",
],
)
cc_library(
name = "eigen",
hdrs = EIGEN_MPL2_HEADER_FILES,
defines = [
# This define (mostly) guarantees we don't link any problematic
# code. We use it, but we do not rely on it, as evidenced above.
"EIGEN_MPL2_ONLY",
],
includes = ["."],
visibility = ["//visibility:public"],
)
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2013 Christian Seiler <christian@iwakd.de>
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_CORE_MODULE
#define EIGEN_CXX11_CORE_MODULE
#include <Eigen/Core>
#include <Eigen/src/Core/util/DisableStupidWarnings.h>
/** \defgroup CXX11_Core_Module C++11 Core Module
*
* This module provides common core features for all modules that
* explicitly depend on C++11. Currently, this is only the Tensor
* module. Note that at this stage, you should not need to include
* this module directly.
*
* It also provides a limited fallback for compilers that don't support
* CXX11 yet, such as nvcc.
*
* \code
* #include <Eigen/CXX11/Core>
* \endcode
*/
// Only a subset of cxx11 is allowed at Google, so we default to emulate the
// cxx11 functionality that we need.
#include "src/Core/util/FixedSizeVector.h"
#if 1
#include <vector>
#include "src/Core/util/EmulateCXX11Meta.h"
#else
#include "src/Core/util/CXX11Workarounds.h"
#include "src/Core/util/CXX11Meta.h"
#endif
#include <Eigen/src/Core/util/ReenableStupidWarnings.h>
#endif // EIGEN_CXX11_CORE_MODULE
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_FIXED_POINT_MODULE
#define EIGEN_CXX11_FIXED_POINT_MODULE
#include <Eigen/Core>
#include <stdint.h>
/** \defgroup CXX11_FixedPoint_Module Fixed Point Module
*
* This module provides common core features for all modules that
* explicitly depend on C++11. Currently, this is only the Tensor
* module. Note that at this stage, you should not need to include
* this module directly.
*
* It also provides a limited fallback for compilers that don't support
* CXX11 yet, such as nvcc.
*
* \code
* #include <Eigen/CXX11/FixedPoint>
* \endcode
*/
#include "src/FixedPoint/FixedPointTypes.h"
// Use optimized implementations whenever available
#if defined (EIGEN_VECTORIZE_AVX512DQ) || defined (EIGEN_VECTORIZE_AVX512BW)
#include "src/FixedPoint/PacketMathAVX512.h"
#include "src/FixedPoint/TypeCastingAVX512.h"
#elif defined EIGEN_VECTORIZE_AVX2
#define EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
#define EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT
#include "src/FixedPoint/PacketMathAVX2.h"
#include "src/FixedPoint/MatMatProductAVX2.h"
#include "src/FixedPoint/TypeCastingAVX2.h"
#elif defined EIGEN_VECTORIZE_NEON
#define EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
#include "src/FixedPoint/MatMatProductNEON.h"
#endif
// Use the default implementation when no optimized code is available
#include "src/FixedPoint/MatMatProduct.h"
#include "src/FixedPoint/MatVecProduct.h"
#endif // EIGEN_CXX11_FIXED_POINT_MODULE
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_NEURAL_NETWORKS_MODULE
#define EIGEN_CXX11_NEURAL_NETWORKS_MODULE
#include "unsupported/Eigen/CXX11/Tensor"
/** \defgroup CXX11_NeuralNetworks_Module Neural Networks Module
*
* This module provides an efficient implementation of the common primitives
* used by neural networks.
* The primitives are built on top of the tensor library.
*
* \code
* #include <Eigen/CXX11/NeuralNetworks>
* \endcode
*/
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/Activations.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/Attention.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/Pooling.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/SoftMax.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/BackwardCuboidConvolutions.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/CuboidConvolution.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/BackwardSpatialConvolutions.h"
#include "unsupported/Eigen/CXX11/src/NeuralNetworks/SpatialConvolutions.h"
#endif // EIGEN_CXX11_NEURAL_NETWORKS_MODULE
#include "unsupported/Eigen/CXX11/Tensor"
#ifdef _WIN32
#ifndef SLEEP_FUNC_HEADER_GUARD
#define SLEEP_FUNC_HEADER_GUARD
inline void sleep(unsigned int seconds) { Sleep(1000*seconds); }
#endif
// On Windows, Eigen will include Windows.h, which defines various
// macros that conflict with TensorFlow symbols. Undefine them here to
// prevent clashes.
#undef DeleteFile
#undef ERROR
#undef LoadLibrary
#endif // _WIN32
#include "unsupported/Eigen/CXX11/ThreadPool"
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_FIXED_POINT_TYPES_H
#define EIGEN_CXX11_FIXED_POINT_TYPES_H
#include <cmath>
#include <iostream>
namespace Eigen {
// The mantissa part of the fixed point representation. See
// go/tensorfixedpoint for details
struct QInt8;
struct QUInt8;
struct QInt16;
struct QUInt16;
struct QInt32;
template <>
struct NumTraits<QInt8> : GenericNumTraits<int8_t> {};
template <>
struct NumTraits<QUInt8> : GenericNumTraits<uint8_t> {};
template <>
struct NumTraits<QInt16> : GenericNumTraits<int16_t> {};
template <>
struct NumTraits<QUInt16> : GenericNumTraits<uint16_t> {};
template <>
struct NumTraits<QInt32> : GenericNumTraits<int32_t> {};
namespace internal {
template <>
struct scalar_product_traits<QInt32, double> {
enum {
// Cost = NumTraits<T>::MulCost,
Defined = 1
};
typedef QInt32 ReturnType;
};
}
// Wrap the 8bit int into a QInt8 struct instead of using a typedef to prevent
// the compiler from silently type cast the mantissa into a bigger or a smaller
// representation.
struct QInt8 {
QInt8() {}
QInt8(const int8_t v) : value(v) {}
QInt8(const QInt32 v);
operator int() const { return static_cast<int>(value); }
int8_t value;
};
struct QUInt8 {
QUInt8() {}
QUInt8(const uint8_t v) : value(v) {}
QUInt8(const QInt32 v);
operator int() const { return static_cast<int>(value); }
uint8_t value;
};
struct QInt16 {
QInt16() {}
QInt16(const int16_t v) : value(v) {}
QInt16(const QInt32 v);
operator int() const { return static_cast<int>(value); }
int16_t value;
};
struct QUInt16 {
QUInt16() {}
QUInt16(const uint16_t v) : value(v) {}
QUInt16(const QInt32 v);
operator int() const { return static_cast<int>(value); }
uint16_t value;
};
struct QInt32 {
QInt32() {}
QInt32(const int8_t v) : value(v) {}
QInt32(const int32_t v) : value(v) {}
QInt32(const uint32_t v) : value(static_cast<int32_t>(v)) {}
QInt32(const QInt8 v) : value(v.value) {}
QInt32(const float v) : value(static_cast<int32_t>(lrint(v))) {}
#ifdef EIGEN_MAKING_DOCS
// Workaround to fix build on PPC.
QInt32(unsigned long v) : value(v) {}
#endif
operator float() const { return static_cast<float>(value); }
int32_t value;
};
EIGEN_STRONG_INLINE QInt8::QInt8(const QInt32 v)
: value(v.value > 127 ? 127 : (v.value < -128 ? -128 : v.value)) {}
EIGEN_STRONG_INLINE QUInt8::QUInt8(const QInt32 v)
: value(v.value > 255 ? 255 : (v.value < 0 ? 0 : v.value)) {}
EIGEN_STRONG_INLINE QInt16::QInt16(const QInt32 v)
: value(v.value > 32767 ? 32767 : (v.value < -32768 ? -32768 : v.value)) {}
EIGEN_STRONG_INLINE QUInt16::QUInt16(const QInt32 v)
: value(v.value > 65535 ? 65535 : (v.value < 0 ? 0 : v.value)) {}
// Basic widening 8-bit operations: This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt8 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QUInt8 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt8 b) {
return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt8 b) {
return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value));
}
// Basic widening 16-bit operations: This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt16 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QUInt16 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt16 b) {
return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt16 b) {
return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value));
}
// Mixed QInt32 op QInt8 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt8 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt8 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt8 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Mixed QInt32 op QInt16 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt16 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt16 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt16 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Mixed QInt32 op QUInt8 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt8 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QUInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt8 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QUInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt8 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QUInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Mixed QInt32 op QUInt16 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt16 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QUInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt16 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QUInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt16 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QUInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Basic arithmetic operations on QInt32, which behaves like a int32_t.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt32 b) {
return a.value + b.value;
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt32 b) {
return a.value - b.value;
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt32 b) {
return a.value * b.value;
}
EIGEN_STRONG_INLINE QInt32 operator/(const QInt32 a, const QInt32 b) {
return a.value / b.value;
}
EIGEN_STRONG_INLINE QInt32& operator+=(QInt32& a, const QInt32 b) {
a.value += b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32& operator-=(QInt32& a, const QInt32 b) {
a.value -= b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const QInt32 b) {
a.value *= b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) {
a.value /= b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) {
return -a.value;
}
// Scaling QInt32 by double. We do the arithmetic in double because
// float only has 23 bits of mantissa, so casting QInt32 to float might reduce
// accuracy by discarding up to 7 (least significant) bits.
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const double b) {
return static_cast<int32_t>(lrint(static_cast<double>(a.value) * b));
}
EIGEN_STRONG_INLINE QInt32 operator*(const double a, const QInt32 b) {
return static_cast<int32_t>(lrint(a * static_cast<double>(b.value)));
}
EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const double b) {
a.value = static_cast<int32_t>(lrint(static_cast<double>(a.value) * b));
return a;
}
// Comparisons
EIGEN_STRONG_INLINE bool operator==(const QInt8 a, const QInt8 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QUInt8 a, const QUInt8 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QInt16 a, const QInt16 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QUInt16 a, const QUInt16 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QInt32 a, const QInt32 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QInt8 a, const QInt8 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QUInt8 a, const QUInt8 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QInt16 a, const QInt16 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QUInt16 a, const QUInt16 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QInt32 a, const QInt32 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QInt8 a, const QInt8 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QUInt8 a, const QUInt8 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QInt16 a, const QInt16 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QUInt16 a, const QUInt16 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QInt32 a, const QInt32 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt8 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt8 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt16 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt16 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt32 a) {
os << a.value;
return os;
}
} // namespace Eigen
#endif // EIGEN_CXX11_FIXED_POINT_TYPES_H
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H
#define EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H
namespace Eigen {
namespace internal {
// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent
// overflows
template<> struct scalar_product_traits<QInt8, QInt8>
{
enum {
Defined = 1
};
typedef QInt32 ReturnType;
};
// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits
// to prevent overflows
template<> struct scalar_product_traits<QInt8, QUInt8>
{
enum {
Defined = 1
};
typedef QInt32 ReturnType;
};
// Description of the product implementation. It's pretty simple now since
// nothing is vectorized yet.
// This definition tackle the case where both lhs and rhs are encoded using
// signed 8bit integers
#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT
template<bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs>
{
public:
typedef QInt8 LhsScalar;
typedef QInt8 RhsScalar;
typedef QInt32 ResScalar;
enum {
// register block size along the M and N directions
// One for the current implementation
nr = 1,
mr = 1,
// Progress made at each iteration of the product loop
// also 1 for the current implementation
LhsProgress = 1,
RhsProgress = 1
};
};
// The signed 8bit Mat-Mat product itself.
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
{
EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
Index rows, Index depth, Index cols, QInt32 alpha,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
};
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE
void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
::operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
Index rows, Index depth, Index cols, QInt32 alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(alpha.value == 1);
eigen_assert(strideA == -1);
eigen_assert(strideB == -1);
eigen_assert(offsetA == 0);
eigen_assert(offsetB == 0);
eigen_assert(rows > 0);
eigen_assert(cols > 0);
eigen_assert(depth > 0);
eigen_assert(blockA);
eigen_assert(blockB);
for (Index j = 0; j < cols; ++j) {
Index startB = j * depth;
for (Index i = 0; i < rows; ++i) {
Index startA = i * depth;
for (Index k = 0; k < depth; ++k) {
res(i, j) += blockA[startA + k] * blockB[startB + k];
}
}
}
}
#endif
// This definition tackle the case where the lhs is encoded using signed 8bit
// integers and the rhs using unsigned 8bit integers.
#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
template<bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
{
public:
typedef QInt8 LhsScalar;
typedef QUInt8 RhsScalar;
typedef QInt32 ResScalar;
enum {
// register block size along the M and N directions
// One for the current implementation
nr = 1,
mr = 1,
// Progress made at each iteration of the product loop
// also 1 for the current implementation
LhsProgress = 1,
RhsProgress = 1
};
};
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
{
EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
Index rows, Index depth, Index cols, QInt32 alpha,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
};
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE
void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
Index rows, Index depth, Index cols, QInt32 alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(alpha.value == 1);
eigen_assert(strideA == -1);
eigen_assert(strideB == -1);
eigen_assert(offsetA == 0);
eigen_assert(offsetB == 0);
eigen_assert(rows > 0);
eigen_assert(cols > 0);
eigen_assert(depth > 0);
eigen_assert(blockA);
eigen_assert(blockB);
for (Index j = 0; j < cols; ++j) {
Index startB = j * depth;
for (Index i = 0; i < rows; ++i) {
Index startA = i * depth;
for (Index k = 0; k < depth; ++k) {
res(i, j) += blockA[startA + k] * blockB[startB + k];
}
}
}
}
#endif
// This definition tackle the case where the khs is encoded using unsigned 8bit
// integers and the rhs using signed 8bit integers.
#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT
template<bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs>
{
public:
typedef QUInt8 LhsScalar;
typedef QInt8 RhsScalar;
typedef QInt32 ResScalar;
enum {
// register block size along the M and N directions
// One for the current implementation
nr = 1,
mr = 1,
// Progress made at each iteration of the product loop
// also 1 for the current implementation
LhsProgress = 1,
RhsProgress = 1
};
};
// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
{
EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
Index rows, Index depth, Index cols, QInt32 alpha,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
};
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE
void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
::operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
Index rows, Index depth, Index cols, QInt32 alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(alpha.value == 1);
eigen_assert(strideA == -1);
eigen_assert(strideB == -1);
eigen_assert(offsetA == 0);
eigen_assert(offsetB == 0);
eigen_assert(rows > 0);
eigen_assert(cols > 0);
eigen_assert(depth > 0);
eigen_assert(blockA);
eigen_assert(blockB);
for (Index j = 0; j < cols; ++j) {
Index startB = j * depth;
for (Index i = 0; i < rows; ++i) {
Index startA = i * depth;
for (Index k = 0; k < depth; ++k) {
res(i, j) += blockA[startA + k] * blockB[startB + k];
}
}
}
}
#endif
} // namespace internal
} // namespace Eigen
#endif // EIGEN_CXX11_FIXED_POINT_MAT_MAT_PRODUCT_H
#include "unsupported/Eigen/SpecialFunctions"
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册