diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 99a2eaa32a9cded20bef97f2428dba25bd5dc096..a4a24eedae4b7c92f2ff1b1841c78b6f99d566bc 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -197,6 +197,56 @@ struct Conv2dFunctorBase { const float relux_max_limit_; }; +#define MACE_DO_CONV2D(CC, CH, CW) \ +Conv2dKernelFunc( \ + input_ptr, filter_data, bias_data, output_ptr, \ + h_offset, w_offset, c_offset, kernel_h, kernel_w, \ + stride_h, stride_w, dilation_h, dilation_w, \ + channels, input_channels, width, padded_width); + +#define MACE_CASE_W_CONV2D(CC, CH) \ +switch (w_count) { \ + case 1: \ + MACE_DO_CONV2D(CC, CH, 1); \ + break; \ + case 2: \ + MACE_DO_CONV2D(CC, CH, 2); \ + break; \ + default: \ + LOG(FATAL) << "Unsupported w tile: " << w_count; \ +} + +#define MACE_CASE_H_CONV2D(CC) \ +switch (h_count) { \ + case 1: \ + MACE_CASE_W_CONV2D(CC, 1); \ + break; \ + case 2: \ + MACE_CASE_W_CONV2D(CC, 2); \ + break; \ + default: \ + LOG(FATAL) << "Unsupported h tile: " << h_count; \ +} + +#define MACE_CASE_C_CONV2D \ +switch (c_count) { \ + case 1: \ + MACE_CASE_H_CONV2D(1); \ + break; \ + case 2: \ + MACE_CASE_H_CONV2D(2); \ + break; \ + case 3: \ + MACE_CASE_H_CONV2D(3); \ + break; \ + case 4: \ + MACE_CASE_H_CONV2D(4); \ + break; \ + default: \ + LOG(FATAL) << "Unsupported c tile: " << c_count; \ +} + + template struct Conv2dFunctor : Conv2dFunctorBase { Conv2dFunctor(const int *strides, @@ -312,306 +362,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { const int w_count = std::min(w_tile_size, width - w_offset); const int c_count = std::min(c_tile_size, channels - c_offset); - switch (c_count) { - case 1: - switch (h_count) { - case 1: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - case 2: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - default: - LOG(FATAL) << "Unsupported height tile: " << h_count; - } - break; - case 2: - switch (h_count) { - case 1: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - case 2: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - default: - LOG(FATAL) << "Unsupported height tile: " << h_count; - } - break; - case 3: - switch (h_count) { - case 1: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - case 2: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - default: - LOG(FATAL) << "Unsupported height tile: " << h_count; - } - break; - case 4: - switch (h_count) { - case 1: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - case 2: - switch (w_count) { - case 1: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 2: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 3: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - case 4: - Conv2dKernelFunc( - input_ptr, filter_data, bias_data, output_ptr, - h_offset, w_offset, c_offset, kernel_h, kernel_w, - stride_h, stride_w, dilation_h, dilation_w, - channels, input_channels, width, padded_width); - break; - default: - LOG(FATAL) << "Unsupported width tile: " << w_count; - } - break; - default: - LOG(FATAL) << "Unsupported height tile: " << h_count; - } - break; - default: - LOG(FATAL) << "Unsupported channel tile: " << c_count; - } + MACE_CASE_C_CONV2D; } } } diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 7610c0dadd50f80da2b6bf297f86eb0e4c529577..88452bfe83dde8c0d05e1ff61a55410863c1b31a 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -89,42 +89,71 @@ inline void MatMulKernelFunc(const T *A, } } // namespace -#define CASE_K_MATMUL(HC, WC, KC) \ - case KC: \ - MatMulKernelFunc(a_ptr_batch_base, \ - b_ptr_batch_base, \ - c_ptr_batch_base, \ - ih, \ - iw, \ - ik, \ - height, \ - width, \ - K); \ - break; - -#define CASE_W_MATMUL(HC, WC) \ - case WC: \ - switch (k_count) { \ - CASE_K_MATMUL(HC, WC, 1); \ - CASE_K_MATMUL(HC, WC, 2); \ - CASE_K_MATMUL(HC, WC, 3); \ - CASE_K_MATMUL(HC, WC, 4); \ - default: \ - LOG(FATAL) << "Unsupported k tile: " << k_count; \ - } \ - break; - -#define CASE_H_MATMUL(HC) \ - case HC: \ - switch (w_count) { \ - CASE_W_MATMUL(HC, 1); \ - CASE_W_MATMUL(HC, 2); \ - CASE_W_MATMUL(HC, 3); \ - CASE_W_MATMUL(HC, 4); \ - default: \ - LOG(FATAL) << "Unsupported w tile: " << k_count; \ - } \ - break; +#define MACE_DO_MATMUL(HC, WC, KC) \ +MatMulKernelFunc(a_ptr_batch_base, \ + b_ptr_batch_base, \ + c_ptr_batch_base, \ + ih, \ + iw, \ + ik, \ + height, \ + width, \ + K); + +#define MACE_CASE_K_MATMUL(HC, WC) \ +switch (k_count) { \ + case 1: \ + MACE_DO_MATMUL(HC, WC, 1); \ + break; \ + case 2: \ + MACE_DO_MATMUL(HC, WC, 2); \ + break; \ + case 3: \ + MACE_DO_MATMUL(HC, WC, 3); \ + break; \ + case 4: \ + MACE_DO_MATMUL(HC, WC, 4); \ + break; \ + default: \ + LOG(FATAL) << "Unsupported k tile: " << k_count; \ +} + + +#define MACE_CASE_W_MATMUL(HC) \ +switch (w_count) { \ + case 1: \ + MACE_CASE_K_MATMUL(HC, 1); \ + break; \ + case 2: \ + MACE_CASE_K_MATMUL(HC, 2); \ + break; \ + case 3: \ + MACE_CASE_K_MATMUL(HC, 3); \ + break; \ + case 4: \ + MACE_CASE_K_MATMUL(HC, 4); \ + break; \ + default: \ + LOG(FATAL) << "Unsupported w tile: " << w_count; \ +} + +#define MACE_CASE_H_MATMUL \ +switch (h_count) { \ + case 1: \ + MACE_CASE_W_MATMUL(1); \ + break; \ + case 2: \ + MACE_CASE_W_MATMUL(2); \ + break; \ + case 3: \ + MACE_CASE_W_MATMUL(3); \ + break; \ + case 4: \ + MACE_CASE_W_MATMUL(4); \ + break; \ + default: \ + LOG(FATAL) << "Unsupported h tile: " << h_count; \ +} template struct MatMulFunctor { @@ -196,14 +225,7 @@ struct MatMulFunctor { const int w_count = std::min(register_tile_size, iw_end - iw); const int k_count = std::min(register_tile_size, ik_end - ik); - switch (h_count) { - CASE_H_MATMUL(1); - CASE_H_MATMUL(2); - CASE_H_MATMUL(3); - CASE_H_MATMUL(4); - default:LOG(FATAL) << "Unsupported height tile: " - << h_count; - } + MACE_CASE_H_MATMUL; } // ik } // iw } // ih