提交 c764ba25 编写于 作者: L liyin

Refactor winograd

上级 3d88cf68
......@@ -434,16 +434,11 @@ class BufferSlice : public BufferBase {
}
void *Map(index_t offset, index_t length, std::vector<size_t> *pitch) const {
MACE_UNUSED(offset);
MACE_UNUSED(length);
MACE_UNUSED(pitch);
MACE_NOT_IMPLEMENTED;
return nullptr;
return buffer_->Map(offset_ + offset, length, pitch);
}
void UnMap(void *mapped_ptr) const {
MACE_UNUSED(mapped_ptr);
MACE_NOT_IMPLEMENTED;
buffer_->UnMap(mapped_ptr);
}
void Map(std::vector<size_t> *pitch) {
......
......@@ -304,10 +304,14 @@ class Tensor {
if (buffer_ != nullptr) {
MACE_CHECK(!has_opencl_image(),
name_, ": Cannot resize image, use ResizeImage.");
if (raw_size() + MACE_EXTRA_BUFFER_PAD_SIZE > buffer_->size()) {
const index_t apply_size = raw_size()
+ ((buffer_ != &buffer_slice_) ? MACE_EXTRA_BUFFER_PAD_SIZE : 0);
if (apply_size > buffer_->size()) {
LOG(WARNING) << name_ << ": Resize buffer from size " << buffer_->size()
<< " to " << raw_size() + MACE_EXTRA_BUFFER_PAD_SIZE;
return buffer_->Resize(raw_size() + MACE_EXTRA_BUFFER_PAD_SIZE);
<< " to " << apply_size;
MACE_CHECK(buffer_ != &buffer_slice_,
": Cannot resize tensor with buffer slice");
return buffer_->Resize(apply_size);
}
return MaceStatus::MACE_SUCCESS;
} else {
......
......@@ -16,22 +16,10 @@
#define MACE_OPS_ARM_CONV_2D_NEON_H_
#include "mace/core/types.h"
#include "mace/ops/sgemm.h"
namespace mace {
namespace ops {
void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer);
void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/conv_2d_neon.h"
namespace mace {
namespace ops {
void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
for (index_t b = 0; b < batch; ++b) {
sgemm->Run(filter,
input + b * in_channels * height * width,
1,
out_channels,
in_channels,
in_channels,
height * width,
false,
false,
true,
false,
output + b * out_channels * height * width,
scratch_buffer);
}
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_CONV_WINOGRAD_H_
#define MACE_OPS_ARM_CONV_WINOGRAD_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/ops/sgemm.h"
namespace mace {
namespace ops {
void TransformFilter4x4(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output);
void TransformFilter8x8(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output);
void WinogradConv3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer);
void WinogradConv3x3s1(const float *input,
const float *transformed_filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *transformed_input,
float *transformed_output,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer);
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output);
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_CONV_WINOGRAD_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <random>
#include "mace/core/tensor.h"
#include "mace/core/types.h"
#include "mace/ops/arm/conv_winograd.h"
namespace mace {
namespace ops {
TEST(ConvWinogradTest, winograd) {
index_t batch = 1;
index_t in_height = 32;
index_t in_width = 32;
index_t in_channels = 64;
index_t out_channels = 128;
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t input_size = batch * in_channels * in_height * in_width;
index_t filter_size = 3 * 3 * in_channels * out_channels;
index_t output_size = batch * out_channels * out_height * out_width;
Tensor input(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor filter(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output_ref(GetCPUAllocator(), DataType::DT_FLOAT);
input.Resize({batch, in_channels, in_height, in_width});
filter.Resize({out_channels, in_channels, 3, 3});
output.Resize({batch, out_channels, out_height, out_width});
output_ref.Resize({batch, out_channels, out_height, out_width});
float *input_data = input.mutable_data<float>();
float *filter_data = filter.mutable_data<float>();
float *output_data = output.mutable_data<float>();
float *output_data_ref = output.mutable_data<float>();
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(input_data, input_data + input_size, [&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen)));
});
std::generate(filter_data, filter_data + filter_size, [&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen)));
});
ops::ConvRef3x3s1(input_data, filter_data, batch, in_height, in_width,
in_channels, out_channels, output_data_ref);
SGemm sgemm;
ops::WinogradConv3x3s1(input_data, filter_data, batch, in_height,
in_width, in_channels, out_channels, 6,
output_data, &sgemm, nullptr);
// test
for (index_t i = 0; i < output_size; ++i) {
EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1) << " with index " << i;
}
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,21 +14,392 @@
#include <algorithm>
#include "mace/ops/arm/conv_winograd.h"
#include "mace/utils/memory.h"
#include "mace/ops/arm/fp32/conv_2d_3x3_winograd.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
const index_t batch = input->dim(0);
const index_t in_channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
const index_t out_channels = filter->dim(0);
index_t padded_in_height = in_height + pad_top_ + pad_bottom_;
index_t padded_in_width = in_width + pad_left_ + pad_right_;
index_t out_height = padded_in_height - 2;
index_t out_width = padded_in_width - 2;
output->Resize({batch, out_channels, out_height, out_width});
// When size of input feature map is bigger than 16x16,
// set winograd out tile size to 6 to get higher performance.
index_t out_tile_size = 2;
if (in_height > 16 && in_width > 16) {
out_tile_size = 6;
}
const index_t padded_out_height = RoundUp<index_t>(out_height, out_tile_size);
const index_t padded_out_width = RoundUp<index_t>(out_width, out_tile_size);
padded_in_height = std::max(padded_in_height, padded_out_height + 2);
padded_in_width = std::max(padded_in_width, padded_out_width + 2);
bool is_out_padded =
padded_out_height != out_height || padded_out_width != out_width;
const index_t
tile_height_count = padded_out_height / out_tile_size;
const index_t tile_width_count = padded_out_width / out_tile_size;
const index_t tile_count = tile_height_count * tile_width_count;
const index_t in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
// pad input and transform input
auto scratch_buffer = context->device()->scratch_buffer();
const index_t padded_in_size = PadAlignSize(
sizeof(float) * batch * in_channels * padded_in_height * padded_in_width);
const index_t padded_out_size = is_out_padded ? PadAlignSize(
sizeof(float) * batch * out_channels * padded_out_height
* padded_out_width) : 0;
const index_t transformed_in_size = PadAlignSize(
sizeof(float) * batch * in_tile_area * in_channels * tile_count);
const index_t transformed_out_size = PadAlignSize(
sizeof(float) * batch * in_tile_area * out_channels * tile_count);
const index_t transformed_filter_size =
PadAlignSize(sizeof(float) * in_tile_area * out_channels * in_channels);
const index_t gemm_pack_size =
transformed_in_size + transformed_filter_size + transformed_filter_size;
scratch_buffer->Rewind();
scratch_buffer->GrowSize(
padded_in_size + padded_out_size + transformed_in_size
+ transformed_out_size + gemm_pack_size);
Tensor padded_in(scratch_buffer->Scratch(padded_in_size), DataType::DT_FLOAT);
padded_in.Resize({batch, in_channels, padded_in_height, padded_in_width});
Tensor *padded_out = output;
Tensor tmp_padded_out
(scratch_buffer->Scratch(padded_out_size), DataType::DT_FLOAT);
if (is_out_padded) {
padded_out = &tmp_padded_out;
padded_out->Resize({batch, out_channels, padded_out_height,
padded_out_width});
}
auto transformed_in = scratch_buffer->Scratch(transformed_in_size);
auto transformed_out = scratch_buffer->Scratch(transformed_out_size);
auto padded_in_data = padded_in.data<float>();
auto padded_out_data = padded_out->mutable_data<float>();
auto transformed_in_data = transformed_in.mutable_data<float>();
auto transformed_out_data = transformed_out.mutable_data<float>();
const index_t padded_bottom = padded_in_height - in_height - pad_top_;
const index_t padded_right = padded_in_width - in_width - pad_left_;
ConstructNCHWInputWithSpecificPadding(input,
pad_top_,
padded_bottom,
pad_left_,
padded_right,
&padded_in);
Tensor::MappingGuard filter_guard(filter);
auto filter_data = filter->data<float>();
if (!filter->is_weight() || out_tile_size != out_tile_size_) {
out_tile_size_ = out_tile_size;
transformed_filter_.reset(new Tensor);
transformed_filter_->Resize({in_tile_area, out_channels, in_channels});
auto transformed_filter_data = transformed_filter_->mutable_data<float>();
switch (out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
in_channels,
out_channels,
transformed_filter_data);
break;
case 6:
TransformFilter8x8(filter_data,
in_channels,
out_channels,
transformed_filter_data);
break;
default:MACE_NOT_IMPLEMENTED;
}
}
switch (out_tile_size) {
case 2:
TransformInput4x4(padded_in_data,
batch,
padded_in_height,
padded_in_width,
in_channels,
tile_count,
transformed_in_data);
break;
case 6:
TransformInput8x8(padded_in_data,
batch,
padded_in_height,
padded_in_width,
in_channels,
tile_count,
transformed_in_data);
break;
default:MACE_NOT_IMPLEMENTED;
}
const index_t scratch_buffer_offset = scratch_buffer->offset();
const index_t transformed_in_size_per_batch =
in_tile_area * in_channels * tile_count * sizeof(float);
const index_t transformed_out_size_per_batch =
in_tile_area * out_channels * tile_count * sizeof(float);
for (index_t b = 0; b < batch; ++b) {
scratch_buffer->Rewind(scratch_buffer_offset);
BufferSlice transformed_in_slice(&transformed_in,
b * transformed_in_size_per_batch,
transformed_in_size_per_batch);
BufferSlice transformed_out_slice(&transformed_out,
b * transformed_out_size_per_batch,
transformed_out_size_per_batch);
Tensor transformed_in_this_batch(transformed_in_slice, DataType::DT_FLOAT);
transformed_in_this_batch.Resize({in_tile_area, in_channels, tile_count});
Tensor
transformed_out_this_batch(transformed_out_slice, DataType::DT_FLOAT);
transformed_out_this_batch.Resize({in_tile_area, out_channels, tile_count});
gemm_.Compute(context,
transformed_filter_.get(),
&transformed_in_this_batch,
in_tile_area,
out_channels,
in_channels,
in_channels,
tile_count,
false,
false,
false,
true,
true,
&transformed_out_this_batch);
}
switch (out_tile_size) {
case 2:
TransformOutput4x4(transformed_out_data,
batch,
padded_out_height,
padded_out_width,
out_channels,
tile_count,
padded_out_data);
break;
case 6:
TransformOutput8x8(transformed_out_data,
batch,
padded_out_height,
padded_out_width,
out_channels,
tile_count,
padded_out_data);
break;
default:MACE_NOT_IMPLEMENTED;
}
if (is_out_padded) {
UnPackOutput(*padded_out, output);
}
return MaceStatus::MACE_SUCCESS;
}
void Conv2dK3x3Winograd::UnPackOutput(const Tensor &src, Tensor *dst) {
const index_t batch = dst->dim(0);
const index_t channels = dst->dim(1);
const index_t height = dst->dim(2);
const index_t width = dst->dim(3);
const index_t padded_height = src.dim(2);
const index_t padded_width = src.dim(3);
if (height == padded_height && width == padded_width) {
return;
}
auto padded_out_data = src.data<float>();
auto out_data = dst->mutable_data<float>();
const index_t img_size = height * width;
const index_t padded_img_size = padded_height * padded_width;
#pragma omp parallel for collapse(3) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
memcpy(
out_data + (b * channels + c) * img_size
+ h * width,
padded_out_data
+ (b * channels + c) * padded_img_size
+ h * padded_width,
sizeof(float) * width);
} // h
} // c
} // b
}
// OCHW => TOC
void Conv2dK3x3Winograd::TransformFilter4x4(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
float g0, g1, g2, g3, g4, g5, g6, g7, g8;
float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
s15;
// load filter
index_t filter_offset = (m * in_channels + c) * 9;
g0 = filter[filter_offset];
g1 = filter[filter_offset + 1];
g2 = filter[filter_offset + 2];
g3 = filter[filter_offset + 3];
g4 = filter[filter_offset + 4];
g5 = filter[filter_offset + 5];
g6 = filter[filter_offset + 6];
g7 = filter[filter_offset + 7];
g8 = filter[filter_offset + 8];
// s = G * g * GT
s0 = g0;
s1 = (g0 + g2 + g1) * 0.5f;
s2 = (g0 + g2 - g1) * 0.5f;
s3 = g2;
s4 = (g0 + g6 + g3) * 0.5f;
s5 = ((g0 + g6 + g3) + (g2 + g8 + g5) + (g1 + g7 + g4)) * 0.25f;
s6 = ((g0 + g6 + g3) + (g2 + g8 + g5) - (g1 + g7 + g4)) * 0.25f;
s7 = (g2 + g8 + g5) * 0.5f;
s8 = (g0 + g6 - g3) * 0.5f;
s9 = ((g0 + g6 - g3) + (g2 + g8 - g5) + (g1 + g7 - g4)) * 0.25f;
s10 = ((g0 + g6 - g3) + (g2 + g8 - g5) - (g1 + g7 - g4)) * 0.25f;
s11 = (g2 + g8 - g5) * 0.5f;
s12 = g6;
s13 = (g6 + g8 + g7) * 0.5f;
s14 = (g6 + g8 - g7) * 0.5f;
s15 = g8;
// store output
index_t output_offset = m * in_channels + c;
output[output_offset + 0 * stride] = s0;
output[output_offset + 1 * stride] = s1;
output[output_offset + 2 * stride] = s2;
output[output_offset + 3 * stride] = s3;
output[output_offset + 4 * stride] = s4;
output[output_offset + 5 * stride] = s5;
output[output_offset + 6 * stride] = s6;
output[output_offset + 7 * stride] = s7;
output[output_offset + 8 * stride] = s8;
output[output_offset + 9 * stride] = s9;
output[output_offset + 10 * stride] = s10;
output[output_offset + 11 * stride] = s11;
output[output_offset + 12 * stride] = s12;
output[output_offset + 13 * stride] = s13;
output[output_offset + 14 * stride] = s14;
output[output_offset + 15 * stride] = s15;
}
}
}
// OCHW => TOC
/**
* G =
⎡ 1 0 0 ⎤
⎢ ⎥
⎢-2/9 -2/9 -2/9 ⎥
⎢ ⎥
⎢-2/9 2/9 -2/9 ⎥
⎢ ⎥
⎢1/90 1/45 2/45 ⎥
⎢ ⎥
⎢1/90 -1/45 2/45 ⎥
⎢ ⎥
⎢1/45 1/90 1/180⎥
⎢ ⎥
⎢1/45 -1/90 1/180⎥
⎢ ⎥
⎣ 0 0 1 ⎦
*/
void Conv2dK3x3Winograd::TransformFilter8x8(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
const float G[8][3] = {{1.0f, 0.0f, 0.0f},
{-2.0f / 9, -2.0f / 9, -2.0f / 9},
{-2.0f / 9, 2.0f / 9, -2.0f / 9},
{1.0f / 90, 1.0f / 45, 2.0f / 45},
{1.0f / 90, -1.0f / 45, 2.0f / 45},
{1.0f / 45, 1.0f / 90, 1.0f / 180},
{1.0f / 45, -1.0f / 90, 1.0f / 180},
{0.0f, 0.0f, 1.0f}};
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
// load filter
index_t filter_offset = (m * in_channels + c) * 9;
float g0, g1, g2, g3, g4, g5, g6, g7, g8;
g0 = filter[filter_offset];
g1 = filter[filter_offset + 1];
g2 = filter[filter_offset + 2];
g3 = filter[filter_offset + 3];
g4 = filter[filter_offset + 4];
g5 = filter[filter_offset + 5];
g6 = filter[filter_offset + 6];
g7 = filter[filter_offset + 7];
g8 = filter[filter_offset + 8];
float s[3][8];
for (int i = 0; i < 8; ++i) {
s[0][i] = g0 * G[i][0] + g1 * G[i][1] + g2 * G[i][2];
s[1][i] = g3 * G[i][0] + g4 * G[i][1] + g5 * G[i][2];
s[2][i] = g6 * G[i][0] + g7 * G[i][1] + g8 * G[i][2];
}
// store output
index_t output_offset = m * in_channels + c;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
output[output_offset + (i * 8 + j) * stride] =
G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j];
}
}
}
}
}
namespace {
// NCHW => NTCB (T: in tile pixels, B: tile indices)
void TransformInput4x4(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
void Conv2dK3x3Winograd::TransformInput4x4(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
......@@ -47,7 +418,7 @@ void TransformInput4x4(const float *input,
// load tile data
const float *input_ptr = input + n * input_batch_size +
c * in_height_width + h * in_width + w;
c * in_height_width + h * in_width + w;
d0 = input_ptr[0];
d1 = input_ptr[1];
d2 = input_ptr[2];
......@@ -134,22 +505,14 @@ void TransformInput4x4(const float *input,
⎢0 -2 4 5/2 -5 -1/2 1 0⎥
⎢ ⎥
⎣0 -1 0 21/4 0 -21/4 0 1⎦
* @param input
* @param batch
* @param in_height
* @param in_width
* @param in_channels
* @param tile_count
* @param output
*/
void TransformInput8x8(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
void Conv2dK3x3Winograd::TransformInput8x8(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
const index_t input_batch_size = in_height_width * in_channels;
......@@ -163,7 +526,7 @@ void TransformInput8x8(const float *input,
for (index_t h = 0; h < in_height - 2; h += 6) {
for (index_t w = 0; w < in_width - 2; w += 6) {
const float *input_ptr = input + n * input_batch_size +
c * in_height_width + h * in_width + w;
c * in_height_width + h * in_width + w;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
......@@ -236,57 +599,14 @@ void TransformInput8x8(const float *input,
}
}
// TOC * NTCB => NTOB
void BatchGemm(const float *input,
const float *filter,
index_t batch,
index_t in_channels,
index_t out_channels,
index_t tile_count,
int out_tile_size,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
const index_t in_batch_size = in_tile_area * in_channels * tile_count;
const index_t out_batch_size = in_tile_area * out_channels * tile_count;
index_t scratch_buffer_offset = 0;
if (scratch_buffer) {
scratch_buffer_offset = scratch_buffer->offset();
}
// 'batch' is not gemm batch, 'in_tile_area' is. gemm is not thread safe,
// so we loop batch using single thread.
// Scratch buffer should be rewind to the initial position to use same
// scratch memory for each batch.
for (int b = 0; b < batch; ++b) {
if (scratch_buffer) {
scratch_buffer->Rewind(scratch_buffer_offset);
}
sgemm->Run(filter,
input + b * in_batch_size,
in_tile_area,
out_channels,
in_channels,
in_channels,
tile_count,
false,
false,
true,
false,
output + b * out_batch_size,
scratch_buffer);
}
}
// NTOB => NToOB => NOHoWo
void TransformOutput4x4(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
void Conv2dK3x3Winograd::TransformOutput4x4(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 16 * stride;
const index_t out_image_size = out_height * out_width;
......@@ -340,7 +660,7 @@ void TransformOutput4x4(const float *input,
v3 = s3 - s5 - s7;
float *output_ptr = output + n * output_batch_size +
m * out_image_size + h * out_width + w;
m * out_image_size + h * out_width + w;
output_ptr[0] = v0;
output_ptr[1] = v1;
output_ptr[out_width] = v2;
......@@ -367,22 +687,14 @@ void TransformOutput4x4(const float *input,
⎢0 1 1 16 16 2 2 0⎥
⎢ ⎥
⎣0 1 -1 32 -32 1 -1 1⎦
*
* @param input
* @param batch
* @param out_height
* @param out_width
* @param out_channels
* @param tile_count
* @param output
*/
void TransformOutput8x8(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
void Conv2dK3x3Winograd::TransformOutput8x8(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t stride = out_channels * tile_count;
const index_t input_batch_size = 64 * stride;
const index_t out_image_size = out_height * out_width;
......@@ -427,7 +739,7 @@ void TransformOutput8x8(const float *input,
}
float *output_ptr = output + n * output_batch_size +
m * out_image_size + h * out_width + w;
m * out_image_size + h * out_width + w;
for (int i = 0; i < 6; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
......@@ -461,293 +773,8 @@ void TransformOutput8x8(const float *input,
}
}
}
} // namespace
// OCHW => TOC
// no need to optimize, it will exist in converter
void TransformFilter4x4(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
float g0, g1, g2, g3, g4, g5, g6, g7, g8;
float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
s15;
// load filter
index_t filter_offset = (m * in_channels + c) * 9;
g0 = filter[filter_offset];
g1 = filter[filter_offset + 1];
g2 = filter[filter_offset + 2];
g3 = filter[filter_offset + 3];
g4 = filter[filter_offset + 4];
g5 = filter[filter_offset + 5];
g6 = filter[filter_offset + 6];
g7 = filter[filter_offset + 7];
g8 = filter[filter_offset + 8];
// s = G * g * GT
s0 = g0;
s1 = (g0 + g2 + g1) * 0.5f;
s2 = (g0 + g2 - g1) * 0.5f;
s3 = g2;
s4 = (g0 + g6 + g3) * 0.5f;
s5 = ((g0 + g6 + g3) + (g2 + g8 + g5) + (g1 + g7 + g4)) * 0.25f;
s6 = ((g0 + g6 + g3) + (g2 + g8 + g5) - (g1 + g7 + g4)) * 0.25f;
s7 = (g2 + g8 + g5) * 0.5f;
s8 = (g0 + g6 - g3) * 0.5f;
s9 = ((g0 + g6 - g3) + (g2 + g8 - g5) + (g1 + g7 - g4)) * 0.25f;
s10 = ((g0 + g6 - g3) + (g2 + g8 - g5) - (g1 + g7 - g4)) * 0.25f;
s11 = (g2 + g8 - g5) * 0.5f;
s12 = g6;
s13 = (g6 + g8 + g7) * 0.5f;
s14 = (g6 + g8 - g7) * 0.5f;
s15 = g8;
// store output
index_t output_offset = m * in_channels + c;
output[output_offset + 0 * stride] = s0;
output[output_offset + 1 * stride] = s1;
output[output_offset + 2 * stride] = s2;
output[output_offset + 3 * stride] = s3;
output[output_offset + 4 * stride] = s4;
output[output_offset + 5 * stride] = s5;
output[output_offset + 6 * stride] = s6;
output[output_offset + 7 * stride] = s7;
output[output_offset + 8 * stride] = s8;
output[output_offset + 9 * stride] = s9;
output[output_offset + 10 * stride] = s10;
output[output_offset + 11 * stride] = s11;
output[output_offset + 12 * stride] = s12;
output[output_offset + 13 * stride] = s13;
output[output_offset + 14 * stride] = s14;
output[output_offset + 15 * stride] = s15;
}
}
}
// OCHW => TOC
// no need to optimize, it will exist in converter
/**
* G =
⎡ 1 0 0 ⎤
⎢ ⎥
⎢-2/9 -2/9 -2/9 ⎥
⎢ ⎥
⎢-2/9 2/9 -2/9 ⎥
⎢ ⎥
⎢1/90 1/45 2/45 ⎥
⎢ ⎥
⎢1/90 -1/45 2/45 ⎥
⎢ ⎥
⎢1/45 1/90 1/180⎥
⎢ ⎥
⎢1/45 -1/90 1/180⎥
⎢ ⎥
⎣ 0 0 1 ⎦
*
* @param filter
* @param in_channels
* @param out_channels
* @param output
*/
void TransformFilter8x8(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
const float G[8][3] = {{1.0f, 0.0f, 0.0f},
{-2.0f / 9, -2.0f / 9, -2.0f / 9},
{-2.0f / 9, 2.0f / 9, -2.0f / 9},
{1.0f / 90, 1.0f / 45, 2.0f / 45},
{1.0f / 90, -1.0f / 45, 2.0f / 45},
{1.0f / 45, 1.0f / 90, 1.0f / 180},
{1.0f / 45, -1.0f / 90, 1.0f / 180},
{0.0f, 0.0f, 1.0f}};
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
// load filter
index_t filter_offset = (m * in_channels + c) * 9;
float g0, g1, g2, g3, g4, g5, g6, g7, g8;
g0 = filter[filter_offset];
g1 = filter[filter_offset + 1];
g2 = filter[filter_offset + 2];
g3 = filter[filter_offset + 3];
g4 = filter[filter_offset + 4];
g5 = filter[filter_offset + 5];
g6 = filter[filter_offset + 6];
g7 = filter[filter_offset + 7];
g8 = filter[filter_offset + 8];
float s[3][8];
for (int i = 0; i < 8; ++i) {
s[0][i] = g0 * G[i][0] + g1 * G[i][1] + g2 * G[i][2];
s[1][i] = g3 * G[i][0] + g4 * G[i][1] + g5 * G[i][2];
s[2][i] = g6 * G[i][0] + g7 * G[i][1] + g8 * G[i][2];
}
// store output
index_t output_offset = m * in_channels + c;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
output[output_offset + (i * 8 + j) * stride] =
G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j];
}
}
}
}
}
void WinogradConv3x3s1(const float *input,
const float *transformed_filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *transformed_input,
float *transformed_output,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count =
RoundUpDiv(out_height, static_cast<index_t>(out_tile_size));
index_t tile_width_count =
RoundUpDiv(out_width, static_cast<index_t>(out_tile_size));
index_t tile_count = tile_height_count * tile_width_count;
switch (out_tile_size) {
case 2:
TransformInput4x4(input, batch, in_height, in_width, in_channels,
tile_count, transformed_input);
break;
case 6:
TransformInput8x8(input, batch, in_height, in_width, in_channels,
tile_count, transformed_input);
break;
default:
MACE_NOT_IMPLEMENTED;
}
BatchGemm(transformed_input, transformed_filter, batch, in_channels,
out_channels, tile_count, out_tile_size, transformed_output,
sgemm, scratch_buffer);
switch (out_tile_size) {
case 2:
TransformOutput4x4(transformed_output, batch, out_height, out_width,
out_channels, tile_count, output);
break;
case 6:
TransformOutput8x8(transformed_output, batch, out_height, out_width,
out_channels, tile_count, output);
break;
default:
MACE_NOT_IMPLEMENTED;
}
}
void WinogradConv3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count =
RoundUpDiv(out_height, static_cast<index_t>(out_tile_size));
index_t tile_width_count =
RoundUpDiv(out_width, static_cast<index_t>(out_tile_size));
index_t tile_count = tile_height_count * tile_width_count;
index_t in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
index_t transformed_input_size =
in_tile_area * batch * in_channels * tile_count;
index_t transformed_filter_size = in_tile_area * out_channels * in_channels;
index_t transformed_output_size =
in_tile_area * batch * out_channels * tile_count;
auto transformed_input =
make_unique<float[]>(transformed_input_size); // TNCB NOLINT
auto transformed_filter =
make_unique<float[]>(transformed_filter_size); // TOC NOLINT
auto transformed_output =
make_unique<float[]>(transformed_output_size); // NOLINT
switch (out_tile_size) {
case 2:
TransformFilter4x4(filter, in_channels, out_channels,
transformed_filter.get());
break;
case 6:
TransformFilter8x8(filter, in_channels, out_channels,
transformed_filter.get());
break;
default:
MACE_NOT_IMPLEMENTED;
}
WinogradConv3x3s1(input, transformed_filter.get(), batch, in_height,
in_width, in_channels, out_channels, out_tile_size,
transformed_input.get(), transformed_output.get(),
output, sgemm, scratch_buffer);
}
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
#pragma omp parallel for collapse(4) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
output[out_offset] = 0;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < 3; ++kh) {
for (index_t kw = 0; kw < 3; ++kw) {
index_t ih = h + kh;
index_t iw = w + kw;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((m * in_channels) + c) * 3 + kh) * 3 + kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_3X3_WINOGRAD_H_
#define MACE_OPS_ARM_FP32_CONV_2D_3X3_WINOGRAD_H_
#include <memory>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dK3x3Winograd : public Conv2dBase {
public:
Conv2dK3x3Winograd(int pad_top, int pad_bottom, int pad_left, int pad_right)
: gemm_(),
pad_top_(pad_top),
pad_bottom_(pad_bottom),
pad_left_(pad_left),
pad_right_(pad_right),
transformed_filter_(nullptr),
out_tile_size_(0) {}
virtual ~Conv2dK3x3Winograd() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
void UnPackOutput(const Tensor &padded_output,
Tensor *output);
void TransformFilter4x4(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output);
void TransformFilter8x8(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output);
void TransformInput4x4(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output);
void TransformInput8x8(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output);
void TransformOutput4x4(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output);
void TransformOutput8x8(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output);
Gemm gemm_;
int pad_top_;
int pad_bottom_;
int pad_left_;
int pad_right_;
std::unique_ptr<Tensor> transformed_filter_;
index_t out_tile_size_;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_3X3_WINOGRAD_H_
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include <algorithm>
......@@ -28,7 +28,6 @@
#include "mace/core/tensor.h"
#include "mace/ops/activation.h"
#include "mace/ops/arm/conv_2d_neon.h"
#include "mace/ops/arm/conv_winograd.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/utils/memory.h"
......@@ -37,6 +36,7 @@
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/arm/fp32/conv_2d_1x1.h"
#include "mace/ops/arm/fp32/conv_2d_3x3_winograd.h"
#else
#include "mace/ops/ref/conv_2d.h"
#endif // MACE_ENABLE_NEON
......@@ -55,21 +55,20 @@
namespace mace {
namespace ops {
template <DeviceType D, class T>
template<DeviceType D, class T>
class Conv2dOp;
template <>
template<>
class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
public:
explicit Conv2dOp(OpConstructContext *context)
: ConvPool2dOpBase(context),
activation_(ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)),
is_filter_transformed_(false),
"leakyrelu_coefficient", 0.0f)),
conv2d_delegator_(nullptr) {}
MaceStatus Run(OpContext *context) override {
......@@ -127,12 +126,26 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
index_t filter_h = filter->dim(2);
index_t filter_w = filter->dim(3);
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
if (filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
if (conv2d_delegator_.get() == nullptr) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK1x1>();
}
conv2d_delegator_->Compute(context, input, filter, output);
} else if (filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1
&& input_channels >= 8 && channels >= 8) {
if (conv2d_delegator_.get() == nullptr) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK3x3Winograd>(
pad_top, pad_bottom, pad_left, pad_right);
}
conv2d_delegator_->Compute(context, input, filter, output);
} else {
// TODO(liyin): the code below needs to be refactored.
// delegate to each of kernels instead of ruling them all
......@@ -157,11 +170,6 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
std::function<void(const float *input, float *output)> conv_func;
bool
use_winograd = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1
&& input_channels >= 8 && channels >= 8;
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
......@@ -193,122 +201,58 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
&& stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1;
std::vector<index_t> transformed_input_shape;
std::vector<index_t> transformed_output_shape;
std::vector<index_t> transformed_filter_shape;
// When size of input feature map is bigger than 16x16,
// set winograd out tile size to 6 to get higher performance.
index_t winograd_out_tile_size = 2;
if (input_height > 16 && input_width > 16) {
winograd_out_tile_size = 6;
}
if (use_winograd) {
extra_output_height = RoundUp<index_t>(height, winograd_out_tile_size);
extra_input_height =
std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, winograd_out_tile_size);
extra_input_width =
std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
index_t
tile_height_count = extra_output_height / winograd_out_tile_size;
index_t tile_width_count = extra_output_width / winograd_out_tile_size;
index_t tile_count = tile_height_count * tile_width_count;
index_t in_tile_area =
(winograd_out_tile_size + 2) * (winograd_out_tile_size + 2);
transformed_input_shape.insert(transformed_input_shape.end(),
{in_tile_area, batch, input_channels,
tile_count});
transformed_output_shape.insert(transformed_output_shape.end(),
{in_tile_area, batch, channels,
tile_count});
transformed_filter_shape.insert(transformed_filter_shape.end(),
{in_tile_area, channels,
input_channels});
index_t tile_h, tile_w;
if (use_neon_3x3_s1) {
tile_h = 2;
tile_w = 4;
} else if (use_neon_7x1_s1 || use_neon_15x1_s1) {
tile_h = 4;
tile_w = 1;
} else {
index_t tile_h, tile_w;
if (use_neon_3x3_s1) {
tile_h = 2;
tile_w = 4;
} else if (use_neon_7x1_s1 || use_neon_15x1_s1) {
tile_h = 4;
tile_w = 1;
} else {
tile_h = 1;
tile_w = 4;
}
extra_output_height = RoundUp<index_t>(height, tile_h);
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * stride_h
+ (filter_h - 1) * dilation_h + 1);
extra_output_width = RoundUp<index_t>(width, tile_w);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * stride_w
+ (filter_w - 1) * dilation_w + 1);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
tile_h = 1;
tile_w = 4;
}
extra_output_height = RoundUp<index_t>(height, tile_h);
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * stride_h
+ (filter_h - 1) * dilation_h + 1);
extra_output_width = RoundUp<index_t>(width, tile_w);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * stride_w
+ (filter_w - 1) * dilation_w + 1);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
// decide scratch size before allocate it
index_t total_scratch_size = 0;
index_t transformed_input_size = 0;
index_t transformed_output_size = 0;
index_t padded_input_size = 0;
index_t padded_output_size = 0;
if (use_winograd) {
transformed_input_size =
std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
transformed_output_size =
std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size;
}
if (extra_input_height != input_height
|| extra_input_width != input_width) {
padded_input_size =
batch * input_channels * (input_height + pad_top + pad_bottom)
* (input_width + pad_left + pad_right) * sizeof(float) +
MACE_EXTRA_BUFFER_PAD_SIZE;
PadAlignSize(
batch * input_channels * (input_height + pad_top + pad_bottom)
* (input_width + pad_left + pad_right) * sizeof(float) +
MACE_EXTRA_BUFFER_PAD_SIZE);
total_scratch_size += padded_input_size;
}
if (extra_output_height != height || extra_output_width != width) {
padded_output_size =
batch * channels * extra_output_height * extra_output_width
* sizeof(float);
PadAlignSize(
batch * channels * extra_output_height * extra_output_width
* sizeof(float) + MACE_EXTRA_BUFFER_PAD_SIZE);
total_scratch_size += padded_output_size;
}
if (use_winograd) {
total_scratch_size += transformed_input_size + transformed_output_size;
}
// Init scratch buffer
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(total_scratch_size);
Tensor
transformed_input(scratch->Scratch(transformed_input_size), DT_FLOAT);
Tensor
transformed_output
(scratch->Scratch(transformed_output_size), DT_FLOAT);
Tensor padded_input(scratch->Scratch(padded_input_size), DT_FLOAT);
Tensor padded_output(scratch->Scratch(padded_output_size), DT_FLOAT);
const index_t extra_input_shape[4] =
......@@ -320,56 +264,8 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
MACE_UNUSED(extra_input_shape);
MACE_UNUSED(extra_output_shape);
Tensor transformed_filter;
// decide which convolution function to call
if (use_winograd) {
transformed_input.Reshape(transformed_input_shape);
transformed_output.Reshape(transformed_output_shape);
const float *transformed_filter_data = nullptr;
// filter only needs to be transformed once, set transformed_filter_data
// to null after the first run.
if (!is_filter_transformed_) {
transformed_filter.Resize(transformed_filter_shape);
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter.mutable_data<float>());
break;
case 6:
TransformFilter8x8(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter.mutable_data<float>());
break;
default:MACE_NOT_IMPLEMENTED;
}
transformed_filter_data = transformed_filter.data<float>();
is_filter_transformed_ = true;
}
float *transformed_input_data = transformed_input.mutable_data<float>();
float
*transformed_output_data = transformed_output.mutable_data<float>();
conv_func = [=](const float *pad_input, float *pad_output) {
WinogradConv3x3s1(pad_input,
transformed_filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
winograd_out_tile_size,
transformed_input_data,
transformed_output_data,
pad_output,
&sgemm_,
scratch);
};
} else if (use_neon_3x3_s1) {
if (use_neon_3x3_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S1(pad_input,
filter_data,
......@@ -732,8 +628,6 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
bool is_filter_transformed_;
SGemm sgemm_;
#ifdef MACE_ENABLE_NEON
std::unique_ptr<arm::fp32::Conv2dBase> conv2d_delegator_;
#else
......@@ -745,7 +639,6 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#ifdef MACE_ENABLE_QUANTIZE
template <>
class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
......@@ -1052,7 +945,6 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase {
};
#endif // MACE_ENABLE_OPENCL
void RegisterConv2D(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Conv2D", Conv2dOp,
DeviceType::CPU, float);
......
......@@ -37,7 +37,7 @@
namespace mace {
namespace ops {
template <DeviceType D, class T>
template<DeviceType D, class T>
class DepthwiseDeconv2dOp;
template<>
......@@ -92,10 +92,11 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float>
const index_t pad_top = out_paddings[1] / 2;
index_t padded_out_size =
std::accumulate(padded_out_shape.begin(),
padded_out_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
PadAlignSize(std::accumulate(padded_out_shape.begin(),
padded_out_shape.end(),
1,
std::multiplies<index_t>())
* sizeof(float) + MACE_EXTRA_BUFFER_PAD_SIZE);
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(padded_out_size);
......@@ -253,7 +254,6 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float>
padded_out_shape.data(),
out_data);
if (!no_pad) {
CropPadOut<float>(out_data,
padded_out_shape.data(),
......@@ -384,7 +384,7 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float>
const index_t out_offset =
i * strides[0] * out_width + j * strides[1];
for (int q = 0; q < in_channels_g; ++q) {
const index_t in_base =
const index_t in_base =
((b * group + g) * in_channels_g + q) * in_img_size;
const index_t in_offset =
in_base + i * in_width + j;
......
......@@ -21,7 +21,6 @@
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
#include "mace/ops/sgemm.h"
#include "mace/utils/utils.h"
#ifdef MACE_ENABLE_NEON
......
......@@ -21,7 +21,6 @@
#include "public/gemmlowp.h"
#include "mace/benchmark/statistics.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/sgemm.h"
#include "mace/ops/ops_test_util.h"
namespace gemmlowp {
......@@ -94,32 +93,6 @@ namespace test {
namespace {
// Matmul with (m, k) x (k, n)
void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) {
mace::testing::StopTiming();
std::vector<float> lhs(m * k);
std::vector<float> rhs(k * n);
std::vector<float> result(m * n);
ops::SGemmMatrixMap<const float>
matrix_lhs(1, m, k, SGemmRowMajor, lhs.data(),
true);
ops::SGemmMatrixMap<const float>
matrix_rhs(1, k, n, SGemmRowMajor, rhs.data(),
true);
ops::SGemmMatrixMap<float>
matrix_result(1, m, n, SGemmRowMajor, result.data());
ops::SGemm sgemm;
sgemm(matrix_lhs, matrix_rhs, &matrix_result);
mace::testing::StartTiming();
while (iters--) {
sgemm(matrix_lhs, matrix_rhs, &matrix_result);
}
}
void MatmulBenchmark_Eigen(int iters, int m, int k, int n) {
mace::testing::StopTiming();
Eigen::MatrixXf lhs = Eigen::MatrixXf::Random(m, k);
......@@ -223,7 +196,6 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC)
#define MACE_BM_MATMUL(M, K, N) \
MACE_BM_MATMUL_FUNC(M, K, N, Mace_SGemm, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
#include <sys/mman.h>
#include <memory>
#include "mace/ops/sgemm.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/utils/memory.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace {
namespace ops {
void SGemm::operator()(const SGemmMatrixMap<const float> &lhs,
const SGemmMatrixMap<const float> &rhs,
SGemmMatrixMap<float> *result,
ScratchBuffer *scratch_buffer) {
if (lhs.is_const() && !rhs.is_const()) {
SGemmMatrixMap<const float> lhs_transpose = lhs.transpose();
SGemmMatrixMap<const float> rhs_transpose = rhs.transpose();
SGemmMatrixMap<float> result_transpose = result->transpose();
return operator()(rhs_transpose,
lhs_transpose,
&result_transpose,
scratch_buffer);
}
if (scratch_buffer != nullptr) {
index_t total_size = result->size();
if (!lhs.is_const()) {
total_size += lhs.size();
}
if (!rhs.is_const()) {
total_size += rhs.size();
}
scratch_buffer->GrowSize(total_size * sizeof(float));
if (!lhs.is_const()) {
packed_lhs_ = make_unique<Tensor>(scratch_buffer->Scratch(
lhs.size() * sizeof(float)), DT_FLOAT);
}
if (!rhs.is_const()) {
packed_rhs_ = make_unique<Tensor>(scratch_buffer->Scratch(
rhs.size() * sizeof(float)), DT_FLOAT);
}
packed_result_ = make_unique<Tensor>(scratch_buffer->Scratch(
result->size() * sizeof(float)), DT_FLOAT);
}
if (packed_lhs_.get() == nullptr) {
packed_lhs_ = make_unique<Tensor>(GetCPUAllocator(), DT_FLOAT);
packed_lhs_->Resize({lhs.size()});
}
if (packed_rhs_.get() == nullptr) {
packed_rhs_ = make_unique<Tensor>(GetCPUAllocator(), DT_FLOAT);
packed_rhs_->Resize({rhs.size()});
}
if (packed_result_.get() == nullptr) {
packed_result_ = make_unique<Tensor>(GetCPUAllocator(), DT_FLOAT);
packed_result_->Resize({result->size()});
}
if (!lhs.is_const() || !packed_) {
PackLhs(lhs, packed_lhs_.get());
if (lhs.is_const()) {
AdviseFree(reinterpret_cast<void *>(const_cast<float *>(lhs.data())),
lhs.size() * sizeof(float));
}
}
if (!rhs.is_const() || !packed_) {
PackRhs(rhs, packed_rhs_.get());
if (rhs.is_const()) {
AdviseFree(reinterpret_cast<void *>(const_cast<float *>(rhs.data())),
rhs.size() * sizeof(float));
}
}
packed_ = true;
RunInternal(*packed_lhs_,
*packed_rhs_,
lhs.batch(),
lhs.row(),
lhs.col(),
rhs.col(),
packed_result_.get());
UnPack(*packed_result_, result);
}
void SGemm::Run(const float *A,
const float *B,
const index_t batch,
const index_t height_a,
const index_t width_a,
const index_t height_b,
const index_t width_b,
const bool transpose_a,
const bool transpose_b,
const bool is_a_weight,
const bool is_b_weight,
float *C,
ScratchBuffer *scratch_buffer) {
index_t height_c = height_a;
index_t width_c = width_b;
if (transpose_a) {
height_c = width_a;
}
if (transpose_b) {
width_c = height_b;
}
SGemmMatrixMap<const float> matrix_a =
SGemmMatrixMap<const float>(batch,
height_a,
width_a,
ops::SGemmRowMajor,
A,
is_a_weight);
SGemmMatrixMap<const float> matrix_b =
ops::SGemmMatrixMap<const float>(batch,
height_b,
width_b,
ops::SGemmRowMajor,
B,
is_b_weight);
if (transpose_a) {
matrix_a = matrix_a.transpose();
}
if (transpose_b) {
matrix_b = matrix_b.transpose();
}
SGemmMatrixMap<float>
matrix_c(batch, height_c, width_c, ops::SGemmRowMajor, C);
operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer);
}
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
// calculate 8 rows, 4 cols for each depth
#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \
c0 = vfmaq_laneq_f32(c0, b##D, a##VD, 0); \
c1 = vfmaq_laneq_f32(c1, b##D, a##VD, 1); \
c2 = vfmaq_laneq_f32(c2, b##D, a##VD, 2); \
c3 = vfmaq_laneq_f32(c3, b##D, a##VD, 3); \
c4 = vfmaq_laneq_f32(c4, b##D, a##VDN, 0); \
c5 = vfmaq_laneq_f32(c5, b##D, a##VDN, 1); \
c6 = vfmaq_laneq_f32(c6, b##D, a##VDN, 2); \
c7 = vfmaq_laneq_f32(c7, b##D, a##VDN, 3);
// calculate 4 rows, 4 cols for each depth
#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \
c0 = vfmaq_laneq_f32(c0, b##D, a##D, 0); \
c1 = vfmaq_laneq_f32(c1, b##D, a##D, 1); \
c2 = vfmaq_laneq_f32(c2, b##D, a##D, 2); \
c3 = vfmaq_laneq_f32(c3, b##D, a##D, 3);
// calculate 4 cols for 8 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \
c##R = vfmaq_laneq_f32(c##R, b0, a##VR, 0); \
c##R = vfmaq_laneq_f32(c##R, b1, a##VR, 1); \
c##R = vfmaq_laneq_f32(c##R, b2, a##VR, 2); \
c##R = vfmaq_laneq_f32(c##R, b3, a##VR, 3); \
c##R = vfmaq_laneq_f32(c##R, b4, a##VRN, 0); \
c##R = vfmaq_laneq_f32(c##R, b5, a##VRN, 1); \
c##R = vfmaq_laneq_f32(c##R, b6, a##VRN, 2); \
c##R = vfmaq_laneq_f32(c##R, b7, a##VRN, 3);
// calculate 4 cols for 4 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \
c##R = vfmaq_laneq_f32(c##R, b0, a##R, 0); \
c##R = vfmaq_laneq_f32(c##R, b1, a##R, 1); \
c##R = vfmaq_laneq_f32(c##R, b2, a##R, 2); \
c##R = vfmaq_laneq_f32(c##R, b3, a##R, 3);
// calculate 8 cols for 4 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C8_D4(VR, VRN, R) \
c##VR = vfmaq_laneq_f32(c##VR, b0, a##R, 0); \
c##VR = vfmaq_laneq_f32(c##VR, b2, a##R, 1); \
c##VR = vfmaq_laneq_f32(c##VR, b4, a##R, 2); \
c##VR = vfmaq_laneq_f32(c##VR, b6, a##R, 3); \
c##VRN = vfmaq_laneq_f32(c##VRN, b1, a##R, 0); \
c##VRN = vfmaq_laneq_f32(c##VRN, b3, a##R, 1); \
c##VRN = vfmaq_laneq_f32(c##VRN, b5, a##R, 2); \
c##VRN = vfmaq_laneq_f32(c##VRN, b7, a##R, 3);
#else
#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \
c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##VD), 0); \
c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##VD), 1); \
c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##VD), 0); \
c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##VD), 1); \
c4 = vmlaq_lane_f32(c4, b##D, vget_low_f32(a##VDN), 0); \
c5 = vmlaq_lane_f32(c5, b##D, vget_low_f32(a##VDN), 1); \
c6 = vmlaq_lane_f32(c6, b##D, vget_high_f32(a##VDN), 0); \
c7 = vmlaq_lane_f32(c7, b##D, vget_high_f32(a##VDN), 1);
#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \
c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##D), 0); \
c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##D), 1); \
c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##D), 0); \
c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##D), 1);
#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \
c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##VR), 0); \
c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##VR), 1); \
c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##VR), 0); \
c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##VR), 1); \
c##R = vmlaq_lane_f32(c##R, b4, vget_low_f32(a##VRN), 0); \
c##R = vmlaq_lane_f32(c##R, b5, vget_low_f32(a##VRN), 1); \
c##R = vmlaq_lane_f32(c##R, b6, vget_high_f32(a##VRN), 0); \
c##R = vmlaq_lane_f32(c##R, b7, vget_high_f32(a##VRN), 1);
#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \
c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##R), 0); \
c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##R), 1); \
c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##R), 0); \
c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##R), 1);
#endif // __aarch64__
#endif // MACE_ENABLE_NEON
void SGemm::RunInternal(const PackedBlock &lhs,
const PackedBlock &rhs,
const index_t batch,
const index_t height,
const index_t depth,
const index_t width,
PackedBlock *result) {
const float *lhs_data = lhs.data<float>();
const float *rhs_data = rhs.data<float>();
float *result_data = result->mutable_data<float>();
#define MACE_SGEMM_RUN_PER_BATCH \
for (index_t b = 0; b < batch; ++b) { \
RunPerBatch(lhs_data + b * height * depth, \
rhs_data + b * depth * width, \
height, \
depth, \
width, \
result_data + b * height * width); \
}
if (batch >= MaceOpenMPThreadCount) {
#pragma omp parallel for schedule(runtime)
MACE_SGEMM_RUN_PER_BATCH
} else {
MACE_SGEMM_RUN_PER_BATCH
}
#undef MACE_SGEMM_RUN_PER_BATCH
}
void SGemm::RunPerBatch(const float *lhs_data,
const float *rhs_data,
const index_t height,
const index_t depth,
const index_t width,
float *result_data) {
#if defined(MACE_ENABLE_NEON)
const index_t block_w = width >> 2;
const index_t remain_w = width - (block_w << 2);
#else
const index_t remain_w = width;
#endif
#if defined(MACE_ENABLE_NEON)
// TODO(liyin): make better use l2(l1) cache, try to fit as much lhs data as
// as possible to cache, by tiling lhs by height and rhs by width.
// w: 4
#pragma omp parallel for schedule(runtime)
for (index_t bw = 0; bw < block_w; ++bw) {
index_t remain_h = height;
index_t block_h = 0;
const float *lhs_ptr = lhs_data;
float *res_ptr = result_data + height * (bw << 2);
#if defined(__aarch64__)
block_h = remain_h >> 3;
remain_h -= (block_h << 3);
// h: 8
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * (bw << 2);
index_t remain_d = depth;
index_t block_d = remain_d >> 3;
remain_d -= (block_d << 3);
float32x4_t c0, c1, c2, c3, c4, c5, c6, c7;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
c2 = vdupq_n_f32(0.f);
c3 = vdupq_n_f32(0.f);
c4 = vdupq_n_f32(0.f);
c5 = vdupq_n_f32(0.f);
c6 = vdupq_n_f32(0.f);
c7 = vdupq_n_f32(0.f);
// d: 8
for (index_t bd = 0; bd < block_d; ++bd) {
// 8.8.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13,
a14, a15;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
a4 = vld1q_f32(lhs_ptr + 16);
a5 = vld1q_f32(lhs_ptr + 20);
a6 = vld1q_f32(lhs_ptr + 24);
a7 = vld1q_f32(lhs_ptr + 28);
a8 = vld1q_f32(lhs_ptr + 32);
a9 = vld1q_f32(lhs_ptr + 36);
a10 = vld1q_f32(lhs_ptr + 40);
a11 = vld1q_f32(lhs_ptr + 44);
a12 = vld1q_f32(lhs_ptr + 48);
a13 = vld1q_f32(lhs_ptr + 52);
a14 = vld1q_f32(lhs_ptr + 56);
a15 = vld1q_f32(lhs_ptr + 60);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
b4 = vld1q_f32(rhs_ptr + 16);
b5 = vld1q_f32(rhs_ptr + 20);
b6 = vld1q_f32(rhs_ptr + 24);
b7 = vld1q_f32(rhs_ptr + 28);
MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1
MACE_SGEMM_PART_CAL_R8_C4_D1(1, 2, 3); // d = 2
MACE_SGEMM_PART_CAL_R8_C4_D1(2, 4, 5);
MACE_SGEMM_PART_CAL_R8_C4_D1(3, 6, 7);
MACE_SGEMM_PART_CAL_R8_C4_D1(4, 8, 9);
MACE_SGEMM_PART_CAL_R8_C4_D1(5, 10, 11);
MACE_SGEMM_PART_CAL_R8_C4_D1(6, 12, 13);
MACE_SGEMM_PART_CAL_R8_C4_D1(7, 14, 15);
lhs_ptr += 64;
rhs_ptr += 32;
}
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 8.4.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
float32x4_t b0, b1, b2, b3;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
a4 = vld1q_f32(lhs_ptr + 16);
a5 = vld1q_f32(lhs_ptr + 20);
a6 = vld1q_f32(lhs_ptr + 24);
a7 = vld1q_f32(lhs_ptr + 28);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1
MACE_SGEMM_PART_CAL_R8_C4_D1(1, 2, 3); // d = 2
MACE_SGEMM_PART_CAL_R8_C4_D1(2, 4, 5);
MACE_SGEMM_PART_CAL_R8_C4_D1(3, 6, 7);
lhs_ptr += 32;
rhs_ptr += 16;
}
// TODO(liyin): handle remain by each case
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 8.1.4
float32x4_t a0, a1;
float32x4_t b0;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1
lhs_ptr += 8;
rhs_ptr += 4;
}
vst1q_f32(res_ptr, c0);
vst1q_f32(res_ptr + 4, c1);
vst1q_f32(res_ptr + 8, c2);
vst1q_f32(res_ptr + 12, c3);
vst1q_f32(res_ptr + 16, c4);
vst1q_f32(res_ptr + 20, c5);
vst1q_f32(res_ptr + 24, c6);
vst1q_f32(res_ptr + 28, c7);
res_ptr += 32;
} // bh: 8
#endif // __aarch64__
// h: 4
block_h = remain_h >> 2;
remain_h -= (block_h << 2);
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * (bw << 2);
index_t remain_d = depth;
index_t block_d = 0;
float32x4_t c0, c1, c2, c3;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
c2 = vdupq_n_f32(0.f);
c3 = vdupq_n_f32(0.f);
// d: 8
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
#if defined(__aarch64__)
for (index_t bd = 0; bd < block_d; ++bd) {
// 4.8.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
a4 = vld1q_f32(lhs_ptr + 16);
a5 = vld1q_f32(lhs_ptr + 20);
a6 = vld1q_f32(lhs_ptr + 24);
a7 = vld1q_f32(lhs_ptr + 28);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
b4 = vld1q_f32(rhs_ptr + 16);
b5 = vld1q_f32(rhs_ptr + 20);
b6 = vld1q_f32(rhs_ptr + 24);
b7 = vld1q_f32(rhs_ptr + 28);
MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1
MACE_SGEMM_PART_CAL_R4_C4_D1(1); // d = 2
MACE_SGEMM_PART_CAL_R4_C4_D1(2);
MACE_SGEMM_PART_CAL_R4_C4_D1(3);
MACE_SGEMM_PART_CAL_R4_C4_D1(4);
MACE_SGEMM_PART_CAL_R4_C4_D1(5);
MACE_SGEMM_PART_CAL_R4_C4_D1(6);
MACE_SGEMM_PART_CAL_R4_C4_D1(7);
lhs_ptr += 32;
rhs_ptr += 32;
}
#else // arm v7
// 4.8.4
if (block_d > 0) {
asm volatile(
"0: \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n"
"vld1.f32 {d24-d25}, [%[rhs_ptr]]! \n"
"vmla.f32 %q[c0], q10, d0[0] \n"
"vmla.f32 %q[c1], q10, d0[1] \n"
"vmla.f32 %q[c2], q10, d1[0] \n"
"vmla.f32 %q[c3], q10, d1[1] \n"
"vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n"
"vld1.f32 {d26-d27}, [%[rhs_ptr]]! \n"
"vmla.f32 %q[c0], q11, d2[0] \n"
"vmla.f32 %q[c1], q11, d2[1] \n"
"vmla.f32 %q[c2], q11, d3[0] \n"
"vmla.f32 %q[c3], q11, d3[1] \n"
"vld1.f32 {d8-d9}, [%[lhs_ptr]]! \n"
"vld1.f32 {d28-d29}, [%[rhs_ptr]]! \n"
"vmla.f32 %q[c0], q12, d4[0] \n"
"vmla.f32 %q[c1], q12, d4[1] \n"
"vmla.f32 %q[c2], q12, d5[0] \n"
"vmla.f32 %q[c3], q12, d5[1] \n"
"vld1.f32 {d10-d11}, [%[lhs_ptr]]! \n"
"vld1.f32 {d30-d31}, [%[rhs_ptr]]! \n"
"vmla.f32 %q[c0], q13, d6[0] \n"
"vmla.f32 %q[c1], q13, d6[1] \n"
"vmla.f32 %q[c2], q13, d7[0] \n"
"vmla.f32 %q[c3], q13, d7[1] \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n"
"vmla.f32 %q[c0], q14, d8[0] \n"
"vmla.f32 %q[c1], q14, d8[1] \n"
"vmla.f32 %q[c2], q14, d9[0] \n"
"vmla.f32 %q[c3], q14, d9[1] \n"
"vmla.f32 %q[c0], q15, d10[0] \n"
"vmla.f32 %q[c1], q15, d10[1] \n"
"vmla.f32 %q[c2], q15, d11[0] \n"
"vmla.f32 %q[c3], q15, d11[1] \n"
"vmla.f32 %q[c0], q10, d0[0] \n"
"vmla.f32 %q[c1], q10, d0[1] \n"
"vmla.f32 %q[c2], q10, d1[0] \n"
"vmla.f32 %q[c3], q10, d1[1] \n"
"subs %[block_d], %[block_d], #1 \n"
"vmla.f32 %q[c0], q11, d2[0] \n"
"vmla.f32 %q[c1], q11, d2[1] \n"
"vmla.f32 %q[c2], q11, d3[0] \n"
"vmla.f32 %q[c3], q11, d3[1] \n"
"bne 0b \n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr),
[rhs_ptr] "+r"(rhs_ptr),
[res_ptr] "+r"(res_ptr),
[block_d] "+r"(block_d),
[c0] "+w"(c0),
[c1] "+w"(c1),
[c2] "+w"(c2),
[c3] "+w"(c3)
: // inputs
: // clabbers
"cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5",
"q10", "q11", "q12", "q13", "q14", "q15");
}
#endif // __aarch64__
// d: 4
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
for (index_t bd = 0; bd < block_d; ++bd) {
// 4.4.4
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
a2 = vld1q_f32(lhs_ptr + 8);
a3 = vld1q_f32(lhs_ptr + 12);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1
MACE_SGEMM_PART_CAL_R4_C4_D1(1); // d = 2
MACE_SGEMM_PART_CAL_R4_C4_D1(2);
MACE_SGEMM_PART_CAL_R4_C4_D1(3);
lhs_ptr += 16;
rhs_ptr += 16;
}
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 4.1.4
float32x4_t a0;
float32x4_t b0;
a0 = vld1q_f32(lhs_ptr);
b0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1
lhs_ptr += 4;
rhs_ptr += 4;
}
vst1q_f32(res_ptr, c0);
vst1q_f32(res_ptr + 4, c1);
vst1q_f32(res_ptr + 8, c2);
vst1q_f32(res_ptr + 12, c3);
res_ptr += 16;
} // bh: 4
// h: 1
for (index_t h = 0; h < remain_h; ++h) {
const float *rhs_ptr = rhs_data + depth * (bw << 2);
index_t remain_d = depth;
index_t block_d = 0;
float32x4_t c0 = vdupq_n_f32(0.f);
// d: 8
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.8.4
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
b4 = vld1q_f32(rhs_ptr + 16);
b5 = vld1q_f32(rhs_ptr + 20);
b6 = vld1q_f32(rhs_ptr + 24);
b7 = vld1q_f32(rhs_ptr + 28);
MACE_SGEMM_PART_CAL_R1_C4_D8(0, 0, 1);
lhs_ptr += 8;
rhs_ptr += 32;
}
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.4.4
float32x4_t a0;
float32x4_t b0, b1, b2, b3;
a0 = vld1q_f32(lhs_ptr);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
b2 = vld1q_f32(rhs_ptr + 8);
b3 = vld1q_f32(rhs_ptr + 12);
MACE_SGEMM_PART_CAL_R1_C4_D4(0);
lhs_ptr += 4;
rhs_ptr += 16;
}
// d: remain
float s0 = 0;
float s1 = 0;
float s2 = 0;
float s3 = 0;
for (index_t d = 0; d < remain_d; ++d) {
// 1.1.4
s0 += lhs_ptr[0] * rhs_ptr[0];
s1 += lhs_ptr[0] * rhs_ptr[1];
s2 += lhs_ptr[0] * rhs_ptr[2];
s3 += lhs_ptr[0] * rhs_ptr[3];
lhs_ptr += 1;
rhs_ptr += 4;
}
float32x4_t c0_remain = {s0, s1, s2, s3};
c0 += c0_remain;
vst1q_f32(res_ptr, c0);
res_ptr += 4;
} // bh: remain
} // bw
#endif // MACE_ENABLE_NEON
// ========================== remain width ===========================
result_data += (width - remain_w) * height;
rhs_data += (width - remain_w) * depth;
// w: 1
#pragma omp parallel for schedule(runtime)
for (index_t bw = 0; bw < remain_w; ++bw) {
index_t remain_h = height;
const float *lhs_ptr = lhs_data;
float *res_ptr = result_data + height * bw;
#if defined(MACE_ENABLE_NEON)
index_t block_h = 0;
#if defined(__aarch64__)
block_h = remain_h >> 3;
remain_h -= (block_h << 3);
// h: 8
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * bw;
index_t remain_d = depth;
float32x4_t c0, c1;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
index_t block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 8.4.1
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t a0;
b0 = vld1q_f32(lhs_ptr);
b1 = vld1q_f32(lhs_ptr + 4);
b2 = vld1q_f32(lhs_ptr + 8);
b3 = vld1q_f32(lhs_ptr + 12);
b4 = vld1q_f32(lhs_ptr + 16);
b5 = vld1q_f32(lhs_ptr + 20);
b6 = vld1q_f32(lhs_ptr + 24);
b7 = vld1q_f32(lhs_ptr + 28);
a0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R1_C8_D4(0, 1, 0);
lhs_ptr += 32;
rhs_ptr += 4;
}
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 8.1.1
float32x4_t b0, b1;
float32x4_t a0 = vdupq_n_f32(rhs_ptr[0]);
b0 = vld1q_f32(lhs_ptr);
b1 = vld1q_f32(lhs_ptr + 4);
c0 = vfmaq_laneq_f32(c0, b0, a0, 0);
c1 = vfmaq_laneq_f32(c1, b1, a0, 0);
lhs_ptr += 8;
rhs_ptr += 1;
}
vst1q_f32(res_ptr, c0);
vst1q_f32(res_ptr + 4, c1);
res_ptr += 8;
} // bh: 8
#endif
// h: 4
block_h = remain_h >> 2;
remain_h -= (block_h << 2);
for (index_t bh = 0; bh < block_h; ++bh) {
const float *rhs_ptr = rhs_data + depth * bw;
index_t remain_d = depth;
index_t block_d = 0;
float32x4_t c0 = vdupq_n_f32(0.f);
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 4.4.1
float32x4_t b0, b1, b2, b3;
float32x4_t a0;
b0 = vld1q_f32(lhs_ptr);
b1 = vld1q_f32(lhs_ptr + 4);
b2 = vld1q_f32(lhs_ptr + 8);
b3 = vld1q_f32(lhs_ptr + 12);
a0 = vld1q_f32(rhs_ptr);
MACE_SGEMM_PART_CAL_R1_C4_D4(0);
lhs_ptr += 16;
rhs_ptr += 4;
}
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 4.1.1
float32x4_t b0;
float32x2_t a0 = vdup_n_f32(rhs_ptr[0]);
b0 = vld1q_f32(lhs_ptr);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
lhs_ptr += 4;
rhs_ptr += 1;
}
vst1q_f32(res_ptr, c0);
res_ptr += 4;
} // bh: 4
#endif // MACE_ENABLE_NEON
// h: 1
for (index_t h = 0; h < remain_h; ++h) {
const float *rhs_ptr = rhs_data + depth * bw;
index_t remain_d = depth;
float sum = 0.f;
#if defined(MACE_ENABLE_NEON)
index_t block_d = 0;
float32x4_t c0, c1;
c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
// d: 8
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.8.1
float32x4_t a0, a1;
float32x4_t b0, b1;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
c0 = vmlaq_f32(c0, a0, b0);
c1 = vmlaq_f32(c1, a1, b1);
lhs_ptr += 8;
rhs_ptr += 8;
}
block_d = remain_d >> 2;
remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.4.1
float32x4_t a0;
float32x4_t b0;
a0 = vld1q_f32(lhs_ptr);
b0 = vld1q_f32(rhs_ptr);
c0 = vmlaq_f32(c0, a0, b0);
lhs_ptr += 4;
rhs_ptr += 4;
}
sum += vaddvq_f32(c0);
sum += vaddvq_f32(c1);
#endif // MACE_ENABLE_NEON
// d: remain
for (index_t d = 0; d < remain_d; ++d) {
// 1.1.1
sum += lhs_ptr[0] * rhs_ptr[0];
lhs_ptr += 1;
rhs_ptr += 1;
}
*res_ptr = sum;
++res_ptr;
} // bh: remain
} // bw
}
void SGemm::PackLhs(const SGemmMatrixMap<const float> &lhs,
PackedBlock *packed_block) {
Pack(lhs, PackOrder::SGemmColMajor, packed_block);
}
void SGemm::PackRhs(const SGemmMatrixMap<const float> &rhs,
PackedBlock *packed_block) {
Pack(rhs, PackOrder::SGemmRowMajor, packed_block);
}
void SGemm::Pack(const SGemmMatrixMap<const float> &src,
const PackOrder order,
PackedBlock *packed_block) {
MACE_CHECK_NOTNULL(packed_block);
const index_t height = src.row();
const index_t width = src.col();
auto packed_data = packed_block->mutable_data<float>();
#define MACE_SGEMM_PACK_PER_BATCH \
for (index_t b = 0; b < src.batch(); ++b) { \
PackPerBatch(src, order, b, packed_data + b * height * width); \
}
if (src.batch() >= MaceOpenMPThreadCount) {
#pragma omp parallel for schedule(runtime)
MACE_SGEMM_PACK_PER_BATCH
} else {
MACE_SGEMM_PACK_PER_BATCH
}
#undef MACE_SGEMM_PACK_PER_BATCH
}
void SGemm::UnPack(const PackedBlock &packed_result,
SGemmMatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row();
const index_t width = matrix_map->col();
auto packed_data = packed_result.data<float>();
#define MACE_SGEMM_UNPACK_PER_BATCH \
for (index_t b = 0; b < matrix_map->batch(); ++b) { \
UnPackPerBatch(packed_data + b * height * width, b, matrix_map); \
}
if (matrix_map->batch() >= MaceOpenMPThreadCount) {
#pragma omp parallel for schedule(runtime)
MACE_SGEMM_UNPACK_PER_BATCH
} else {
MACE_SGEMM_UNPACK_PER_BATCH
}
#undef MACE_SGEMM_UNPACK_PER_BATCH
}
void SGemm::PackPerBatch(const SGemmMatrixMap<const float> &src,
const PackOrder order,
const index_t batch_index,
float *packed_data) {
MACE_CHECK_NOTNULL(packed_data);
const index_t height = src.row();
const index_t width = src.col();
auto src_data = src.batch_data(batch_index);
if (src.map_major() == Major::SGemmRowMajor
&& order == PackOrder::SGemmColMajor) {
// This is for packing no-transpose lhs.
index_t h = 0;
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#pragma omp parallel for schedule(runtime)
for (index_t ih = h; ih <= height - 8; ih += 8) {
const float *src_data_ptr = src_data + ih * width;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w;
const index_t packed_offset = w * 8;
float32x4_t vs0 = {src_data_ptr[src_offset],
src_data_ptr[src_offset + width],
src_data_ptr[src_offset + 2 * width],
src_data_ptr[src_offset + 3 * width]};
float32x4_t vs1 = {src_data_ptr[src_offset + 4 * width],
src_data_ptr[src_offset + 5 * width],
src_data_ptr[src_offset + 6 * width],
src_data_ptr[src_offset + 7 * width]};
vst1q_f32(packed_data_ptr + packed_offset, vs0);
vst1q_f32(packed_data_ptr + packed_offset + 4, vs1);
}
}
h += (height - h) / 8 * 8;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t ih = h; ih <= height - 4; ih += 4) {
const float *src_data_ptr = src_data + ih * width;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w;
const index_t packed_offset = w * 4;
float32x4_t vs = {src_data_ptr[src_offset],
src_data_ptr[src_offset + width],
src_data_ptr[src_offset + 2 * width],
src_data_ptr[src_offset + 3 * width]};
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
h += (height - h) / 4 * 4;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t ih = h; ih < height; ++ih) {
std::copy_n(src_data + ih * width, width, packed_data + ih * width);
}
} else if (src.map_major() == Major::SGemmColMajor &&
order == PackOrder::SGemmColMajor) {
// This is for packing transpose-needed lhs.
index_t h = 0;
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#pragma omp parallel for schedule(runtime)
for (index_t ih = h; ih <= height - 8; ih += 8) {
const float *src_data_ptr = src_data + ih;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w * height;
const index_t packed_offset = w * 8;
float32x4_t vs0 = vld1q_f32(src_data_ptr + src_offset);
float32x4_t vs1 = vld1q_f32(src_data_ptr + src_offset + 4);
vst1q_f32(packed_data_ptr + packed_offset, vs0);
vst1q_f32(packed_data_ptr + packed_offset + 4, vs1);
}
}
h += (height - h) / 8 * 8;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t ih = h; ih <= height - 4; ih += 4) {
const float *src_data_ptr = src_data + ih;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
const index_t src_offset = w * height;
const index_t packed_offset = w * 4;
float32x4_t vs = vld1q_f32(src_data_ptr + src_offset);
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
h += (height - h) / 4 * 4;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t ih = h; ih < height; ++ih) {
const float *src_data_ptr = src_data + ih;
float *packed_data_ptr = packed_data + ih * width;
for (index_t w = 0; w < width; ++w) {
packed_data_ptr[w] = src_data_ptr[w * height];
}
}
} else if (src.map_major() == Major::SGemmRowMajor &&
order == PackOrder::SGemmRowMajor) {
// This is for packing no-transpose rhs.
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *src_data_ptr = src_data + iw;
float *packed_data_ptr = packed_data + iw * height;
for (index_t h = 0; h < height; ++h) {
const index_t src_offset = h * width;
const index_t packed_offset = h * 4;
float32x4_t vs = vld1q_f32(src_data_ptr + src_offset);
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw < width; ++iw) {
const float *src_data_ptr = src_data + iw;
float *packed_data_ptr = packed_data + iw * height;
for (index_t h = 0; h < height; ++h) {
packed_data_ptr[h] = src_data_ptr[h * width];
}
}
} else if (src.map_major() == Major::SGemmColMajor &&
order == PackOrder::SGemmRowMajor) {
// This is for packing transpose-needed rhs.
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *src_data_ptr = src_data + iw * height;
float *packed_data_ptr = packed_data + iw * height;
for (index_t h = 0; h < height; ++h) {
const index_t src_offset = h;
const index_t packed_offset = h * 4;
float32x4_t vs = {src_data_ptr[src_offset],
src_data_ptr[src_offset + height],
src_data_ptr[src_offset + 2 * height],
src_data_ptr[src_offset + 3 * height]};
vst1q_f32(packed_data_ptr + packed_offset, vs);
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw < width; ++iw) {
std::copy_n(src_data + iw * height, height, packed_data + iw * height);
}
}
}
void SGemm::UnPackPerBatch(const float *packed_data,
const index_t batch_index,
SGemmMatrixMap<float> *matrix_map) {
MACE_CHECK_NOTNULL(matrix_map);
const index_t height = matrix_map->row();
const index_t width = matrix_map->col();
auto unpacked_data = matrix_map->batch_data(batch_index);
if (matrix_map->map_major() == Major::SGemmRowMajor) {
// This is for non-transposed result
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw;
for (index_t h = 0; h < height; ++h) {
const index_t packed_offset = h * 4;
const index_t unpacked_offset = h * width;
float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset);
vst1q_f32(unpacked_data_ptr + unpacked_offset, vs);
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw < width; ++iw) {
const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw;
for (index_t h = 0; h < height; ++h) {
unpacked_data_ptr[h * width] = packed_data_ptr[h];
}
}
} else {
// This is for transposed result
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw * height;
for (index_t h = 0; h < height; ++h) {
const index_t packed_offset = h * 4;
const index_t unpacked_offset = h;
float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset);
unpacked_data_ptr[unpacked_offset] = vs[0];
unpacked_data_ptr[unpacked_offset + height] = vs[1];
unpacked_data_ptr[unpacked_offset + 2 * height] = vs[2];
unpacked_data_ptr[unpacked_offset + 3 * height] = vs[3];
}
}
w += (width - w) / 4 * 4;
#endif
#pragma omp parallel for schedule(runtime)
for (index_t iw = w; iw < width; ++iw) {
std::copy_n(
packed_data + iw * height, height, unpacked_data + iw * height);
}
}
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This implementation is deprecated. use mace/ops/arm/fp32/gemm.h instead.
#ifndef MACE_OPS_SGEMM_H_
#define MACE_OPS_SGEMM_H_
#include <memory>
#include <utility>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/core/allocator.h"
#include "mace/core/tensor.h"
namespace mace {
namespace ops {
enum Major {
SGemmRowMajor,
SGemmColMajor
};
template<typename T>
class SGemmMatrixMap {
public:
SGemmMatrixMap() {}
SGemmMatrixMap(const index_t batch,
const index_t row,
const index_t col,
const Major major,
T *data,
const bool is_const = false) :
batch_(batch),
row_(row),
col_(col),
stride_(major == SGemmRowMajor ? col : row),
major_(major),
data_(data),
is_const_(is_const) {}
SGemmMatrixMap transpose() const {
Major transpose_major =
major_ == SGemmRowMajor ? SGemmColMajor : SGemmRowMajor;
return SGemmMatrixMap(batch_,
col_,
row_,
transpose_major,
data_,
is_const_);
}
index_t batch() const {
return batch_;
}
index_t row() const {
return row_;
}
index_t col() const {
return col_;
}
index_t stride() const {
return stride_;
}
Major map_major() const {
return major_;
}
T *data() const {
return data_;
}
T *batch_data(index_t batch) const {
return data_ + batch * row_ * col_;
}
index_t size() const {
return batch_ * row_ * col_;
}
bool is_const() const {
return is_const_;
}
private:
index_t batch_;
index_t row_;
index_t col_;
index_t stride_;
Major major_;
T *data_;
bool is_const_;
};
typedef Major PackOrder;
typedef Tensor PackedBlock;
class SGemm {
public:
SGemm()
: packed_lhs_(nullptr),
packed_rhs_(nullptr),
packed_(false) {}
void operator()(const SGemmMatrixMap<const float> &lhs,
const SGemmMatrixMap<const float> &rhs,
SGemmMatrixMap<float> *result,
ScratchBuffer *scratch_buffer = nullptr);
void Run(const float *A,
const float *B,
const index_t batch,
const index_t height_a,
const index_t width_a,
const index_t height_b,
const index_t width_b,
const bool transpose_a,
const bool transpose_b,
const bool is_a_weight,
const bool is_b_weight,
float *C,
ScratchBuffer *scratch_buffer = nullptr);
void PackLhs(const SGemmMatrixMap<const float> &lhs,
PackedBlock *packed_block);
void PackRhs(const SGemmMatrixMap<const float> &rhs,
PackedBlock *packed_block);
void UnPack(const PackedBlock &packed_result,
SGemmMatrixMap<float> *matrix_map);
private:
void Pack(const SGemmMatrixMap<const float> &src,
const PackOrder order,
PackedBlock *packed_block);
void PackPerBatch(const SGemmMatrixMap<const float> &src,
const PackOrder order,
const index_t batch_index,
float *packed_data);
void UnPackPerBatch(const float *packed_data,
const index_t batch_index,
SGemmMatrixMap<float> *matrix_map);
void RunInternal(const PackedBlock &lhs,
const PackedBlock &rhs,
const index_t batch,
const index_t height,
const index_t depth,
const index_t width,
PackedBlock *result);
void RunPerBatch(const float *lhs,
const float *rhs,
const index_t height,
const index_t depth,
const index_t width,
float *result);
std::unique_ptr<Tensor> packed_lhs_;
std::unique_ptr<Tensor> packed_rhs_;
std::unique_ptr<Tensor> packed_result_;
bool packed_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_SGEMM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <random>
#include <vector>
#include "mace/ops/sgemm.h"
namespace mace {
namespace ops {
namespace test {
namespace {
void TestPack(const std::vector<float> &data,
const std::vector<float> &expected_data,
const index_t height,
const index_t width,
Major src_order,
PackOrder pack_order) {
SGemm sg;
SGemmMatrixMap<const float>
src_matrix(1, height, width, src_order, data.data());
PackedBlock packed;
packed.Resize({height, width});
if (pack_order == PackOrder::SGemmColMajor) {
sg.PackLhs(src_matrix, &packed);
} else {
sg.PackRhs(src_matrix, &packed);
}
auto packed_data = packed.data<float>();
for (index_t i = 0; i < packed.size(); ++i) {
EXPECT_EQ(expected_data[i], packed_data[i]);
}
}
void TestUnPack(const index_t height,
const index_t width,
Major src_order,
PackOrder pack_order) {
static auto seed = static_cast<unsigned int>(time(nullptr));
const index_t matrix_size = height * width;
std::vector<float> data(matrix_size);
for (int i = 0; i < matrix_size; ++i) {
data[i] = rand_r(&seed);
}
SGemmMatrixMap<const float>
src_matrix(1, height, width, src_order, data.data());
PackedBlock packed;
packed.Resize({height, width});
SGemm sg;
if (pack_order == PackOrder::SGemmColMajor) {
sg.PackLhs(src_matrix, &packed);
} else {
sg.PackRhs(src_matrix, &packed);
}
std::vector<float> unpacked(matrix_size);
SGemmMatrixMap<float>
unpacked_matrix(1, height, width, src_order, unpacked.data());
sg.UnPack(packed, &unpacked_matrix);
auto unpacked_data = unpacked.data();
for (index_t i = 0; i < packed.size(); ++i) {
EXPECT_EQ(data[i], unpacked_data[i]);
}
}
} // namespace
TEST(SGemmPackTest, Pack) {
std::vector<float> data =
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36};
// For no-transpose lhs
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20},
5, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#if defined(__aarch64__)
TestPack(data,
{1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11,
15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36},
9, 4, Major::SGemmRowMajor, PackOrder::SGemmColMajor);
#endif
#endif
// For transpose-needed lhs
TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
3, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20},
5, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#if defined(__aarch64__)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21,
22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36},
9, 4, Major::SGemmColMajor, PackOrder::SGemmColMajor);
#endif
#endif
// For no-transpose rhs
TestPack(data,
{1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12},
4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestPack(data,
{1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15,
20},
4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
#endif
// For transpose-needed rhs
TestPack(data,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
#if defined(MACE_ENABLE_NEON)
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16},
4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestPack(data,
{1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19,
20},
4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
#endif
}
TEST(SGemmPackTest, UnPack) {
TestUnPack(4, 3, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 4, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 5, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 100, Major::SGemmRowMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 3, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 4, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 5, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
TestUnPack(4, 100, Major::SGemmColMajor, PackOrder::SGemmRowMajor);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -124,12 +124,11 @@ TEST_F(MaceAPITest, MultipleInputOutput) {
}
TEST_F(MaceAPITest, VariableInputShape) {
// TODO(liyin): there is a bug of cpu convolution
// MaceRun<CPU, float>(1,
// {1, 32, 64, 16},
// {{1, 16, 32, 16}, {1, 32, 64, 16}},
// {{1, 16, 32, 16}, {1, 32, 64, 16}},
// {16, 16, 3, 3});
MaceRun<CPU, float>(1,
{1, 32, 64, 16},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{16, 16, 3, 3});
MaceRun<GPU, float>(1,
{1, 32, 64, 16},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册