“24509f4af942bb250564756ad636691c7921e1df”上不存在“paddle/fluid/framework/unused_var_check.h”
提交 83c88e62 编写于 作者: L liuqi

Add Gemm op.

上级 cdd81997
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_GEMM_H_
#define MACE_KERNELS_GEMM_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct GEMMFunctor {
void operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
StatsFuture *future) {
std::vector<index_t> c_shape = {A->dim(0), A->dim(1), 1, B->dim(3)};
C->Resize(c_shape);
const index_t N = C->dim(0);
const index_t height = C->dim(1);
const index_t width = C->dim(3);
const index_t K = A->dim(3);
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
const T *a_ptr_base = A->data<T>();
const T *b_ptr_base = B->data<T>();
T *c_ptr = C->mutable_data<T>();
for (int i = 0; i < N; ++i) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
const T *a_ptr = a_ptr_base + h * K;
const T *b_ptr = b_ptr_base + w;
*c_ptr = 0;
for (int k = 0; k < K; ++k) {
*c_ptr += *a_ptr * *b_ptr;
a_ptr++;
b_ptr += width;
}
c_ptr++;
}
}
a_ptr_base += height * K;
b_ptr_base += K * width;
}
}
};
template <typename T>
struct GEMMFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
StatsFuture *future);
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_GEMM_H_
#include <common.h>
// C = A * B
__kernel void gemm(__read_only image2d_t A,
__read_only image2d_t B,
__write_only image2d_t C,
__private const int M,
__private const int height_blocks,
__private const int K) {
const int gx = get_global_id(0);
const int hb = get_global_id(1);
const int batch = hb / height_blocks;
const int gy = (hb % height_blocks) << 2;
const int bm = mul24(batch, M);
const int bk = mul24(batch, K);
float4 a0, a1, a2, a3;
float4 b0, b1, b2, b3;
float4 c0, c1, c2, c3;
for (short pos = 0; pos < K; pos += 4) {
a0 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy)));
a1 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 1)));
a2 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 2)));
a3 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 3)));
b0 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos)));
b1 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 1)));
b2 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 2)));
b3 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 3)));
c0 = mad(a0.x, b0, c0);
c0 = mad(a0.y, b1, c0);
c0 = mad(a0.z, b2, c0);
c0 = mad(a0.w, b3, c0);
c1 = mad(a1.x, b0, c1);
c1 = mad(a1.y, b1, c1);
c1 = mad(a1.z, b2, c1);
c1 = mad(a1.w, b3, c1);
c2 = mad(a2.x, b0, c2);
c2 = mad(a2.y, b1, c2);
c2 = mad(a2.z, b2, c2);
c2 = mad(a2.w, b3, c2);
c3 = mad(a3.x, b0, c3);
c3 = mad(a3.y, b1, c3);
c3 = mad(a3.z, b2, c3);
c3 = mad(a3.w, b3, c3);
}
if (gy >= M) return;
WRITE_IMAGET(C, (int2)(gx, (bm + gy)), c0);
if ((gy + 1) >= M) return;
WRITE_IMAGET(C, (int2)(gx, (bm + gy + 1)), c1);
if ((gy + 2) >= M) return;
WRITE_IMAGET(C, (int2)(gx, (bm + gy + 2)), c2);
if ((gy + 3) >= M) return;
WRITE_IMAGET(C, (int2)(gx, (bm + gy + 3)), c3);
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/gemm.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void GEMMFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *A,
const Tensor *B,
Tensor *C,
StatsFuture *future) {
std::vector<index_t> c_shape = {A->dim(0), A->dim(1), 1, B->dim(3)};
std::vector<size_t> c_image_shape;
CalImage2DShape(c_shape, BufferType::IN_OUT, c_image_shape);
C->ResizeImage(c_shape, c_image_shape);
const index_t batch = C->dim(0);
const index_t height = C->dim(1);
const index_t width = C->dim(3);
const index_t width_blocks = RoundUpDiv4(width);
const index_t height_blocks = RoundUpDiv4(height);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("gemm");
built_options.emplace("-Dgemm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto gemm_kernel = runtime->BuildKernel("gemm", kernel_name, built_options);
uint32_t idx = 0;
gemm_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(A->buffer())));
gemm_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(B->buffer())));
gemm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(C->buffer())));
gemm_kernel.setArg(idx++, static_cast<int>(height));
gemm_kernel.setArg(idx++, static_cast<int>(height_blocks));
gemm_kernel.setArg(idx++, static_cast<int>(A->dim(3)));
const uint32_t gws[3] = {
static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height_blocks * batch),
};
const std::vector<uint32_t> lws = {16, 64};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(gemm_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(2, 0);
local_ws[0] = std::min<uint32_t>(width_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(height_blocks * batch, kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1]},
{local_ws[1], local_ws[0]},
{kwg_size / 4, 4},
{kwg_size / 16, 16},
{kwg_size / 32, 32},
{kwg_size / 64, 64},
{kwg_size / 128, 128},
{kwg_size / 256, 256},
{kwg_size / 512, 512},
{kwg_size, 1},
{1, kwg_size}
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
gemm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "gemm_opencl_kernel_"
<< C->dim(0) << "_"
<< C->dim(1) << "_"
<< C->dim(2) << "_"
<< C->dim(3);
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
};
template
struct GEMMFunctor<DeviceType::OPENCL, float>;
template
struct GEMMFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/gemm.h"
namespace mace {
void Register_GEMM(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
GEMMOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
GEMMOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
GEMMOp<DeviceType::OPENCL, half>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_GEMM_H_
#define MACE_OPS_GEMM_H_
#include "mace/core/operator.h"
#include "mace/kernels/gemm.h"
namespace mace {
template <DeviceType D, class T>
class GEMMOp : public Operator<D, T> {
public:
GEMMOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
bool Run(StatsFuture *future) override {
const Tensor *A = this->Input(0);
const Tensor *B = this->Input(1);
Tensor *C = this->Output(0);
MACE_CHECK(A->dim_size() == 4 && 4 == B->dim_size())
<< "The dimension of A and B should be 4";
MACE_CHECK(A->dim(0) == B->dim(0)) << "A and B must have same batch size";
MACE_CHECK(A->dim(3) == B->dim(1))
<< "the number of A's column must be equal to B's row";
functor_(A, B, C, future);
return true;
}
private:
kernels::GEMMFunctor<D, T> functor_;
};
} // namespace mace
#endif // MACE_OPS_GEMM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void GEMMBenchmark(
int iters, int batch, int height, int channels, int out_width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("A", {batch, height, 1, channels});
net.AddRandomInput<D, float>("B", {batch, channels, 1, out_width});
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "A", "AImage",
kernels::BufferType::IN_OUT);
BufferToImage<D, T>(net, "B", "BImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("GEMM", "GEMMBM")
.Input("AImage")
.Input("BImage")
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("GEMM", "GEMMBM")
.Input("A")
.Input("B")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_GEMM_MACRO(N, H, C, W, TYPE, DEVICE) \
static void BM_GEMM_##N##H##C##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
GEMMBenchmark<DEVICE, TYPE>(iters, N, H, C, W); \
} \
BENCHMARK(BM_GEMM_##N##H##C##W##_##TYPE##_##DEVICE)
#define BM_GEMM(N, H, C, W, TYPE) \
BM_GEMM_MACRO(N, H, C, W, TYPE, OPENCL);
BM_GEMM(16, 32, 128, 1024, half);
BM_GEMM(36, 32, 128, 256, half);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <fstream>
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
class GEMMOpTest : public OpsTestBase {};
template<DeviceType D>
void Simple(const std::vector<index_t> &A_shape,
const std::vector<float> &A_value,
const std::vector<index_t> &B_shape,
const std::vector<float> &B_value,
const std::vector<index_t> &C_shape,
const std::vector<float> &C_value) {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("A", A_shape, A_value);
net.AddInputFromArray<D, float>("B", B_shape, B_value);
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "A", "AImage",
kernels::BufferType::IN_OUT);
BufferToImage<D, float>(net, "B", "BImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("GEMM", "GEMMTest")
.Input("AImage")
.Input("BImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("GEMM", "GEMMTest")
.Input("A")
.Input("B")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check
auto expected =
CreateTensor<float>(C_shape, C_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(GEMMOpTest, SimpleCPU) {
Simple<DeviceType::CPU>({1, 2, 1, 3}, {1, 2, 3, 4, 5, 6},
{1, 3, 1, 2}, {1, 2, 3, 4, 5, 6},
{1, 2, 1, 2}, {22, 28, 49, 64});
Simple<DeviceType::CPU>({1, 5, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 1, 5},
{215, 230, 245, 260, 275, 490, 530, 570, 610, 650,
765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400,
1315, 1430, 1545, 1660, 1775});
}
TEST_F(GEMMOpTest, SimpleOPENCL) {
Simple<DeviceType::OPENCL>({1, 2, 1, 3}, {1, 2, 3, 4, 5, 6},
{1, 3, 1, 2}, {1, 2, 3, 4, 5, 6},
{1, 2, 1, 2}, {22, 28, 49, 64});
Simple<DeviceType::OPENCL>({1, 5, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 1, 5},
{215, 230, 245, 260, 275, 490, 530, 570, 610, 650,
765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400,
1315, 1430, 1545, 1660, 1775});
}
template <typename T>
void Complex(const index_t batch,
const index_t height,
const index_t channels,
const index_t out_width) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
OpDefBuilder("GEMM", "GEMMTest")
.Input("A")
.Input("B")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"A", {batch, height, 1, channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"B", {batch, channels, 1, out_width});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, T>(net, "A", "AImage",
kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, T>(net, "B", "BImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("GEMM", "GEMMTest")
.Input("AImage")
.Input("BImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
net.Sync();
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT);
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-1);
} else {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-4);
}
}
TEST_F(GEMMOpTest, OPENCLAlignedWithoutBatch) {
Complex<float>(1, 64, 128, 32);
Complex<float>(1, 64, 32, 128);
}
TEST_F(GEMMOpTest, OPENCLUnAlignedWithoutBatch) {
Complex<float>(1, 31, 113, 61);
Complex<float>(1, 113, 31, 73);
}
TEST_F(GEMMOpTest, OPENCLUnAlignedWithBatch) {
Complex<float>(2, 31, 113, 61);
Complex<float>(16, 32, 64, 64);
Complex<float>(31, 31, 61, 67);
}
TEST_F(GEMMOpTest, OPENCLHalfAlignedWithoutBatch) {
Complex<half>(1, 64, 128, 32);
Complex<half>(1, 64, 32, 128);
}
TEST_F(GEMMOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(2, 31, 113, 61);
Complex<half>(16, 32, 64, 64);
Complex<half>(31, 31, 61, 67);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册