提交 961dddd9 编写于 作者: 李寅

Replace gemm to sgemm

上级 9e79fef3
......@@ -469,6 +469,7 @@ class ScratchBuffer: public Buffer {
MaceStatus GrowSize(index_t size) {
if (size > size_) {
MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size");
return Resize(size);
}
return MaceStatus::MACE_SUCCESS;
......@@ -487,8 +488,12 @@ class ScratchBuffer: public Buffer {
return slice;
}
void Rewind() {
offset_ = 0;
void Rewind(index_t offset = 0) {
offset_ = offset;
}
index_t offset() const {
return offset_;
}
private:
......
......@@ -16,6 +16,7 @@
#define MACE_KERNELS_ARM_CONV_2D_NEON_H_
#include "mace/core/types.h"
#include "mace/kernels/sgemm.h"
namespace mace {
namespace kernels {
......@@ -27,7 +28,9 @@ void Conv2dNeonK1x1S1(const float *input,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output);
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer);
void Conv2dNeonK3x3S1(const float *input,
const float *filter,
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "mace/kernels/arm/conv_2d_neon.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
......@@ -25,11 +24,23 @@ void Conv2dNeonK1x1S1(const float *input,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output) {
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
for (index_t b = 0; b < batch; ++b) {
Gemm(filter, input + b * in_channels * height * width, 1, out_channels,
in_channels, height * width,
output + b * out_channels * height * width);
sgemm->Run(filter,
input + b * in_channels * height * width,
1,
out_channels,
in_channels,
in_channels,
height * width,
false,
false,
true,
false,
output + b * out_channels * height * width,
scratch_buffer);
}
}
......
......@@ -12,13 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/logging.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
......@@ -247,30 +244,38 @@ void BatchGemm(const float *input,
index_t out_channels,
index_t tile_count,
int out_tile_size,
float *output) {
const index_t filter_stride = out_channels * in_channels;
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
const index_t in_batch_size = in_tile_area * in_channels * tile_count;
const index_t in_stride = in_channels * tile_count;
const index_t out_batch_size = in_tile_area * out_channels * tile_count;
const index_t out_stride = out_channels * tile_count;
if (batch == 1) {
Gemm(filter, input, in_tile_area, out_channels, in_channels, tile_count,
output);
} else {
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int i = 0; i < in_tile_area; ++i) {
const float *in_ptr = input + b * in_batch_size + i * in_stride;
const float *filter_ptr = filter + i * filter_stride;
float *out_ptr = output + b * out_batch_size + i * out_stride;
Gemm(filter_ptr, in_ptr, 1, out_channels, /* rows */
in_channels, /* K */
tile_count, /* cols */
out_ptr);
}
index_t scratch_buffer_offset = 0;
if (scratch_buffer) {
scratch_buffer_offset = scratch_buffer->offset();
}
// 'batch' is not gemm batch, 'in_tile_area' is. gemm is not thread safe,
// so we loop batch using single thread.
// Scratch buffer should be rewind to the initial position to use same
// scratch memory for each batch.
for (int b = 0; b < batch; ++b) {
if (scratch_buffer) {
scratch_buffer->Rewind(scratch_buffer_offset);
}
sgemm->Run(filter,
input + b * in_batch_size,
in_tile_area,
out_channels,
in_channels,
in_channels,
tile_count,
false,
false,
true,
false,
output + b * out_batch_size,
scratch_buffer);
}
}
......@@ -613,7 +618,9 @@ void WinoGradConv3x3s1(const float *input,
const int out_tile_size,
float *transformed_input,
float *transformed_output,
float *output) {
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count =
......@@ -636,7 +643,8 @@ void WinoGradConv3x3s1(const float *input,
}
BatchGemm(transformed_input, transformed_filter, batch, in_channels,
out_channels, tile_count, out_tile_size, transformed_output);
out_channels, tile_count, out_tile_size, transformed_output,
sgemm, scratch_buffer);
switch (out_tile_size) {
case 2:
......@@ -660,7 +668,9 @@ void WinoGradConv3x3s1(const float *input,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *output) {
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count =
......@@ -692,7 +702,7 @@ void WinoGradConv3x3s1(const float *input,
WinoGradConv3x3s1(input, transformed_filter, batch, in_height, in_width,
in_channels, out_channels, out_tile_size, transformed_input,
transformed_output, output);
transformed_output, output, sgemm, scratch_buffer);
delete[] transformed_input;
delete[] transformed_filter;
......
......@@ -20,6 +20,7 @@
#endif
#include "mace/core/types.h"
#include "mace/kernels/sgemm.h"
namespace mace {
namespace kernels {
......@@ -42,7 +43,9 @@ void WinoGradConv3x3s1(const float *input,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *output);
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer);
void WinoGradConv3x3s1(const float *input,
const float *transformed_filter,
......@@ -54,7 +57,9 @@ void WinoGradConv3x3s1(const float *input,
const int out_tile_size,
float *transformed_input,
float *transformed_output,
float *output);
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer);
void ConvRef3x3s1(const float *input,
const float *filter,
......
......@@ -65,9 +65,10 @@ TEST(ConvWinogradTest, winograd) {
kernels::ConvRef3x3s1(input_data, filter_data, batch, in_height, in_width,
in_channels, out_channels, output_data_ref);
SGemm sgemm;
kernels::WinoGradConv3x3s1(input_data, filter_data, batch, in_height,
in_width, in_channels, out_channels, 6,
output_data);
output_data, &sgemm, nullptr);
// test
for (index_t i = 0; i < output_size; ++i) {
......
......@@ -483,6 +483,16 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
* sizeof(float);
total_scratch_size += padded_output_size;
}
// scratch for sgemm
if (use_neon_1x1_s1) {
total_scratch_size +=
(input_batch * input_height * input_width
* (input_channels + channels)) * sizeof(float);
} else if (use_winograd) {
total_scratch_size +=
(transformed_input_size + transformed_output_size) * sizeof(float);
}
// Init scratch buffer
scratch_->Rewind();
scratch_->GrowSize(total_scratch_size);
......@@ -547,7 +557,9 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
winograd_out_tile_size,
transformed_input_data,
transformed_output_data,
pad_output);
pad_output,
&sgemm_,
scratch_);
};
} else if (use_neon_3x3_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
......@@ -574,7 +586,9 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
extra_input_width,
input_channels,
channels,
pad_output);
pad_output,
&sgemm_,
scratch_);
};
} else if (use_neon_5x5_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
......@@ -722,6 +736,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
Tensor transformed_filter_;
bool is_filter_transformed_;
ScratchBuffer *scratch_;
SGemm sgemm_;
};
template<>
......
......@@ -89,6 +89,9 @@ struct MatMulFunctor : OpKernel {
const index_t height_b = B->dim(rank - 2);
const index_t width_b = B->dim(rank - 1);
auto scratch_buffer = context_->workspace()->GetScratchBuffer(D);
scratch_buffer->Rewind();
sgemm_.Run(a_ptr_base,
b_ptr_base,
batch,
......@@ -101,7 +104,7 @@ struct MatMulFunctor : OpKernel {
A->is_weight(),
B->is_weight(),
c_ptr_base,
context_->workspace()->GetScratchBuffer(D));
scratch_buffer);
return MACE_SUCCESS;
}
......
......@@ -44,7 +44,6 @@ void SGemm::operator()(const MatrixMap<const float> &lhs,
}
if (scratch_buffer != nullptr) {
scratch_buffer->Rewind();
index_t total_size = result->size();
if (!lhs.is_const()) {
total_size += lhs.size();
......@@ -54,7 +53,6 @@ void SGemm::operator()(const MatrixMap<const float> &lhs,
}
scratch_buffer->GrowSize(total_size * sizeof(float));
scratch_buffer->Rewind();
if (!lhs.is_const()) {
packed_lhs_.reset(new Tensor(scratch_buffer->Scratch(
lhs.size() * sizeof(float)), DT_FLOAT));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册