提交 08533627 编写于 作者: L liuqi

Rebase winograd convolution code and rename gemm to matmul.

上级 bcd624f8
...@@ -77,7 +77,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry); ...@@ -77,7 +77,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_GEMM(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
...@@ -100,7 +100,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -100,7 +100,7 @@ OperatorRegistry::OperatorRegistry() {
Register_ResizeBilinear(this); Register_ResizeBilinear(this);
Register_Softmax(this); Register_Softmax(this);
Register_SpaceToBatchND(this); Register_SpaceToBatchND(this);
Register_GEMM(this); Register_MatMul(this);
Register_WinogradTransform(this); Register_WinogradTransform(this);
Register_WinogradInverseTransform(this); Register_WinogradInverseTransform(this);
} }
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#ifndef MACE_KERNELS_GEMM_H_ #ifndef MACE_KERNELS_MATMUL_H_
#define MACE_KERNELS_GEMM_H_ #define MACE_KERNELS_MATMUL_H_
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
...@@ -13,7 +13,7 @@ namespace kernels { ...@@ -13,7 +13,7 @@ namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct GEMMFunctor { struct MatMulFunctor {
void operator()(const Tensor *A, void operator()(const Tensor *A,
const Tensor *B, const Tensor *B,
Tensor *C, Tensor *C,
...@@ -53,7 +53,7 @@ struct GEMMFunctor { ...@@ -53,7 +53,7 @@ struct GEMMFunctor {
template <typename T> template <typename T>
struct GEMMFunctor<DeviceType::OPENCL, T> { struct MatMulFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *A, void operator()(const Tensor *A,
const Tensor *B, const Tensor *B,
Tensor *C, Tensor *C,
...@@ -63,4 +63,4 @@ struct GEMMFunctor<DeviceType::OPENCL, T> { ...@@ -63,4 +63,4 @@ struct GEMMFunctor<DeviceType::OPENCL, T> {
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_GEMM_H_ #endif // MACE_KERNELS_MATMUL_H_
...@@ -63,7 +63,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -63,7 +63,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1), Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1),
output->dim(2), output->dim(3)); output->dim(2), output->dim(3));
......
...@@ -48,7 +48,7 @@ static void AddN(const std::vector<const Tensor *> &input_tensors, ...@@ -48,7 +48,7 @@ static void AddN(const std::vector<const Tensor *> &input_tensors,
static_cast<uint32_t>(width_pixels), static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels) static_cast<uint32_t>(batch_height_pixels)
}; };
std::vector<uint32_t> lws = {64, 16, 1}; const std::vector<uint32_t> lws = {64, 16, 1};
std::stringstream ss; std::stringstream ss;
ss << "addn_opencl_kernel_" ss << "addn_opencl_kernel_"
<< output->dim(0) << "_" << output->dim(0) << "_"
......
...@@ -83,7 +83,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -83,7 +83,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat("batch_norm_opencl_kernel_", activation_, output->dim(0), Concat("batch_norm_opencl_kernel_", activation_, output->dim(0),
output->dim(1), output->dim(2), output->dim(3), folded_constant_); output->dim(1), output->dim(2), output->dim(3), folded_constant_);
......
#include <common.h> #include <common.h>
// C = A * B // C = A * B
__kernel void gemm(__read_only image2d_t A, __kernel void matmul(__read_only image2d_t A,
__read_only image2d_t B, __read_only image2d_t B,
__write_only image2d_t C, __write_only image2d_t C,
__private const int M, __private const int M,
__private const int N, __private const int N,
__private const int K, __private const int K,
__private const int height_blocks, __private const int height_blocks,
__private const int k_blocks) { __private const int k_blocks) {
const int gx = get_global_id(0) << 2; const int gx = get_global_id(0) << 2;
const int hb = get_global_id(1); const int hb = get_global_id(1);
const int batch = hb / height_blocks; const int batch = hb / height_blocks;
...@@ -17,7 +17,6 @@ __kernel void gemm(__read_only image2d_t A, ...@@ -17,7 +17,6 @@ __kernel void gemm(__read_only image2d_t A,
const int bm = mad24(batch, M, ty << 2); const int bm = mad24(batch, M, ty << 2);
const int bk = mul24(batch, k_blocks); const int bk = mul24(batch, k_blocks);
float4 a0, a1, a2, a3; float4 a0, a1, a2, a3;
float4 b0, b1, b2, b3; float4 b0, b1, b2, b3;
float4 c0 = 0, c1 = 0, c2 = 0, c3 = 0; float4 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
......
...@@ -50,7 +50,7 @@ static void Concat2(const Tensor *input0, ...@@ -50,7 +50,7 @@ static void Concat2(const Tensor *input0,
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
static_cast<uint32_t>(batch * height), static_cast<uint32_t>(batch * height),
}; };
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss; std::stringstream ss;
ss << "concat_opencl_kernel_" ss << "concat_opencl_kernel_"
<< output->dim(0) << "_" << output->dim(0) << "_"
......
...@@ -96,7 +96,7 @@ void Conv1x1(const Tensor *input, ...@@ -96,7 +96,7 @@ void Conv1x1(const Tensor *input,
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {8, 15, 8, 1}; const std::vector<uint32_t> lws = {8, 15, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat("conv2d_1x1_opencl_kernel_", activation, output->dim(0), Concat("conv2d_1x1_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3)); output->dim(1), output->dim(2), output->dim(3));
......
...@@ -94,7 +94,7 @@ static void Conv2d3x3S12(const Tensor *input, ...@@ -94,7 +94,7 @@ static void Conv2d3x3S12(const Tensor *input,
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {4, 15, 8, 1}; const std::vector<uint32_t> lws = {4, 15, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat("conv2d_3x3_opencl_kernel_", activation, output->dim(0), Concat("conv2d_3x3_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3)); output->dim(1), output->dim(2), output->dim(3));
......
...@@ -97,7 +97,7 @@ void Conv2dOpencl(const Tensor *input, ...@@ -97,7 +97,7 @@ void Conv2dOpencl(const Tensor *input,
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat("conv2d_general_opencl_kernel_", activation, output->dim(0), Concat("conv2d_general_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3)); output->dim(1), output->dim(2), output->dim(3));
......
...@@ -106,7 +106,7 @@ void DepthwiseConv2d(const Tensor *input, // NHWC ...@@ -106,7 +106,7 @@ void DepthwiseConv2d(const Tensor *input, // NHWC
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = Concat("depthwise_conv2d_ocl_kernel_", activation, std::string tuning_key = Concat("depthwise_conv2d_ocl_kernel_", activation,
batch, height, width, channels, multiplier); batch, height, width, channels, multiplier);
TuningOrRun3DKernel(dw_conv2d_kernel, tuning_key, gws, lws, future); TuningOrRun3DKernel(dw_conv2d_kernel, tuning_key, gws, lws, future);
...@@ -150,7 +150,7 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -150,7 +150,7 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
padding_, output_shape.data(), paddings.data()); padding_, output_shape.data(), paddings.data());
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape); output->ResizeImage(output_shape, output_image_shape);
DepthwiseConv2d(input, filter, bias, strides_[0], paddings.data(), dilations_, DepthwiseConv2d(input, filter, bias, strides_[0], paddings.data(), dilations_,
......
...@@ -158,7 +158,7 @@ std::string DtToUpstreamCLCMDDt(const DataType dt) { ...@@ -158,7 +158,7 @@ std::string DtToUpstreamCLCMDDt(const DataType dt) {
void TuningOrRun3DKernel(cl::Kernel &kernel, void TuningOrRun3DKernel(cl::Kernel &kernel,
const std::string tuning_key, const std::string tuning_key,
const uint32_t *gws, const uint32_t *gws,
std::vector<uint32_t> &lws, const std::vector<uint32_t> &lws,
StatsFuture *future) { StatsFuture *future) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel);
...@@ -255,7 +255,7 @@ void TuningOrRun3DKernel(cl::Kernel &kernel, ...@@ -255,7 +255,7 @@ void TuningOrRun3DKernel(cl::Kernel &kernel,
void TuningOrRun2DKernel(cl::Kernel &kernel, void TuningOrRun2DKernel(cl::Kernel &kernel,
const std::string tuning_key, const std::string tuning_key,
const uint32_t *gws, const uint32_t *gws,
std::vector<uint32_t> &lws, const std::vector<uint32_t> &lws,
StatsFuture *future) { StatsFuture *future) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel);
......
...@@ -44,14 +44,14 @@ std::string DtToUpstreamCLDt(const DataType dt); ...@@ -44,14 +44,14 @@ std::string DtToUpstreamCLDt(const DataType dt);
void TuningOrRun3DKernel(cl::Kernel &kernel, void TuningOrRun3DKernel(cl::Kernel &kernel,
const std::string tuning_key, const std::string tuning_key,
const uint32_t *gws, const uint32_t *gws,
std::vector<uint32_t> &lws, const std::vector<uint32_t> &lws,
StatsFuture *future); StatsFuture *future);
void TuningOrRun2DKernel(cl::Kernel &kernel, void TuningOrRun2DKernel(cl::Kernel &kernel,
const std::string tuning_key, const std::string tuning_key,
const uint32_t *gws, const uint32_t *gws,
std::vector<uint32_t> &lws, const std::vector<uint32_t> &lws,
StatsFuture *future); StatsFuture *future);
inline void SetFuture(StatsFuture *future, const cl::Event &event) { inline void SetFuture(StatsFuture *future, const cl::Event &event) {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/kernels/gemm.h" #include "mace/kernels/matmul.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h" #include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h" #include "mace/utils/tuner.h"
...@@ -11,7 +11,7 @@ namespace mace { ...@@ -11,7 +11,7 @@ namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <typename T>
void GEMMFunctor<DeviceType::OPENCL, T>::operator()( void MatMulFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *A, const Tensor *A,
const Tensor *B, const Tensor *B,
Tensor *C, Tensor *C,
...@@ -32,87 +32,44 @@ void GEMMFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -32,87 +32,44 @@ void GEMMFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options; std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value; auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("gemm"); std::string kernel_name = MACE_OBFUSCATE_SYMBOL("matmul");
built_options.emplace("-Dgemm=" + kernel_name); built_options.emplace("-Dmatmul=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto gemm_kernel = runtime->BuildKernel("gemm", kernel_name, built_options); auto matmul_kernel = runtime->BuildKernel("matmul", kernel_name, built_options);
uint32_t idx = 0; uint32_t idx = 0;
gemm_kernel.setArg(idx++, matmul_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(A->buffer()))); *(static_cast<const cl::Image2D *>(A->buffer())));
gemm_kernel.setArg(idx++, matmul_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(B->buffer()))); *(static_cast<const cl::Image2D *>(B->buffer())));
gemm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(C->buffer()))); matmul_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(C->buffer())));
gemm_kernel.setArg(idx++, static_cast<int>(height)); matmul_kernel.setArg(idx++, static_cast<int>(height));
gemm_kernel.setArg(idx++, static_cast<int>(width)); matmul_kernel.setArg(idx++, static_cast<int>(width));
gemm_kernel.setArg(idx++, static_cast<int>(A->dim(2))); matmul_kernel.setArg(idx++, static_cast<int>(A->dim(2)));
gemm_kernel.setArg(idx++, static_cast<int>(height_blocks)); matmul_kernel.setArg(idx++, static_cast<int>(height_blocks));
gemm_kernel.setArg(idx++, static_cast<int>(RoundUpDiv4(A->dim(2)))); matmul_kernel.setArg(idx++, static_cast<int>(RoundUpDiv4(A->dim(2))));
const uint32_t gws[3] = { const uint32_t gws[2] = {
static_cast<uint32_t>(width_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height_blocks * batch), static_cast<uint32_t>(height_blocks * batch),
}; };
const std::vector<uint32_t> lws = {16, 64}; const std::vector<uint32_t> lws = {16, 64, 1};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(gemm_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(2, 0);
local_ws[0] = std::min<uint32_t>(width_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(height_blocks * batch, kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1]},
{local_ws[1], local_ws[0]},
{kwg_size / 4, 4},
{kwg_size / 8, 8},
{kwg_size / 16, 16},
{kwg_size / 32, 32},
{kwg_size / 64, 64},
{kwg_size / 128, 128},
{kwg_size / 256, 256},
{kwg_size / 512, 512},
{kwg_size, 1},
{1, kwg_size}
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
gemm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss; std::stringstream ss;
ss << "gemm_opencl_kernel_" ss << "matmul_opencl_kernel_"
<< C->dim(0) << "_" << C->dim(0) << "_"
<< C->dim(1) << "_" << C->dim(1) << "_"
<< C->dim(2) << "_" << C->dim(2) << "_"
<< C->dim(3); << C->dim(3);
OpenCLProfilingTimer timer(&event); TuningOrRun2DKernel(matmul_kernel, ss.str(), gws, lws, future);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}; };
template template
struct GEMMFunctor<DeviceType::OPENCL, float>; struct MatMulFunctor<DeviceType::OPENCL, float>;
template template
struct GEMMFunctor<DeviceType::OPENCL, half>; struct MatMulFunctor<DeviceType::OPENCL, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -59,7 +59,7 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -59,7 +59,7 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(out_width), static_cast<uint32_t>(out_width),
static_cast<uint32_t>(out_height * batch)}; static_cast<uint32_t>(out_height * batch)};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss; std::stringstream ss;
ss << "resize_bilinear_opencl_kernel_" ss << "resize_bilinear_opencl_kernel_"
<< output->dim(0) << "_" << output->dim(0) << "_"
......
...@@ -41,7 +41,7 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits, ...@@ -41,7 +41,7 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss; std::stringstream ss;
ss << "softmax_opencl_kernel_" ss << "softmax_opencl_kernel_"
<< output->dim(0) << "_" << output->dim(0) << "_"
......
...@@ -61,7 +61,7 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor ...@@ -61,7 +61,7 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
const uint32_t gws[3] = {chan_blk, const uint32_t gws[3] = {chan_blk,
static_cast<uint32_t>(batch_tensor->dim(2)), static_cast<uint32_t>(batch_tensor->dim(2)),
static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))}; static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))};
std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss; std::stringstream ss;
ss << kernel_name << "_" ss << kernel_name << "_"
<< batch_tensor->dim(0) << "_" << batch_tensor->dim(0) << "_"
......
...@@ -51,60 +51,16 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i ...@@ -51,60 +51,16 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2)); wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2));
wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2)); wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));
const size_t gws[2] = {static_cast<size_t>(out_width), const uint32_t gws[2] = {static_cast<size_t>(out_width),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(3)))}; static_cast<size_t>(RoundUpDiv4(input_tensor->dim(3)))};
const std::vector<uint32_t> lws = {128, 8}; const std::vector<uint32_t> lws = {128, 8, 1};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(wino_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(2, 0);
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1]},
{local_ws[1], local_ws[0]},
{kwg_size / 4, 4},
{kwg_size / 8, 8},
{kwg_size / 16, 16},
{kwg_size / 32, 32},
{kwg_size / 64, 64},
{kwg_size / 128, 128},
{kwg_size / 256, 256},
{kwg_size / 512, 512},
{kwg_size, 1},
{1, kwg_size}
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
wino_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss; std::stringstream ss;
ss << "winograd_transform_kernel_" ss << "winograd_transform_kernel_"
<< input_tensor->dim(0) << "_" << input_tensor->dim(0) << "_"
<< input_tensor->dim(1) << "_" << input_tensor->dim(1) << "_"
<< input_tensor->dim(2) << "_" << input_tensor->dim(2) << "_"
<< input_tensor->dim(3); << input_tensor->dim(3);
OpenCLProfilingTimer timer(&event); TuningOrRun2DKernel(wino_kernel, ss.str(), gws, lws, future);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
} }
template<typename T> template<typename T>
...@@ -165,60 +121,17 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te ...@@ -165,60 +121,17 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
wino_kernel.setArg(idx++, relux_max_limit_); wino_kernel.setArg(idx++, relux_max_limit_);
wino_kernel.setArg(idx++, prelu_alpha_); wino_kernel.setArg(idx++, prelu_alpha_);
const size_t gws[2] = {static_cast<size_t>(input_tensor->dim(2)), const uint32_t gws[2] = {static_cast<size_t>(input_tensor->dim(2)),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(1)))}; static_cast<size_t>(RoundUpDiv4(input_tensor->dim(1)))};
const std::vector<uint32_t> lws = {128, 8}; const std::vector<uint32_t> lws = {128, 8, 1};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(wino_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(2, 0);
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1]},
{local_ws[1], local_ws[0]},
{kwg_size / 4, 4},
{kwg_size / 8, 8},
{kwg_size / 16, 16},
{kwg_size / 32, 32},
{kwg_size / 64, 64},
{kwg_size / 128, 128},
{kwg_size / 256, 256},
{kwg_size / 512, 512},
{kwg_size, 1},
{1, kwg_size}
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
wino_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss; std::stringstream ss;
ss << "winograd_inverse_transform_kernel_" ss << "winograd_inverse_transform_kernel_"
<< input_tensor->dim(0) << "_" << input_tensor->dim(0) << "_"
<< input_tensor->dim(1) << "_" << input_tensor->dim(1) << "_"
<< input_tensor->dim(2) << "_" << input_tensor->dim(2) << "_"
<< input_tensor->dim(3); << input_tensor->dim(3);
OpenCLProfilingTimer timer(&event); TuningOrRun2DKernel(wino_kernel, ss.str(), gws, lws, future);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
} }
template template
......
...@@ -26,7 +26,7 @@ void SimpleValidTest() { ...@@ -26,7 +26,7 @@ void SimpleValidTest() {
net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f}); net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage", BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", BufferToImage<D, T>(net, "Filter", "FilterImage",
kernels::BufferType::DW_CONV2D_FILTER); kernels::BufferType::DW_CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", BufferToImage<D, T>(net, "Bias", "BiasImage",
...@@ -46,7 +46,7 @@ void SimpleValidTest() { ...@@ -46,7 +46,7 @@ void SimpleValidTest() {
// Transfer output // Transfer output
ImageToBuffer<D, T>(net, "OutputImage", "Output", ImageToBuffer<D, T>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
...@@ -129,7 +129,7 @@ void ComplexValidTest() { ...@@ -129,7 +129,7 @@ void ComplexValidTest() {
{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage", BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", BufferToImage<D, T>(net, "Filter", "FilterImage",
kernels::BufferType::DW_CONV2D_FILTER); kernels::BufferType::DW_CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", BufferToImage<D, T>(net, "Bias", "BiasImage",
...@@ -149,7 +149,7 @@ void ComplexValidTest() { ...@@ -149,7 +149,7 @@ void ComplexValidTest() {
// Transfer output // Transfer output
ImageToBuffer<D, T>(net, "OutputImage", "Output", ImageToBuffer<D, T>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
...@@ -239,7 +239,7 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -239,7 +239,7 @@ void TestNxNS12(const index_t height, const index_t width) {
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage", BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", BufferToImage<D, T>(net, "Filter", "FilterImage",
kernels::BufferType::DW_CONV2D_FILTER); kernels::BufferType::DW_CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", BufferToImage<D, T>(net, "Bias", "BiasImage",
...@@ -259,7 +259,7 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -259,7 +259,7 @@ void TestNxNS12(const index_t height, const index_t width) {
// Transfer output // Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "DeviceOutput", ImageToBuffer<D, float>(net, "OutputImage", "DeviceOutput",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
......
...@@ -34,7 +34,7 @@ static void DepthwiseConv2d(int iters, ...@@ -34,7 +34,7 @@ static void DepthwiseConv2d(int iters,
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage", BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", BufferToImage<D, T>(net, "Filter", "FilterImage",
kernels::BufferType::DW_CONV2D_FILTER); kernels::BufferType::DW_CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", BufferToImage<D, T>(net, "Bias", "BiasImage",
......
...@@ -2,28 +2,28 @@ ...@@ -2,28 +2,28 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/ops/gemm.h" #include "mace/ops/matmul.h"
namespace mace { namespace mace {
void Register_GEMM(OperatorRegistry *op_registry) { void Register_MatMul(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM") REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::CPU) .Device(DeviceType::CPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
GEMMOp<DeviceType::CPU, float>); MatMulOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM") REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
GEMMOp<DeviceType::OPENCL, float>); MatMulOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GEMM") REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
GEMMOp<DeviceType::OPENCL, half>); MatMulOp<DeviceType::OPENCL, half>);
} }
} // namespace mace } // namespace mace
...@@ -2,18 +2,18 @@ ...@@ -2,18 +2,18 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#ifndef MACE_OPS_GEMM_H_ #ifndef MACE_OPS_MATMUL_H_
#define MACE_OPS_GEMM_H_ #define MACE_OPS_MATMUL_H_
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/gemm.h" #include "mace/kernels/matmul.h"
namespace mace { namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class GEMMOp : public Operator<D, T> { class MatMulOp : public Operator<D, T> {
public: public:
GEMMOp(const OperatorDef &operator_def, Workspace *ws) MatMulOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {} : Operator<D, T>(operator_def, ws) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
...@@ -32,9 +32,9 @@ class GEMMOp : public Operator<D, T> { ...@@ -32,9 +32,9 @@ class GEMMOp : public Operator<D, T> {
} }
private: private:
kernels::GEMMFunctor<D, T> functor_; kernels::MatMulFunctor<D, T> functor_;
}; };
} // namespace mace } // namespace mace
#endif // MACE_OPS_GEMM_H_ #endif // MACE_OPS_MATMUL_H_
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace mace { namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void GEMMBenchmark( static void MatMulBenchmark(
int iters, int batch, int height, int channels, int out_width) { int iters, int batch, int height, int channels, int out_width) {
mace::testing::StopTiming(); mace::testing::StopTiming();
...@@ -25,14 +25,14 @@ static void GEMMBenchmark( ...@@ -25,14 +25,14 @@ static void GEMMBenchmark(
BufferToImage<D, T>(net, "B", "BImage", BufferToImage<D, T>(net, "B", "BImage",
kernels::BufferType::IN_OUT_HEIGHT); kernels::BufferType::IN_OUT_HEIGHT);
OpDefBuilder("GEMM", "GEMMBM") OpDefBuilder("MatMul", "MatMulBM")
.Input("AImage") .Input("AImage")
.Input("BImage") .Input("BImage")
.Output("Output") .Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("GEMM", "GEMMBM") OpDefBuilder("MatMul", "MatMulBM")
.Input("A") .Input("A")
.Input("B") .Input("B")
.Output("Output") .Output("Output")
...@@ -52,19 +52,19 @@ static void GEMMBenchmark( ...@@ -52,19 +52,19 @@ static void GEMMBenchmark(
net.Sync(); net.Sync();
} }
#define BM_GEMM_MACRO(N, H, C, W, TYPE, DEVICE) \ #define BM_MATMUL_MACRO(N, H, C, W, TYPE, DEVICE) \
static void BM_GEMM_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE(int iters) { \ static void BM_MATMUL_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
GEMMBenchmark<DEVICE, TYPE>(iters, N, H, C, W); \ MatMulBenchmark<DEVICE, TYPE>(iters, N, H, C, W); \
} \ } \
BENCHMARK(BM_GEMM_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE) BENCHMARK(BM_MATMUL_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE)
#define BM_GEMM(N, H, C, W, TYPE) \ #define BM_MATMUL(N, H, C, W, TYPE) \
BM_GEMM_MACRO(N, H, C, W, TYPE, OPENCL); BM_MATMUL_MACRO(N, H, C, W, TYPE, OPENCL);
BM_GEMM(16, 32, 128, 49, half); BM_MATMUL(16, 32, 128, 49, half);
BM_GEMM(16, 32, 128, 961, half); BM_MATMUL(16, 32, 128, 961, half);
BM_GEMM(16, 32, 128, 3969, half); BM_MATMUL(16, 32, 128, 3969, half);
} // namespace mace } // namespace mace
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace mace { namespace mace {
class GEMMOpTest : public OpsTestBase {}; class MatMulOpTest : public OpsTestBase {};
template<DeviceType D> template<DeviceType D>
void Simple(const std::vector<index_t> &A_shape, void Simple(const std::vector<index_t> &A_shape,
...@@ -29,7 +29,7 @@ void Simple(const std::vector<index_t> &A_shape, ...@@ -29,7 +29,7 @@ void Simple(const std::vector<index_t> &A_shape,
BufferToImage<D, float>(net, "B", "BImage", BufferToImage<D, float>(net, "B", "BImage",
kernels::BufferType::IN_OUT_HEIGHT); kernels::BufferType::IN_OUT_HEIGHT);
OpDefBuilder("GEMM", "GEMMTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("AImage") .Input("AImage")
.Input("BImage") .Input("BImage")
.Output("OutputImage") .Output("OutputImage")
...@@ -41,7 +41,7 @@ void Simple(const std::vector<index_t> &A_shape, ...@@ -41,7 +41,7 @@ void Simple(const std::vector<index_t> &A_shape,
ImageToBuffer<D, float>(net, "OutputImage", "Output", ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_HEIGHT); kernels::BufferType::IN_OUT_HEIGHT);
} else { } else {
OpDefBuilder("GEMM", "GEMMTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("A") .Input("A")
.Input("B") .Input("B")
.Output("Output") .Output("Output")
...@@ -57,7 +57,7 @@ void Simple(const std::vector<index_t> &A_shape, ...@@ -57,7 +57,7 @@ void Simple(const std::vector<index_t> &A_shape,
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} }
TEST_F(GEMMOpTest, SimpleCPU) { TEST_F(MatMulOpTest, SimpleCPU) {
Simple<DeviceType::CPU>({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, Simple<DeviceType::CPU>({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6},
{1, 3, 2, 1}, {1, 2, 3, 4, 5, 6}, {1, 3, 2, 1}, {1, 2, 3, 4, 5, 6},
{1, 2, 2, 1}, {22, 28, 49, 64}); {1, 2, 2, 1}, {22, 28, 49, 64});
...@@ -74,13 +74,13 @@ TEST_F(GEMMOpTest, SimpleCPU) { ...@@ -74,13 +74,13 @@ TEST_F(GEMMOpTest, SimpleCPU) {
} }
TEST_F(GEMMOpTest, SimpleCPUWithBatch) { TEST_F(MatMulOpTest, SimpleCPUWithBatch) {
Simple<DeviceType::CPU>({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, Simple<DeviceType::CPU>({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64}); {2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64});
} }
TEST_F(GEMMOpTest, SimpleOPENCL) { TEST_F(MatMulOpTest, SimpleOPENCL) {
Simple<DeviceType::OPENCL>({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, Simple<DeviceType::OPENCL>({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6},
{1, 3, 2, 1}, {1, 2, 3, 4, 5, 6}, {1, 3, 2, 1}, {1, 2, 3, 4, 5, 6},
{1, 2, 2, 1}, {22, 28, 49, 64}); {1, 2, 2, 1}, {22, 28, 49, 64});
...@@ -96,7 +96,7 @@ TEST_F(GEMMOpTest, SimpleOPENCL) { ...@@ -96,7 +96,7 @@ TEST_F(GEMMOpTest, SimpleOPENCL) {
1315, 1430, 1545, 1660, 1775}); 1315, 1430, 1545, 1660, 1775});
} }
TEST_F(GEMMOpTest, SimpleGPUWithBatch) { TEST_F(MatMulOpTest, SimpleGPUWithBatch) {
Simple<DeviceType::CPU>({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, Simple<DeviceType::CPU>({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64}); {2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64});
...@@ -111,7 +111,7 @@ void Complex(const index_t batch, ...@@ -111,7 +111,7 @@ void Complex(const index_t batch,
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("GEMM", "GEMMTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("A") .Input("A")
.Input("B") .Input("B")
.Output("Output") .Output("Output")
...@@ -136,7 +136,7 @@ void Complex(const index_t batch, ...@@ -136,7 +136,7 @@ void Complex(const index_t batch,
BufferToImage<DeviceType::OPENCL, T>(net, "B", "BImage", BufferToImage<DeviceType::OPENCL, T>(net, "B", "BImage",
kernels::BufferType::IN_OUT_HEIGHT); kernels::BufferType::IN_OUT_HEIGHT);
OpDefBuilder("GEMM", "GEMMTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("AImage") .Input("AImage")
.Input("BImage") .Input("BImage")
.Output("OutputImage") .Output("OutputImage")
...@@ -155,24 +155,24 @@ void Complex(const index_t batch, ...@@ -155,24 +155,24 @@ void Complex(const index_t batch,
} }
} }
TEST_F(GEMMOpTest, OPENCLAlignedWithoutBatch) { TEST_F(MatMulOpTest, OPENCLAlignedWithoutBatch) {
Complex<float>(1, 64, 128, 32); Complex<float>(1, 64, 128, 32);
Complex<float>(1, 64, 32, 128); Complex<float>(1, 64, 32, 128);
} }
TEST_F(GEMMOpTest, OPENCLUnAlignedWithoutBatch) { TEST_F(MatMulOpTest, OPENCLUnAlignedWithoutBatch) {
Complex<float>(1, 31, 113, 61); Complex<float>(1, 31, 113, 61);
Complex<float>(1, 113, 31, 73); Complex<float>(1, 113, 31, 73);
} }
TEST_F(GEMMOpTest, OPENCLUnAlignedWithBatch) { TEST_F(MatMulOpTest, OPENCLUnAlignedWithBatch) {
Complex<float>(2, 3, 3, 3); Complex<float>(2, 3, 3, 3);
Complex<float>(16, 31, 61, 67); Complex<float>(16, 31, 61, 67);
Complex<float>(31, 31, 61, 67); Complex<float>(31, 31, 61, 67);
} }
TEST_F(GEMMOpTest, OPENCLHalfAlignedWithoutBatch) { TEST_F(MatMulOpTest, OPENCLHalfAlignedWithoutBatch) {
Complex<half>(1, 64, 128, 32); Complex<half>(1, 64, 128, 32);
Complex<half>(1, 64, 32, 128); Complex<half>(1, 64, 32, 128);
} }
TEST_F(GEMMOpTest, OPENCLHalfUnAlignedWithBatch) { TEST_F(MatMulOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(2, 31, 113, 61); Complex<half>(2, 31, 113, 61);
Complex<half>(16, 32, 64, 64); Complex<half>(16, 32, 64, 64);
Complex<half>(31, 31, 61, 67); Complex<half>(31, 31, 61, 67);
......
...@@ -52,7 +52,7 @@ void WinogradConvolution(const index_t batch, ...@@ -52,7 +52,7 @@ void WinogradConvolution(const index_t batch,
BufferToImage<D, T>(net, "Input", "InputImage", BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", BufferToImage<D, T>(net, "Filter", "FilterImage",
kernels::BufferType::FILTER); kernels::BufferType::CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", BufferToImage<D, T>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -92,8 +92,8 @@ void WinogradConvolution(const index_t batch, ...@@ -92,8 +92,8 @@ void WinogradConvolution(const index_t batch,
// Run on opencl // Run on opencl
net.RunOp(D); net.RunOp(D);
// GEMM // MatMul
OpDefBuilder("GEMM", "GEMMTest") OpDefBuilder("MatMul", "MatMulTest")
.Input("WinoFilter") .Input("WinoFilter")
.Input("WinoInput") .Input("WinoInput")
.Output("WinoGemm") .Output("WinoGemm")
......
...@@ -41,7 +41,7 @@ class Tuner { ...@@ -41,7 +41,7 @@ class Tuner {
template <typename RetType> template <typename RetType>
RetType TuneOrRun( RetType TuneOrRun(
const std::string param_key, const std::string param_key,
std::vector<param_type> &default_param, const std::vector<param_type> &default_param,
const std::function<std::vector<std::vector<param_type>>()> const std::function<std::vector<std::vector<param_type>>()>
&param_generator, &param_generator,
const std::function<RetType(const std::vector<param_type> &, Timer *, std::vector<param_type> *)> &func, const std::function<RetType(const std::vector<param_type> &, Timer *, std::vector<param_type> *)> &func,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册