diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..94cd2bdc38b48e61a97eb807dbadb1297dd2d3de --- /dev/null +++ b/mace/kernels/gemm.h @@ -0,0 +1,66 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_GEMM_H_ +#define MACE_KERNELS_GEMM_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + + +template +struct GEMMFunctor { + void operator()(const Tensor *A, + const Tensor *B, + Tensor *C, + StatsFuture *future) { + + std::vector c_shape = {A->dim(0), A->dim(1), 1, B->dim(3)}; + C->Resize(c_shape); + const index_t N = C->dim(0); + const index_t height = C->dim(1); + const index_t width = C->dim(3); + const index_t K = A->dim(3); + Tensor::MappingGuard guarda(A); + Tensor::MappingGuard guardb(B); + Tensor::MappingGuard guardc(C); + const T *a_ptr_base = A->data(); + const T *b_ptr_base = B->data(); + T *c_ptr = C->mutable_data(); + for (int i = 0; i < N; ++i) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + const T *a_ptr = a_ptr_base + h * K; + const T *b_ptr = b_ptr_base + w; + *c_ptr = 0; + for (int k = 0; k < K; ++k) { + *c_ptr += *a_ptr * *b_ptr; + a_ptr++; + b_ptr += width; + } + c_ptr++; + } + } + a_ptr_base += height * K; + b_ptr_base += K * width; + } + } +}; + + +template +struct GEMMFunctor { + void operator()(const Tensor *A, + const Tensor *B, + Tensor *C, + StatsFuture *future); +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_GEMM_H_ diff --git a/mace/kernels/opencl/cl/gemm.cl b/mace/kernels/opencl/cl/gemm.cl new file mode 100644 index 0000000000000000000000000000000000000000..994a190a180272b2633f91edc4eeae2b3df487a8 --- /dev/null +++ b/mace/kernels/opencl/cl/gemm.cl @@ -0,0 +1,60 @@ +#include + +// C = A * B +__kernel void gemm(__read_only image2d_t A, + __read_only image2d_t B, + __write_only image2d_t C, + __private const int M, + __private const int height_blocks, + __private const int K) { + const int gx = get_global_id(0); + const int hb = get_global_id(1); + const int batch = hb / height_blocks; + const int gy = (hb % height_blocks) << 2; + const int bm = mul24(batch, M); + const int bk = mul24(batch, K); + + float4 a0, a1, a2, a3; + float4 b0, b1, b2, b3; + float4 c0, c1, c2, c3; + + for (short pos = 0; pos < K; pos += 4) { + a0 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy))); + a1 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 1))); + a2 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 2))); + a3 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 3))); + + b0 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos))); + b1 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 1))); + b2 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 2))); + b3 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 3))); + + c0 = mad(a0.x, b0, c0); + c0 = mad(a0.y, b1, c0); + c0 = mad(a0.z, b2, c0); + c0 = mad(a0.w, b3, c0); + + c1 = mad(a1.x, b0, c1); + c1 = mad(a1.y, b1, c1); + c1 = mad(a1.z, b2, c1); + c1 = mad(a1.w, b3, c1); + + c2 = mad(a2.x, b0, c2); + c2 = mad(a2.y, b1, c2); + c2 = mad(a2.z, b2, c2); + c2 = mad(a2.w, b3, c2); + + c3 = mad(a3.x, b0, c3); + c3 = mad(a3.y, b1, c3); + c3 = mad(a3.z, b2, c3); + c3 = mad(a3.w, b3, c3); + } + if (gy >= M) return; + WRITE_IMAGET(C, (int2)(gx, (bm + gy)), c0); + if ((gy + 1) >= M) return; + WRITE_IMAGET(C, (int2)(gx, (bm + gy + 1)), c1); + if ((gy + 2) >= M) return; + WRITE_IMAGET(C, (int2)(gx, (bm + gy + 2)), c2); + if ((gy + 3) >= M) return; + WRITE_IMAGET(C, (int2)(gx, (bm + gy + 3)), c3); +} diff --git a/mace/kernels/opencl/gemm.cc b/mace/kernels/opencl/gemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..934dea79ec172cbaff3bb8e2214771efb232165e --- /dev/null +++ b/mace/kernels/opencl/gemm.cc @@ -0,0 +1,115 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/gemm.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void GEMMFunctor::operator()( + const Tensor *A, + const Tensor *B, + Tensor *C, + StatsFuture *future) { + + std::vector c_shape = {A->dim(0), A->dim(1), 1, B->dim(3)}; + std::vector c_image_shape; + CalImage2DShape(c_shape, BufferType::IN_OUT, c_image_shape); + C->ResizeImage(c_shape, c_image_shape); + + const index_t batch = C->dim(0); + const index_t height = C->dim(1); + const index_t width = C->dim(3); + + const index_t width_blocks = RoundUpDiv4(width); + const index_t height_blocks = RoundUpDiv4(height); + + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("gemm"); + built_options.emplace("-Dgemm=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + auto gemm_kernel = runtime->BuildKernel("gemm", kernel_name, built_options); + + uint32_t idx = 0; + gemm_kernel.setArg(idx++, + *(static_cast(A->buffer()))); + gemm_kernel.setArg(idx++, + *(static_cast(B->buffer()))); + gemm_kernel.setArg(idx++, *(static_cast(C->buffer()))); + gemm_kernel.setArg(idx++, static_cast(height)); + gemm_kernel.setArg(idx++, static_cast(height_blocks)); + gemm_kernel.setArg(idx++, static_cast(A->dim(3))); + + const uint32_t gws[3] = { + static_cast(width_blocks), + static_cast(height_blocks * batch), + }; + const std::vector lws = {16, 64}; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(gemm_kernel); + auto params_generator = [&]()->std::vector> { + std::vector local_ws(2, 0); + local_ws[0] = std::min(width_blocks, kwg_size); + local_ws[1] = std::min(height_blocks * batch, kwg_size / local_ws[0]); + return {{local_ws[0], local_ws[1]}, + {local_ws[1], local_ws[0]}, + {kwg_size / 4, 4}, + {kwg_size / 16, 16}, + {kwg_size / 32, 32}, + {kwg_size / 64, 64}, + {kwg_size / 128, 128}, + {kwg_size / 256, 256}, + {kwg_size / 512, 512}, + {kwg_size, 1}, + {1, kwg_size} + }; + }; + cl::Event event; + auto func = [&](const std::vector& params)->cl_int { + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + gemm_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1]), + cl::NDRange(params[0], params[1]), + nullptr, &event); + + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + return error; + }; + std::stringstream ss; + ss << "gemm_opencl_kernel_" + << C->dim(0) << "_" + << C->dim(1) << "_" + << C->dim(2) << "_" + << C->dim(3); + OpenCLProfilingTimer timer(&event); + Tuner::Get()->template TuneOrRun(ss.str(), + lws, + params_generator, + func, + &timer); + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } + +}; + +template +struct GEMMFunctor; + +template +struct GEMMFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/ops/gemm.cc b/mace/ops/gemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..8dfc83a34244b652781491f9fbc93a0ac9486d9c --- /dev/null +++ b/mace/ops/gemm.cc @@ -0,0 +1,29 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/gemm.h" + +namespace mace { + +void Register_GEMM(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + GEMMOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + GEMMOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + GEMMOp); +} + +} // namespace mace diff --git a/mace/ops/gemm.h b/mace/ops/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..78d7e3b1d4c6e254b9963e7a38a663e2d7e40cd3 --- /dev/null +++ b/mace/ops/gemm.h @@ -0,0 +1,39 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_GEMM_H_ +#define MACE_OPS_GEMM_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/gemm.h" + +namespace mace { + +template +class GEMMOp : public Operator { + public: + GEMMOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) {} + + bool Run(StatsFuture *future) override { + const Tensor *A = this->Input(0); + const Tensor *B = this->Input(1); + Tensor *C = this->Output(0); + MACE_CHECK(A->dim_size() == 4 && 4 == B->dim_size()) + << "The dimension of A and B should be 4"; + MACE_CHECK(A->dim(0) == B->dim(0)) << "A and B must have same batch size"; + MACE_CHECK(A->dim(3) == B->dim(1)) + << "the number of A's column must be equal to B's row"; + + functor_(A, B, C, future); + return true; + } + + private: + kernels::GEMMFunctor functor_; +}; + +} // namespace mace + +#endif // MACE_OPS_GEMM_H_ diff --git a/mace/ops/gemm_benchmark.cc b/mace/ops/gemm_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..76dcc02abbf3abf939aac47f7376d786e2e835e2 --- /dev/null +++ b/mace/ops/gemm_benchmark.cc @@ -0,0 +1,69 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void GEMMBenchmark( + int iters, int batch, int height, int channels, int out_width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("A", {batch, height, 1, channels}); + net.AddRandomInput("B", {batch, channels, 1, out_width}); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "A", "AImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "B", "BImage", + kernels::BufferType::IN_OUT); + + OpDefBuilder("GEMM", "GEMMBM") + .Input("AImage") + .Input("BImage") + .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("GEMM", "GEMMBM") + .Input("A") + .Input("B") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_GEMM_MACRO(N, H, C, W, TYPE, DEVICE) \ + static void BM_GEMM_##N##H##C##W##_##TYPE##_##DEVICE(int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + GEMMBenchmark(iters, N, H, C, W); \ + } \ + BENCHMARK(BM_GEMM_##N##H##C##W##_##TYPE##_##DEVICE) + +#define BM_GEMM(N, H, C, W, TYPE) \ + BM_GEMM_MACRO(N, H, C, W, TYPE, OPENCL); + +BM_GEMM(16, 32, 128, 1024, half); +BM_GEMM(36, 32, 128, 256, half); +} // namespace mace diff --git a/mace/ops/gemm_test.cc b/mace/ops/gemm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1d9488995090293ca01ddf966f324ed3b594191 --- /dev/null +++ b/mace/ops/gemm_test.cc @@ -0,0 +1,169 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { + +class GEMMOpTest : public OpsTestBase {}; + +template +void Simple(const std::vector &A_shape, + const std::vector &A_value, + const std::vector &B_shape, + const std::vector &B_value, + const std::vector &C_shape, + const std::vector &C_value) { + OpsTestNet net; + + // Add input data + net.AddInputFromArray("A", A_shape, A_value); + net.AddInputFromArray("B", B_shape, B_value); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "A", "AImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "B", "BImage", + kernels::BufferType::IN_OUT); + + OpDefBuilder("GEMM", "GEMMTest") + .Input("AImage") + .Input("BImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "Output", + kernels::BufferType::IN_OUT); + } else { + OpDefBuilder("GEMM", "GEMMTest") + .Input("A") + .Input("B") + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } + + // Check + auto expected = + CreateTensor(C_shape, C_value); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(GEMMOpTest, SimpleCPU) { + Simple({1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, + {1, 3, 1, 2}, {1, 2, 3, 4, 5, 6}, + {1, 2, 1, 2}, {22, 28, 49, 64}); + Simple({1, 5, 1, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 1, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 1, 5}, + {215, 230, 245, 260, 275, 490, 530, 570, 610, 650, + 765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400, + 1315, 1430, 1545, 1660, 1775}); +} + +TEST_F(GEMMOpTest, SimpleOPENCL) { + Simple({1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, + {1, 3, 1, 2}, {1, 2, 3, 4, 5, 6}, + {1, 2, 1, 2}, {22, 28, 49, 64}); + Simple({1, 5, 1, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 1, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 1, 5}, + {215, 230, 245, 260, 275, 490, 530, 570, 610, 650, + 765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400, + 1315, 1430, 1545, 1660, 1775}); +} + +template +void Complex(const index_t batch, + const index_t height, + const index_t channels, + const index_t out_width) { + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + OpDefBuilder("GEMM", "GEMMTest") + .Input("A") + .Input("B") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "A", {batch, height, 1, channels}); + net.AddRandomInput( + "B", {batch, channels, 1, out_width}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "A", "AImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "B", "BImage", + kernels::BufferType::IN_OUT); + + OpDefBuilder("GEMM", "GEMMTest") + .Input("AImage") + .Input("BImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + net.Sync(); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT); + if (DataTypeToEnum::value == DataType::DT_HALF) { + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-1); + } else { + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-4); + } +} + +TEST_F(GEMMOpTest, OPENCLAlignedWithoutBatch) { + Complex(1, 64, 128, 32); + Complex(1, 64, 32, 128); +} +TEST_F(GEMMOpTest, OPENCLUnAlignedWithoutBatch) { + Complex(1, 31, 113, 61); + Complex(1, 113, 31, 73); +} +TEST_F(GEMMOpTest, OPENCLUnAlignedWithBatch) { + Complex(2, 31, 113, 61); + Complex(16, 32, 64, 64); + Complex(31, 31, 61, 67); +} +TEST_F(GEMMOpTest, OPENCLHalfAlignedWithoutBatch) { + Complex(1, 64, 128, 32); + Complex(1, 64, 32, 128); +} +TEST_F(GEMMOpTest, OPENCLHalfUnAlignedWithBatch) { + Complex(2, 31, 113, 61); + Complex(16, 32, 64, 64); + Complex(31, 31, 61, 67); +} + +}