提交 a310ca31 编写于 作者: 李寅

Refactor tiling code

上级 56742d51
......@@ -197,6 +197,56 @@ struct Conv2dFunctorBase {
const float relux_max_limit_;
};
#define MACE_DO_CONV2D(CC, CH, CW) \
Conv2dKernelFunc<T, inc_tile_size, CC, CH, CW>( \
input_ptr, filter_data, bias_data, output_ptr, \
h_offset, w_offset, c_offset, kernel_h, kernel_w, \
stride_h, stride_w, dilation_h, dilation_w, \
channels, input_channels, width, padded_width);
#define MACE_CASE_W_CONV2D(CC, CH) \
switch (w_count) { \
case 1: \
MACE_DO_CONV2D(CC, CH, 1); \
break; \
case 2: \
MACE_DO_CONV2D(CC, CH, 2); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
#define MACE_CASE_H_CONV2D(CC) \
switch (h_count) { \
case 1: \
MACE_CASE_W_CONV2D(CC, 1); \
break; \
case 2: \
MACE_CASE_W_CONV2D(CC, 2); \
break; \
default: \
LOG(FATAL) << "Unsupported h tile: " << h_count; \
}
#define MACE_CASE_C_CONV2D \
switch (c_count) { \
case 1: \
MACE_CASE_H_CONV2D(1); \
break; \
case 2: \
MACE_CASE_H_CONV2D(2); \
break; \
case 3: \
MACE_CASE_H_CONV2D(3); \
break; \
case 4: \
MACE_CASE_H_CONV2D(4); \
break; \
default: \
LOG(FATAL) << "Unsupported c tile: " << c_count; \
}
template <DeviceType D, typename T>
struct Conv2dFunctor : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
......@@ -312,306 +362,7 @@ struct Conv2dFunctor : Conv2dFunctorBase {
const int w_count = std::min(w_tile_size, width - w_offset);
const int c_count = std::min(c_tile_size, channels - c_offset);
switch (c_count) {
case 1:
switch (h_count) {
case 1:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 1, 1, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 1, 1, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 1, 1, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 1, 1, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
case 2:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 1, 2, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 1, 2, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 1, 2, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 1, 2, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
default:
LOG(FATAL) << "Unsupported height tile: " << h_count;
}
break;
case 2:
switch (h_count) {
case 1:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 2, 1, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 2, 1, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 2, 1, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 2, 1, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
case 2:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 2, 2, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 2, 2, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 2, 2, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 2, 2, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
default:
LOG(FATAL) << "Unsupported height tile: " << h_count;
}
break;
case 3:
switch (h_count) {
case 1:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 3, 1, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 3, 1, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 3, 1, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 3, 1, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
case 2:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 3, 2, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 3, 2, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 3, 2, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 3, 2, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
default:
LOG(FATAL) << "Unsupported height tile: " << h_count;
}
break;
case 4:
switch (h_count) {
case 1:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 4, 1, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 4, 1, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 4, 1, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 4, 1, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
case 2:
switch (w_count) {
case 1:
Conv2dKernelFunc<T, inc_tile_size, 4, 2, 1>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 2:
Conv2dKernelFunc<T, inc_tile_size, 4, 2, 2>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 3:
Conv2dKernelFunc<T, inc_tile_size, 4, 2, 3>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
case 4:
Conv2dKernelFunc<T, inc_tile_size, 4, 2, 4>(
input_ptr, filter_data, bias_data, output_ptr,
h_offset, w_offset, c_offset, kernel_h, kernel_w,
stride_h, stride_w, dilation_h, dilation_w,
channels, input_channels, width, padded_width);
break;
default:
LOG(FATAL) << "Unsupported width tile: " << w_count;
}
break;
default:
LOG(FATAL) << "Unsupported height tile: " << h_count;
}
break;
default:
LOG(FATAL) << "Unsupported channel tile: " << c_count;
}
MACE_CASE_C_CONV2D;
}
}
}
......
......@@ -89,42 +89,71 @@ inline void MatMulKernelFunc(const T *A,
}
} // namespace
#define CASE_K_MATMUL(HC, WC, KC) \
case KC: \
MatMulKernelFunc<T, register_tile_size, HC, WC, KC>(a_ptr_batch_base, \
b_ptr_batch_base, \
c_ptr_batch_base, \
ih, \
iw, \
ik, \
height, \
width, \
K); \
break;
#define CASE_W_MATMUL(HC, WC) \
case WC: \
switch (k_count) { \
CASE_K_MATMUL(HC, WC, 1); \
CASE_K_MATMUL(HC, WC, 2); \
CASE_K_MATMUL(HC, WC, 3); \
CASE_K_MATMUL(HC, WC, 4); \
default: \
LOG(FATAL) << "Unsupported k tile: " << k_count; \
} \
break;
#define CASE_H_MATMUL(HC) \
case HC: \
switch (w_count) { \
CASE_W_MATMUL(HC, 1); \
CASE_W_MATMUL(HC, 2); \
CASE_W_MATMUL(HC, 3); \
CASE_W_MATMUL(HC, 4); \
default: \
LOG(FATAL) << "Unsupported w tile: " << k_count; \
} \
break;
#define MACE_DO_MATMUL(HC, WC, KC) \
MatMulKernelFunc<T, register_tile_size, HC, WC, KC>(a_ptr_batch_base, \
b_ptr_batch_base, \
c_ptr_batch_base, \
ih, \
iw, \
ik, \
height, \
width, \
K);
#define MACE_CASE_K_MATMUL(HC, WC) \
switch (k_count) { \
case 1: \
MACE_DO_MATMUL(HC, WC, 1); \
break; \
case 2: \
MACE_DO_MATMUL(HC, WC, 2); \
break; \
case 3: \
MACE_DO_MATMUL(HC, WC, 3); \
break; \
case 4: \
MACE_DO_MATMUL(HC, WC, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported k tile: " << k_count; \
}
#define MACE_CASE_W_MATMUL(HC) \
switch (w_count) { \
case 1: \
MACE_CASE_K_MATMUL(HC, 1); \
break; \
case 2: \
MACE_CASE_K_MATMUL(HC, 2); \
break; \
case 3: \
MACE_CASE_K_MATMUL(HC, 3); \
break; \
case 4: \
MACE_CASE_K_MATMUL(HC, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
#define MACE_CASE_H_MATMUL \
switch (h_count) { \
case 1: \
MACE_CASE_W_MATMUL(1); \
break; \
case 2: \
MACE_CASE_W_MATMUL(2); \
break; \
case 3: \
MACE_CASE_W_MATMUL(3); \
break; \
case 4: \
MACE_CASE_W_MATMUL(4); \
break; \
default: \
LOG(FATAL) << "Unsupported h tile: " << h_count; \
}
template<DeviceType D, typename T>
struct MatMulFunctor {
......@@ -196,14 +225,7 @@ struct MatMulFunctor {
const int w_count = std::min(register_tile_size, iw_end - iw);
const int k_count = std::min(register_tile_size, ik_end - ik);
switch (h_count) {
CASE_H_MATMUL(1);
CASE_H_MATMUL(2);
CASE_H_MATMUL(3);
CASE_H_MATMUL(4);
default:LOG(FATAL) << "Unsupported height tile: "
<< h_count;
}
MACE_CASE_H_MATMUL;
} // ik
} // iw
} // ih
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册