diff --git a/mace/ops/arm/base/depthwise_conv_2d_3x3.cc b/mace/ops/arm/base/depthwise_conv_2d_3x3.cc index 778386f144bafd457e83f3651eea37cd63780fff..e17245e612923bd619cd53863f1c3ac60c34aca7 100644 --- a/mace/ops/arm/base/depthwise_conv_2d_3x3.cc +++ b/mace/ops/arm/base/depthwise_conv_2d_3x3.cc @@ -20,16 +20,6 @@ namespace mace { namespace ops { namespace arm { -extern template -MaceStatus DepthwiseConv2dK3x3S1::DoCompute( - const DepthwiseConvComputeParam &p, const float16_t *filter_data, - const float16_t *input_data, float16_t *output_data); - -extern template -MaceStatus DepthwiseConv2dK3x3S2::DoCompute( - const DepthwiseConvComputeParam &p, const float16_t *filter_data, - const float16_t *input_data, float16_t *output_data); - namespace { template void DepthwiseConv2d3x3Pixel(const T *in_base, @@ -474,16 +464,6 @@ void RegisterDepthwiseConv2dK3x3Delegator(OpDelegatorRegistry *registry) { delegator::DepthwiseConv2dParam, MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, BFloat16, ImplType::NEON, K3x3S2)); - MACE_REGISTER_FP16_DELEGATOR( - registry, DepthwiseConv2dK3x3S1, - delegator::DepthwiseConv2dParam, - MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, - float16_t, ImplType::NEON, K3x3S1)); - MACE_REGISTER_FP16_DELEGATOR( - registry, DepthwiseConv2dK3x3S2, - delegator::DepthwiseConv2dParam, - MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, - float16_t, ImplType::NEON, K3x3S2)); } } // namespace arm diff --git a/mace/ops/arm/base/gemm.cc b/mace/ops/arm/base/gemm.cc index cba09ecb6f93a57aa48581c03b7c8df871feb250..935ce6e7e1fe31748f39040d4716aa46aec4cf31 100644 --- a/mace/ops/arm/base/gemm.cc +++ b/mace/ops/arm/base/gemm.cc @@ -23,24 +23,6 @@ namespace mace { namespace ops { namespace arm { -extern template void Gemm::Pack8x4( - const MatrixMap &matrix, - MatrixMajor dst_major, float16_t *packed_matrix); -extern template void Gemm::Unpack8x8( - const float16_t *packed_output, MatrixMap *output); -extern template void Gemm::PackLhs( - const MatrixMap &lhs, float16_t *packed_lhs); -extern template void Gemm::PackRhs( - const MatrixMap &rhs, float16_t *packed_rhs); -extern template void Gemm::UnpackOutput( - const float16_t *packed_output, MatrixMap *output); -extern template MaceStatus Gemm::Compute( - const OpContext *context, const Tensor *lhs, const Tensor *rhs, - const index_t batch, const index_t rows, const index_t cols, - const index_t depth, const MatrixMajor lhs_major, - const MatrixMajor rhs_major, const MatrixMajor output_major, - const bool lhs_batched, const bool rhs_batched, Tensor *output); - template void Gemm::Pack4x4(const MatrixMap &matrix, MatrixMajor dst_major, T *packed_matrix) { @@ -719,12 +701,7 @@ void RegisterGemmDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_BF16_DELEGATOR( registry, Gemm, delegator::GemmParam, MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, BFloat16, ImplType::NEON)); - - MACE_REGISTER_FP16_DELEGATOR( - registry, Gemm, delegator::GemmParam, - MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float16_t, ImplType::NEON)); } - } // namespace arm } // namespace ops } // namespace mace diff --git a/mace/ops/arm/fp16/depthwise_conv_2d_3x3.cc b/mace/ops/arm/fp16/depthwise_conv_2d_3x3.cc index 7455bce8f4c1a833de1e08b3af708a4d63ee1f2b..868c57d49811ed6eb3455ed5a0426fa404df67df 100644 --- a/mace/ops/arm/fp16/depthwise_conv_2d_3x3.cc +++ b/mace/ops/arm/fp16/depthwise_conv_2d_3x3.cc @@ -403,6 +403,18 @@ MaceStatus DepthwiseConv2dK3x3S2::DoCompute( return MaceStatus::MACE_SUCCESS; } +void RegisterFP16DepthwiseConv2dK3x3Delegator(OpDelegatorRegistry *registry) { + MACE_REGISTER_FP16_DELEGATOR( + registry, DepthwiseConv2dK3x3S1, + delegator::DepthwiseConv2dParam, + MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, + float16_t, ImplType::NEON, K3x3S1)); + MACE_REGISTER_FP16_DELEGATOR( + registry, DepthwiseConv2dK3x3S2, + delegator::DepthwiseConv2dParam, + MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, + float16_t, ImplType::NEON, K3x3S2)); +} } // namespace arm } // namespace ops diff --git a/mace/ops/arm/fp16/gemm.cc b/mace/ops/arm/fp16/gemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..84b786f1dafa3097857ba98fa4ef56dd5876a920 --- /dev/null +++ b/mace/ops/arm/fp16/gemm.cc @@ -0,0 +1,780 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mace/ops/arm/base/gemm.h" +#include "mace/port/env.h" + +namespace mace { +namespace ops { +namespace arm { + +template<> +void Gemm::Pack8x4(const MatrixMap &matrix, + MatrixMajor dst_major, + float16_t *packed_matrix) { + const index_t rows = matrix.rows(); + const index_t cols = matrix.cols(); + + // use the same terminology as GemmLowp: + // depth is depth, width is the opposite dim other than depth + // lhs + index_t width = rows; + index_t depth = cols; + index_t width_stride = matrix.rows_stride(); + index_t depth_stride = matrix.cols_stride(); + if (dst_major == RowMajor) { + // rhs + std::swap(width, depth); + std::swap(width_stride, depth_stride); + } + const float16_t *data = matrix.data(); + float16_t *packed_ptr = packed_matrix; + + const index_t block_size = 8; + const index_t depth_padded = RoundUp(depth, static_cast(8)); + + if (depth_padded > depth) { + memset(packed_ptr + depth * block_size, + 0, + sizeof(float16_t) * (depth_padded - depth) * block_size); + } + + if (dst_major == matrix.matrix_major()) { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + memcpy(packed_ptr, data, sizeof(float16_t) * width); + memset(packed_ptr + width, 0, sizeof(float16_t) * width_remain); + data += depth_stride; + packed_ptr += block_size; + } + } else { + for (index_t d = 0; d < depth; ++d) { + float16x8_t vi = vld1q_f16(data); + vst1q_f16(packed_ptr, vi); + data += depth_stride; + packed_ptr += block_size; + } + } + } else { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + for (index_t w = 0; w < width; ++w) { + packed_ptr[w] = data[w * width_stride + d]; + } // w + memset(packed_ptr + width, 0, sizeof(float16_t) * width_remain); + packed_ptr += block_size; + } // d + } else { + const float16_t *data0 = data; + const float16_t *data1 = data + width_stride; + const float16_t *data2 = data1 + width_stride; + const float16_t *data3 = data2 + width_stride; + const float16_t *data4 = data3 + width_stride; + const float16_t *data5 = data4 + width_stride; + const float16_t *data6 = data5 + width_stride; + const float16_t *data7 = data6 + width_stride; + + const index_t depth_block = depth / 8; + const index_t depth_remain = depth - depth_block * 8; + for (index_t depth_block_idx = 0; depth_block_idx < depth_block; + ++depth_block_idx) { + float16x8_t v0 = vld1q_f16(data0); + float16x8_t v1 = vld1q_f16(data1); + float16x8_t v2 = vld1q_f16(data2); + float16x8_t v3 = vld1q_f16(data3); + float16x8_t v4 = vld1q_f16(data4); + float16x8_t v5 = vld1q_f16(data5); + float16x8_t v6 = vld1q_f16(data6); + float16x8_t v7 = vld1q_f16(data7); + float16x8x2_t v02_intertwined = vzipq_f16(v0, v2); + float16x8x2_t v13_intertwined = vzipq_f16(v1, v3); + float16x8x2_t v46_intertwined = vzipq_f16(v4, v6); + float16x8x2_t v57_intertwined = vzipq_f16(v5, v7); + float16x8x2_t v0246_intertwined = + vzipq_f16(v02_intertwined.val[0], v46_intertwined.val[0]); + float16x8x2_t v0246n_intertwined = + vzipq_f16(v02_intertwined.val[1], v46_intertwined.val[1]); + float16x8x2_t v1357_intertwined = + vzipq_f16(v13_intertwined.val[0], v57_intertwined.val[0]); + float16x8x2_t v1357n_intertwined = + vzipq_f16(v13_intertwined.val[1], v57_intertwined.val[1]); + + float16x8x2_t v01234567_intertwined = + vzipq_f16(v0246_intertwined.val[0], v1357_intertwined.val[0]); + float16x8x2_t v01234567n1_intertwined = + vzipq_f16(v0246_intertwined.val[1], v1357_intertwined.val[1]); + float16x8x2_t v01234567n2_intertwined = + vzipq_f16(v0246n_intertwined.val[0], v1357n_intertwined.val[0]); + float16x8x2_t v01234567n3_intertwined = + vzipq_f16(v0246n_intertwined.val[1], v1357n_intertwined.val[1]); + + vst1q_f16(packed_ptr, v01234567_intertwined.val[0]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567_intertwined.val[1]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567n1_intertwined.val[0]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567n1_intertwined.val[1]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567n2_intertwined.val[0]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567n2_intertwined.val[1]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567n3_intertwined.val[0]); + packed_ptr += 8; + vst1q_f16(packed_ptr, v01234567n3_intertwined.val[1]); + packed_ptr += 8; + + data0 += 8; + data1 += 8; + data2 += 8; + data3 += 8; + data4 += 8; + data5 += 8; + data6 += 8; + data7 += 8; + } + for (index_t d = 0; d < depth_remain; ++d) { + float16x8_t vi = + {*data0, *data1, *data2, *data3, *data4, *data5, *data6, *data7}; + vst1q_f16(packed_ptr, vi); + packed_ptr += 8; + + + ++data0; + ++data1; + ++data2; + ++data3; + ++data4; + ++data5; + ++data6; + ++data7; + } // d + } + } +} + +template<> +void Gemm::Unpack8x8(const float16_t *packed_output, + MatrixMap *output) { + const index_t rows = output->rows(); + const index_t cols = output->cols(); + index_t row_stride = output->rows_stride(); + index_t col_stride = output->cols_stride(); + + float16_t *output_ptr = output->data(); + const float16_t *packed_ptr = packed_output; + + const index_t block_size = 8; + + // packed_output always has row-major + if (output->matrix_major() == RowMajor) { + if (cols < block_size) { + for (index_t r = 0; r < rows; ++r) { + memcpy(output_ptr, packed_ptr, sizeof(float16_t) * cols); + output_ptr += row_stride; + packed_ptr += block_size; + } + } else { + for (index_t r = 0; r < rows; ++r) { + float16x8_t vi = vld1q_f16(packed_ptr); + vst1q_f16(output_ptr, vi); + output_ptr += row_stride; + packed_ptr += block_size; + } + } + } else { + // ColMajor + if (rows < block_size) { + for (index_t c = 0; c < cols; ++c) { + for (index_t r = 0; r < rows; ++r) { + output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; + } // r + } // c + } else { + const float16_t *data0 = packed_ptr; + const float16_t *data1 = data0 + block_size; + const float16_t *data2 = data1 + block_size; + const float16_t *data3 = data2 + block_size; + const float16_t *data4 = data3 + block_size; + const float16_t *data5 = data4 + block_size; + const float16_t *data6 = data5 + block_size; + const float16_t *data7 = data6 + block_size; + + index_t col_block = cols / 8; + index_t col_remain = cols - col_block * 8; + for (index_t col_block_idx = 0; col_block_idx < col_block; + ++col_block_idx) { + float16x8_t v0 = vld1q_f16(data0); + float16x8_t v1 = vld1q_f16(data1); + float16x8_t v2 = vld1q_f16(data2); + float16x8_t v3 = vld1q_f16(data3); + float16x8_t v4 = vld1q_f16(data4); + float16x8_t v5 = vld1q_f16(data5); + float16x8_t v6 = vld1q_f16(data6); + float16x8_t v7 = vld1q_f16(data7); + float16x8x2_t v02_intertwined = vzipq_f16(v0, v2); + float16x8x2_t v13_intertwined = vzipq_f16(v1, v3); + float16x8x2_t v46_intertwined = vzipq_f16(v4, v6); + float16x8x2_t v57_intertwined = vzipq_f16(v5, v7); + float16x8x2_t v0246_intertwined = + vzipq_f16(v02_intertwined.val[0], v46_intertwined.val[0]); + float16x8x2_t v0246n_intertwined = + vzipq_f16(v02_intertwined.val[1], v46_intertwined.val[1]); + float16x8x2_t v1357_intertwined = + vzipq_f16(v13_intertwined.val[0], v57_intertwined.val[0]); + float16x8x2_t v1357n_intertwined = + vzipq_f16(v13_intertwined.val[1], v57_intertwined.val[1]); + + float16x8x2_t v01234567_intertwined = + vzipq_f16(v0246_intertwined.val[0], v1357_intertwined.val[0]); + float16x8x2_t v01234567n1_intertwined = + vzipq_f16(v0246_intertwined.val[1], v1357_intertwined.val[1]); + float16x8x2_t v01234567n2_intertwined = + vzipq_f16(v0246n_intertwined.val[0], v1357n_intertwined.val[0]); + float16x8x2_t v01234567n3_intertwined = + vzipq_f16(v0246n_intertwined.val[1], v1357n_intertwined.val[1]); + + vst1q_f16(output_ptr, v01234567_intertwined.val[0]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567_intertwined.val[1]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567n1_intertwined.val[0]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567n1_intertwined.val[1]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567n2_intertwined.val[0]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567n2_intertwined.val[1]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567n3_intertwined.val[0]); + output_ptr += col_stride; + vst1q_f16(output_ptr, v01234567n3_intertwined.val[1]); + output_ptr += col_stride; + + data0 += 8; + data1 += 8; + data2 += 8; + data3 += 8; + data4 += 8; + data5 += 8; + data6 += 8; + data7 += 8; + } + for (index_t c = 0; c < col_remain; ++c) { + float16x8_t vi = + {*data0, *data1, *data2, *data3, *data4, *data5, *data6, *data7}; + vst1q_f16(output_ptr, vi); + output_ptr += col_stride; + + ++data0; + ++data1; + ++data2; + ++data3; + ++data4; + ++data5; + ++data6; + ++data7; + } // d + } + } +} + +template<> +void Gemm::PackLhs(const MatrixMap &lhs, + float16_t *packed_lhs) { + Pack8x4(lhs, ColMajor, packed_lhs); +} + +template<> +void Gemm::PackRhs(const MatrixMap &rhs, + float16_t *packed_rhs) { + Pack8x4(rhs, RowMajor, packed_rhs); +} + +template<> +void Gemm::UnpackOutput(const float16_t *packed_output, + MatrixMap *output) { + Unpack8x8(packed_output, output); +} + +template<> +void Gemm::ComputeBlock(const float16_t *packed_lhs_data, + const float16_t *packed_rhs_data, + const index_t depth_padded, + float16_t *packed_output_data) { + /* Ref: + for (index_t r = 0; r < block_size; ++r) { + for (index_t c = 0; c < block_size; ++c) { + float16_t sum = 0; + for (index_t d = 0; d < depth; ++d) { + // (r, d) * (d, c) + sum += packed_lhs_data[d * r_block_size + r] + * packed_rhs_data[d * c_block_size + c]; + } + packed_output_data[r * c_block_size + c] = sum; + } + } + */ + const float16_t *lhs_ptr = packed_lhs_data; + const float16_t *rhs_ptr = packed_rhs_data; + + const index_t depth_block_count = depth_padded / 8; + + if (depth_block_count > 0) { + index_t r_depth_block_count = depth_block_count; + // just make compiler happy + MACE_UNUSED(r_depth_block_count); + + asm volatile( + "dup v16.8h, wzr \n" + "dup v17.8h, wzr \n" + "dup v18.8h, wzr \n" + "dup v19.8h, wzr \n" + "dup v20.8h, wzr \n" + "dup v21.8h, wzr \n" + "dup v22.8h, wzr \n" + "dup v23.8h, wzr \n" + "dup v24.8h, wzr \n" + "dup v25.8h, wzr \n" + "dup v26.8h, wzr \n" + "dup v27.8h, wzr \n" + "dup v28.8h, wzr \n" + "dup v29.8h, wzr \n" + "dup v30.8h, wzr \n" + "dup v31.8h, wzr \n" + + // prelogue + "ld1 {v0.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v1.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v2.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v3.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v4.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v5.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v6.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v7.8h}, [%[lhs_ptr]], #16 \n" + + "ld1 {v8.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v9.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v10.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v11.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v12.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v13.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v14.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v15.8h}, [%[rhs_ptr]], #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + "beq 1f\n" + + "0: \n" + "fmla v16.8h, v8.8h, v0.h[0] \n" + "fmla v17.8h, v8.8h, v0.h[1] \n" + "fmla v18.8h, v8.8h, v0.h[2] \n" + "fmla v19.8h, v8.8h, v0.h[3] \n" + "fmla v20.8h, v8.8h, v0.h[4] \n" + "fmla v21.8h, v8.8h, v0.h[5] \n" + "fmla v22.8h, v8.8h, v0.h[6] \n" + "fmla v23.8h, v8.8h, v0.h[7] \n" + + "ld1 {v0.8h}, [%[lhs_ptr]], #16 \n" + + "fmla v24.8h, v9.8h, v1.h[0] \n" + "fmla v25.8h, v9.8h, v1.h[1] \n" + "fmla v26.8h, v9.8h, v1.h[2] \n" + "fmla v27.8h, v9.8h, v1.h[3] \n" + "fmla v28.8h, v9.8h, v1.h[4] \n" + "fmla v29.8h, v9.8h, v1.h[5] \n" + "fmla v30.8h, v9.8h, v1.h[6] \n" + "fmla v31.8h, v9.8h, v1.h[7] \n" + + "ld1 {v1.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v8.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v9.8h}, [%[rhs_ptr]], #16 \n" + + "fmla v16.8h, v10.8h, v2.h[0] \n" + "fmla v17.8h, v10.8h, v2.h[1] \n" + "fmla v18.8h, v10.8h, v2.h[2] \n" + "fmla v19.8h, v10.8h, v2.h[3] \n" + "fmla v20.8h, v10.8h, v2.h[4] \n" + "fmla v21.8h, v10.8h, v2.h[5] \n" + "fmla v22.8h, v10.8h, v2.h[6] \n" + "fmla v23.8h, v10.8h, v2.h[7] \n" + + "ld1 {v2.8h}, [%[lhs_ptr]], #16 \n" + + "fmla v24.8h, v11.8h, v3.h[0] \n" + "fmla v25.8h, v11.8h, v3.h[1] \n" + "fmla v26.8h, v11.8h, v3.h[2] \n" + "fmla v27.8h, v11.8h, v3.h[3] \n" + "fmla v28.8h, v11.8h, v3.h[4] \n" + "fmla v29.8h, v11.8h, v3.h[5] \n" + "fmla v30.8h, v11.8h, v3.h[6] \n" + "fmla v31.8h, v11.8h, v3.h[7] \n" + + "ld1 {v3.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v10.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v11.8h}, [%[rhs_ptr]], #16 \n" + + "fmla v16.8h, v12.8h, v4.h[0] \n" + "fmla v17.8h, v12.8h, v4.h[1] \n" + "fmla v18.8h, v12.8h, v4.h[2] \n" + "fmla v19.8h, v12.8h, v4.h[3] \n" + "fmla v20.8h, v12.8h, v4.h[4] \n" + "fmla v21.8h, v12.8h, v4.h[5] \n" + "fmla v22.8h, v12.8h, v4.h[6] \n" + "fmla v23.8h, v12.8h, v4.h[7] \n" + "ld1 {v4.8h}, [%[lhs_ptr]], #16 \n" + + "fmla v24.8h, v13.8h, v5.h[0] \n" + "fmla v25.8h, v13.8h, v5.h[1] \n" + "fmla v26.8h, v13.8h, v5.h[2] \n" + "fmla v27.8h, v13.8h, v5.h[3] \n" + "fmla v28.8h, v13.8h, v5.h[4] \n" + "fmla v29.8h, v13.8h, v5.h[5] \n" + "fmla v30.8h, v13.8h, v5.h[6] \n" + "fmla v31.8h, v13.8h, v5.h[7] \n" + + "ld1 {v5.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v12.8h}, [%[rhs_ptr]], #16 \n" + "ld1 {v13.8h}, [%[rhs_ptr]], #16 \n" + + "fmla v16.8h, v14.8h, v6.h[0] \n" + "fmla v17.8h, v14.8h, v6.h[1] \n" + "fmla v18.8h, v14.8h, v6.h[2] \n" + "fmla v19.8h, v14.8h, v6.h[3] \n" + "fmla v20.8h, v14.8h, v6.h[4] \n" + "fmla v21.8h, v14.8h, v6.h[5] \n" + "fmla v22.8h, v14.8h, v6.h[6] \n" + "fmla v23.8h, v14.8h, v6.h[7] \n" + + "ld1 {v6.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v14.8h}, [%[rhs_ptr]], #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + + "fmla v24.8h, v15.8h, v7.h[0] \n" + "fmla v25.8h, v15.8h, v7.h[1] \n" + "fmla v26.8h, v15.8h, v7.h[2] \n" + "fmla v27.8h, v15.8h, v7.h[3] \n" + "fmla v28.8h, v15.8h, v7.h[4] \n" + "fmla v29.8h, v15.8h, v7.h[5] \n" + "fmla v30.8h, v15.8h, v7.h[6] \n" + "fmla v31.8h, v15.8h, v7.h[7] \n" + + "ld1 {v7.8h}, [%[lhs_ptr]], #16 \n" + "ld1 {v15.8h}, [%[rhs_ptr]], #16 \n" + + "bne 0b \n" + + // prologue + "1:\n" + "fmla v16.8h, v8.8h, v0.h[0] \n" + "fmla v17.8h, v8.8h, v0.h[1] \n" + "fmla v18.8h, v8.8h, v0.h[2] \n" + "fmla v19.8h, v8.8h, v0.h[3] \n" + "fmla v20.8h, v8.8h, v0.h[4] \n" + "fmla v21.8h, v8.8h, v0.h[5] \n" + "fmla v22.8h, v8.8h, v0.h[6] \n" + "fmla v23.8h, v8.8h, v0.h[7] \n" + + "fmla v24.8h, v9.8h, v1.h[0] \n" + "fmla v25.8h, v9.8h, v1.h[1] \n" + "fmla v26.8h, v9.8h, v1.h[2] \n" + "fmla v27.8h, v9.8h, v1.h[3] \n" + "fmla v28.8h, v9.8h, v1.h[4] \n" + "fmla v29.8h, v9.8h, v1.h[5] \n" + "fmla v30.8h, v9.8h, v1.h[6] \n" + "fmla v31.8h, v9.8h, v1.h[7] \n" + + "fmla v16.8h, v10.8h, v2.h[0] \n" + "fmla v17.8h, v10.8h, v2.h[1] \n" + "fmla v18.8h, v10.8h, v2.h[2] \n" + "fmla v19.8h, v10.8h, v2.h[3] \n" + "fmla v20.8h, v10.8h, v2.h[4] \n" + "fmla v21.8h, v10.8h, v2.h[5] \n" + "fmla v22.8h, v10.8h, v2.h[6] \n" + "fmla v23.8h, v10.8h, v2.h[7] \n" + + "fmla v24.8h, v11.8h, v3.h[0] \n" + "fmla v25.8h, v11.8h, v3.h[1] \n" + "fmla v26.8h, v11.8h, v3.h[2] \n" + "fmla v27.8h, v11.8h, v3.h[3] \n" + "fmla v28.8h, v11.8h, v3.h[4] \n" + "fmla v29.8h, v11.8h, v3.h[5] \n" + "fmla v30.8h, v11.8h, v3.h[6] \n" + "fmla v31.8h, v11.8h, v3.h[7] \n" + + "fmla v16.8h, v12.8h, v4.h[0] \n" + "fmla v17.8h, v12.8h, v4.h[1] \n" + "fmla v18.8h, v12.8h, v4.h[2] \n" + "fmla v19.8h, v12.8h, v4.h[3] \n" + "fmla v20.8h, v12.8h, v4.h[4] \n" + "fmla v21.8h, v12.8h, v4.h[5] \n" + "fmla v22.8h, v12.8h, v4.h[6] \n" + "fmla v23.8h, v12.8h, v4.h[7] \n" + + "fmla v24.8h, v13.8h, v5.h[0] \n" + "fmla v25.8h, v13.8h, v5.h[1] \n" + "fmla v26.8h, v13.8h, v5.h[2] \n" + "fmla v27.8h, v13.8h, v5.h[3] \n" + "fmla v28.8h, v13.8h, v5.h[4] \n" + "fmla v29.8h, v13.8h, v5.h[5] \n" + "fmla v30.8h, v13.8h, v5.h[6] \n" + "fmla v31.8h, v13.8h, v5.h[7] \n" + + "fmla v16.8h, v14.8h, v6.h[0] \n" + "fmla v17.8h, v14.8h, v6.h[1] \n" + "fmla v18.8h, v14.8h, v6.h[2] \n" + "fmla v19.8h, v14.8h, v6.h[3] \n" + "fmla v20.8h, v14.8h, v6.h[4] \n" + "fmla v21.8h, v14.8h, v6.h[5] \n" + "fmla v22.8h, v14.8h, v6.h[6] \n" + "fmla v23.8h, v14.8h, v6.h[7] \n" + + "fmla v24.8h, v15.8h, v7.h[0] \n" + "fmla v25.8h, v15.8h, v7.h[1] \n" + "fmla v26.8h, v15.8h, v7.h[2] \n" + "fmla v27.8h, v15.8h, v7.h[3] \n" + "fmla v28.8h, v15.8h, v7.h[4] \n" + "fmla v29.8h, v15.8h, v7.h[5] \n" + "fmla v30.8h, v15.8h, v7.h[6] \n" + "fmla v31.8h, v15.8h, v7.h[7] \n" + + "st1 {v16.8h}, [%[packed_output_data]], #16 \n" + "st1 {v17.8h}, [%[packed_output_data]], #16 \n" + "st1 {v18.8h}, [%[packed_output_data]], #16 \n" + "st1 {v19.8h}, [%[packed_output_data]], #16 \n" + "st1 {v20.8h}, [%[packed_output_data]], #16 \n" + "st1 {v21.8h}, [%[packed_output_data]], #16 \n" + "st1 {v22.8h}, [%[packed_output_data]], #16 \n" + "st1 {v23.8h}, [%[packed_output_data]], #16 \n" + "st1 {v24.8h}, [%[packed_output_data]], #16 \n" + "st1 {v25.8h}, [%[packed_output_data]], #16 \n" + "st1 {v26.8h}, [%[packed_output_data]], #16 \n" + "st1 {v27.8h}, [%[packed_output_data]], #16 \n" + "st1 {v28.8h}, [%[packed_output_data]], #16 \n" + "st1 {v29.8h}, [%[packed_output_data]], #16 \n" + "st1 {v30.8h}, [%[packed_output_data]], #16 \n" + "st1 {v31.8h}, [%[packed_output_data]], #16 \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), + [rhs_ptr] "+r"(rhs_ptr), + [packed_output_data] "+r"(packed_output_data), + [r_depth_block_count] "+r"(r_depth_block_count) + : // inputs + : // clabbers + "cc", "memory", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } +} + +template<> +MaceStatus Gemm::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched, + Tensor *output) { + MACE_CHECK(output->size() == batch * rows * cols, + "Need resize output tensor before call gemm."); + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard output_guard(output); + const float16_t *lhs_data = lhs->data(); + const float16_t *rhs_data = rhs->data(); + float16_t *output_data = output->mutable_data(); + + const index_t row_block_size = 8; + const index_t col_block_size = 8; + const index_t depth_block_size = 8; + const index_t row_block_count = RoundUpDiv(rows, row_block_size); + const index_t col_block_count = RoundUpDiv(cols, col_block_size); + const index_t rows_padded = RoundUp(rows, row_block_size); + const index_t cols_padded = RoundUp(cols, col_block_size); + const index_t depth_padded = RoundUp(depth, depth_block_size); + + ScratchBuffer *scratch = context->device()->scratch_buffer(); + + index_t packed_lhs_size = + PadAlignSize(sizeof(float16_t) * rows_padded * depth_padded); + index_t packed_rhs_size = + PadAlignSize(sizeof(float16_t) * depth_padded * cols_padded); + index_t packed_output_size = + PadAlignSize(sizeof(float16_t) * rows_padded * cols_padded); + // resize to the total size of lhs & rhs & output anyway, + // in case we do not cache const tensor for saving memory + scratch->Rewind(); + MACE_RETURN_IF_ERROR(scratch->GrowSize( + packed_lhs_size + packed_rhs_size + packed_output_size)); + float16_t *packed_lhs_data = + scratch->Scratch(packed_lhs_size).mutable_data(); + float16_t *packed_rhs_data = + scratch->Scratch(packed_rhs_size).mutable_data(); + float16_t *packed_output_data = + scratch->Scratch(packed_output_size).mutable_data(); + + int cache_side = kNoCache; + if (cached_ == kCacheLhs) { + packed_lhs_data = pack_cache_.mutable_data(); + } else if (cached_ == kCacheRhs) { + packed_rhs_data = pack_cache_.mutable_data(); + } else if (should_cache_pack_) { + if (lhs->is_weight() && (!lhs_batched || batch == 1)) { + cache_side = kCacheLhs; + pack_cache_.Resize(packed_lhs_size); + packed_lhs_data = pack_cache_.mutable_data(); + } else if (rhs->is_weight() && (!rhs_batched || batch == 1)) { + cache_side = kCacheRhs; + pack_cache_.Resize(packed_rhs_size); + packed_rhs_data = pack_cache_.mutable_data(); + } + } + + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + + for (index_t b = 0; b < batch; ++b) { + MatrixMap + lhs_matrix + (lhs_data + static_cast(lhs_batched) * b * rows * depth, + lhs_major, + rows, + depth); + MatrixMap + rhs_matrix + (rhs_data + static_cast(rhs_batched) * b * depth * cols, + rhs_major, + depth, + cols); + MatrixMap output_matrix + (output_data + b * rows * cols, output_major, rows, cols); + + // pack lhs + if (cached_ != kCacheLhs) { + thread_pool.Compute1D([=, &lhs_matrix](index_t start, + index_t end, + index_t step) { + for (index_t row_block_idx = start; row_block_idx < end; + row_block_idx += step) { + const index_t start_row = row_block_idx * row_block_size; + const index_t + row_block_len = std::min(row_block_size, rows - start_row); + float16_t *packed_lhs_data_block = + packed_lhs_data + row_block_idx * row_block_size * depth_padded; + PackLhs(lhs_matrix.block(start_row, 0, row_block_len, depth), + packed_lhs_data_block); + } + }, 0, row_block_count, 1); + + if (cache_side == kCacheLhs) { + cached_ = kCacheLhs; + if (lhs->UnderlyingBuffer()->OnHost()) { + AdviseFree(reinterpret_cast(const_cast(lhs->data< + float16_t>())), + lhs->raw_size()); + } + } + } + + // pack rhs + if (cached_ != kCacheRhs) { + thread_pool.Compute1D([=, &rhs_matrix](index_t start, + index_t end, + index_t step) { + for (index_t col_block_idx = start; col_block_idx < end; + col_block_idx += step) { + const index_t start_col = col_block_idx * col_block_size; + const index_t + col_block_len = std::min(col_block_size, cols - start_col); + float16_t *packed_rhs_data_block = + packed_rhs_data + col_block_idx * col_block_size * depth_padded; + PackRhs(rhs_matrix.block(0, start_col, depth, col_block_len), + packed_rhs_data_block); + } + }, 0, col_block_count, 1); + + if (cache_side == kCacheRhs) { + cached_ = kCacheRhs; + if (rhs->UnderlyingBuffer()->OnHost()) { + AdviseFree(reinterpret_cast(const_cast(rhs->data< + float16_t>())), + rhs->raw_size()); + } + } + } + + // multiply lhs and rhs + thread_pool.Compute1D([=, &output_matrix](index_t start, + index_t end, + index_t step) { + for (index_t row_block_idx = start; row_block_idx < end; + row_block_idx += step) { + const index_t start_row = row_block_idx * row_block_size; + const index_t + row_block_len = std::min(row_block_size, rows - start_row); + const float16_t *packed_lhs_data_block = + packed_lhs_data + row_block_idx * row_block_size * depth_padded; + + for (index_t col_block_idx = 0; col_block_idx < col_block_count; + ++col_block_idx) { + const index_t start_col = col_block_idx * col_block_size; + const index_t + col_block_len = std::min(col_block_size, cols - start_col); + const float16_t *packed_rhs_data_block = + packed_rhs_data + col_block_idx * col_block_size * depth_padded; + float16_t *packed_output_data_block = + packed_output_data + row_block_idx * row_block_size * cols_padded + + col_block_idx * col_block_size; + ComputeBlock(packed_lhs_data_block, + packed_rhs_data_block, + depth_padded, + packed_output_data_block); + MatrixMap output_block = + output_matrix.block(start_row, + start_col, + row_block_len, + col_block_len); + UnpackOutput(packed_output_data_block, &output_block); + } // col_block_idx + } // row_block_idx + }, 0, row_block_count, 1); + } // b + + return MaceStatus::MACE_SUCCESS; +} + +void RegisterFP16GemmDelegator(OpDelegatorRegistry *registry) { + MACE_REGISTER_FP16_DELEGATOR( + registry, Gemm, delegator::GemmParam, + MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float16_t, ImplType::NEON)); +} +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/registry/op_delegators_registry.cc b/mace/ops/registry/op_delegators_registry.cc index b7c86133d27f73d0daa2e7a9532e9a7c407020d8..e7f119ca2308c2158a125966c3f583ca0e45f450 100644 --- a/mace/ops/registry/op_delegators_registry.cc +++ b/mace/ops/registry/op_delegators_registry.cc @@ -69,7 +69,11 @@ extern void RegisterGroupDeconv2dGeneralDelegator( extern void RegisterGemmDelegator(OpDelegatorRegistry *registry); extern void RegisterGemvDelegator(OpDelegatorRegistry *registry); - +#ifdef MACE_ENABLE_FP16 +extern void RegisterFP16DepthwiseConv2dK3x3Delegator( + OpDelegatorRegistry *registry); +extern void RegisterFP16GemmDelegator(OpDelegatorRegistry *registry); +#endif #ifdef MACE_ENABLE_QUANTIZE namespace q8 { extern void RegisterEltwiseDelegator(OpDelegatorRegistry *registry); @@ -89,7 +93,6 @@ void RegisterAllOpDelegators(OpDelegatorRegistry *registry) { ref::RegisterDepthwiseDeconv2dDelegator(registry); ref::RegisterGemmDelegator(registry); ref::RegisterGemvDelegator(registry); - #ifdef MACE_ENABLE_QUANTIZE ref::q8::RegisterEltwiseDelegator(registry); ref::q8::RegisterGemvDelegator(registry); @@ -123,7 +126,10 @@ void RegisterAllOpDelegators(OpDelegatorRegistry *registry) { arm::RegisterGemmDelegator(registry); arm::RegisterGemvDelegator(registry); - +#ifdef MACE_ENABLE_FP16 + arm::RegisterFP16DepthwiseConv2dK3x3Delegator(registry); + arm::RegisterFP16GemmDelegator(registry); +#endif #ifdef MACE_ENABLE_QUANTIZE arm::q8::RegisterEltwiseDelegator(registry); arm::q8::RegisterGemvDelegator(registry); diff --git a/tools/bazel_build_standalone_lib.sh b/tools/bazel_build_standalone_lib.sh index f6bcbaa192d8724b4d5f6766ffef312dd252c8f4..3fa74f34dc71262a5a32e8feb7c1b84ed4f6a179 100755 --- a/tools/bazel_build_standalone_lib.sh +++ b/tools/bazel_build_standalone_lib.sh @@ -115,7 +115,7 @@ $(echo "$1" | cut -d '=' -f -1)" esac done -if [[ "${enable_apu}" == true ]];then +if [[ "${enable_apu}" == true || "${abi}" != armeabi-v7a || "${abi}" != arm64-v8a ]];then enable_rpcmem=false fi