提交 5f2001d2 编写于 作者: 李寅

Merge branch 'conv1x1' into 'master'

Optimize conv1x1

See merge request !498
......@@ -61,6 +61,8 @@ class BufferBase {
virtual void Clear() = 0;
virtual void Clear(index_t size) = 0;
virtual index_t offset() const { return 0; }
template <typename T>
......@@ -198,7 +200,11 @@ class Buffer : public BufferBase {
bool OnHost() const { return allocator_->OnHost(); }
void Clear() {
memset(reinterpret_cast<char*>(raw_mutable_data()), 0, size_);
Clear(size_);
}
void Clear(index_t size) {
memset(reinterpret_cast<char*>(raw_mutable_data()), 0, size);
}
protected:
......@@ -312,6 +318,11 @@ class Image : public BufferBase {
MACE_NOT_IMPLEMENTED;
}
void Clear(index_t size) {
MACE_UNUSED(size);
MACE_NOT_IMPLEMENTED;
}
private:
Allocator *allocator_;
std::vector<size_t> shape_;
......@@ -431,7 +442,11 @@ class BufferSlice : public BufferBase {
bool OnHost() const { return buffer_->OnHost(); }
void Clear() {
memset(raw_mutable_data(), 0, size_);
Clear(size_);
}
void Clear(index_t size) {
memset(raw_mutable_data(), 0, size);
}
private:
......
......@@ -208,7 +208,7 @@ class Tensor {
inline void Clear() {
MACE_CHECK_NOTNULL(buffer_);
buffer_->Clear();
buffer_->Clear(raw_size());
}
inline void Reshape(const std::vector<index_t> &shape) {
......
......@@ -297,7 +297,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
......@@ -415,7 +414,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else {
} else if (!use_neon_1x1_s1) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * stride_h
......@@ -602,7 +601,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const Tensor *pad_input_ptr = input;
if (extra_input_height != input_height
|| extra_input_width != input_width) {
padded_input.Clear();
ConstructNCHWInputWithSpecificPadding(input,
pad_top,
pad_bottom,
......@@ -612,13 +610,17 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
pad_input_ptr = &padded_input;
}
// TODO(libin): don't need clear after bias is integrated in each conv
Tensor *pad_output_ptr = output;
if (extra_output_height != height || extra_output_width != width) {
padded_output.Reshape({batch, channels, extra_output_height,
extra_output_width});
padded_output.Clear();
pad_output_ptr = &padded_output;
} else if (!use_neon_1x1_s1) {
output->Clear();
}
const float *pad_input_data = pad_input_ptr->data<float>();
float *pad_output_data = pad_output_ptr->mutable_data<float>();
......
......@@ -377,6 +377,7 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
std::vector<index_t> output_shape(
{batch, channels, height + pad_height, width + pad_width});
output_tensor->Resize(output_shape);
output_tensor->Clear();
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
......
......@@ -140,8 +140,8 @@ inline void GemmTile(const float *A,
#endif
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for (h = 0; h + 7 < height; h += 8) {
for (k = 0; k + 7 < K; k += 8) {
for (h = 0; h < height - 7; h += 8) {
for (k = 0; k < K - 7; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
#ifdef __clang__
int nw = width >> 2;
......@@ -185,153 +185,147 @@ inline void GemmTile(const float *A,
float *c_ptr7 = C + (h + 7) * stride_w;
asm volatile(
"0: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v24.4s}, [%1] \n"
// load b: 0-7
"prfm pldl1keep, [%9, #128] \n"
"ld1 {v16.4s}, [%9], #16 \n"
"prfm pldl1keep, [%10, #128] \n"
"ld1 {v17.4s}, [%10], #16 \n"
"prfm pldl1keep, [%11, #128] \n"
"ld1 {v18.4s}, [%11], #16 \n"
"prfm pldl1keep, [%12, #128] \n"
"ld1 {v19.4s}, [%12], #16 \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v18.4s}, [%1] \n"
"prfm pldl1keep, [%2, #128] \n"
"ld1 {v25.4s}, [%2] \n"
"prfm pldl1keep, [%13, #128] \n"
"ld1 {v20.4s}, [%13], #16 \n"
"prfm pldl1keep, [%14, #128] \n"
"ld1 {v21.4s}, [%14], #16 \n"
"ld1 {v19.4s}, [%2] \n"
"prfm pldl1keep, [%15, #128] \n"
"ld1 {v22.4s}, [%15], #16 \n"
"prfm pldl1keep, [%16, #128] \n"
"ld1 {v23.4s}, [%16], #16 \n"
"0: \n"
"prfm pldl1keep, [%3, #128] \n"
"ld1 {v26.4s}, [%3] \n"
"fmla v24.4s, v16.4s, %34.s[0] \n"
"fmla v24.4s, v17.4s, %34.s[1] \n"
"fmla v24.4s, v18.4s, %34.s[2] \n"
"fmla v24.4s, v19.4s, %34.s[3] \n"
"fmla v24.4s, v20.4s, %35.s[0] \n"
"fmla v24.4s, v21.4s, %35.s[1] \n"
"fmla v24.4s, v22.4s, %35.s[2] \n"
"fmla v24.4s, v23.4s, %35.s[3] \n"
"st1 {v24.4s}, [%1], #16 \n"
"fmla v25.4s, v16.4s, %36.s[0] \n"
"fmla v25.4s, v17.4s, %36.s[1] \n"
"fmla v25.4s, v18.4s, %36.s[2] \n"
"fmla v25.4s, v19.4s, %36.s[3] \n"
"fmla v25.4s, v20.4s, %37.s[0] \n"
"fmla v25.4s, v21.4s, %37.s[1] \n"
"fmla v25.4s, v22.4s, %37.s[2] \n"
"fmla v25.4s, v23.4s, %37.s[3] \n"
"ld1 {v20.4s}, [%3] \n"
"prfm pldl1keep, [%4, #128] \n"
"ld1 {v24.4s}, [%4] \n"
"st1 {v25.4s}, [%2], #16 \n"
"fmla v26.4s, v16.4s, %38.s[0] \n"
"fmla v26.4s, v17.4s, %38.s[1] \n"
"fmla v26.4s, v18.4s, %38.s[2] \n"
"fmla v26.4s, v19.4s, %38.s[3] \n"
"fmla v26.4s, v20.4s, %39.s[0] \n"
"fmla v26.4s, v21.4s, %39.s[1] \n"
"fmla v26.4s, v22.4s, %39.s[2] \n"
"fmla v26.4s, v23.4s, %39.s[3] \n"
"ld1 {v21.4s}, [%4] \n"
"prfm pldl1keep, [%5, #128] \n"
"ld1 {v25.4s}, [%5] \n"
"ld1 {v22.4s}, [%5] \n"
"prfm pldl1keep, [%6, #128] \n"
"ld1 {v23.4s}, [%6] \n"
"prfm pldl1keep, [%7, #128] \n"
"ld1 {v24.4s}, [%7] \n"
"prfm pldl1keep, [%8, #128] \n"
"ld1 {v25.4s}, [%8] \n"
"prfm pldl1keep, [%10, #128] \n"
"ld1 {v17.4s}, [%10], #16 \n"
"st1 {v26.4s}, [%3], #16 \n"
"fmla v18.4s, v16.4s, %34.s[0] \n"
"fmla v19.4s, v16.4s, %35.s[0] \n"
"fmla v20.4s, v16.4s, %36.s[0] \n"
"fmla v21.4s, v16.4s, %37.s[0] \n"
"fmla v22.4s, v16.4s, %38.s[0] \n"
"fmla v23.4s, v16.4s, %39.s[0] \n"
"fmla v24.4s, v16.4s, %40.s[0] \n"
"fmla v24.4s, v17.4s, %40.s[1] \n"
"fmla v24.4s, v18.4s, %40.s[2] \n"
"fmla v24.4s, v19.4s, %40.s[3] \n"
"fmla v25.4s, v16.4s, %41.s[0] \n"
"fmla v24.4s, v20.4s, %41.s[0] \n"
"fmla v24.4s, v21.4s, %41.s[1] \n"
"fmla v24.4s, v22.4s, %41.s[2] \n"
"fmla v24.4s, v23.4s, %41.s[3] \n"
"fmla v18.4s, v17.4s, %34.s[1] \n"
"fmla v19.4s, v17.4s, %35.s[1] \n"
"fmla v20.4s, v17.4s, %36.s[1] \n"
"fmla v21.4s, v17.4s, %37.s[1] \n"
"prfm pldl1keep, [%6, #128] \n"
"ld1 {v26.4s}, [%6] \n"
"prfm pldl1keep, [%11, #128] \n"
"ld1 {v16.4s}, [%11], #16 \n"
"st1 {v24.4s}, [%4], #16 \n"
"fmla v22.4s, v17.4s, %38.s[1] \n"
"fmla v23.4s, v17.4s, %39.s[1] \n"
"fmla v24.4s, v17.4s, %40.s[1] \n"
"fmla v25.4s, v17.4s, %41.s[1] \n"
"fmla v25.4s, v16.4s, %42.s[0] \n"
"fmla v25.4s, v17.4s, %42.s[1] \n"
"fmla v25.4s, v18.4s, %42.s[2] \n"
"fmla v25.4s, v19.4s, %42.s[3] \n"
"fmla v18.4s, v16.4s, %34.s[2] \n"
"fmla v19.4s, v16.4s, %35.s[2] \n"
"fmla v20.4s, v16.4s, %36.s[2] \n"
"fmla v21.4s, v16.4s, %37.s[2] \n"
"fmla v25.4s, v20.4s, %43.s[0] \n"
"fmla v25.4s, v21.4s, %43.s[1] \n"
"fmla v25.4s, v22.4s, %43.s[2] \n"
"fmla v25.4s, v23.4s, %43.s[3] \n"
"prfm pldl1keep, [%12, #128] \n"
"ld1 {v17.4s}, [%12], #16 \n"
"prfm pldl1keep, [%7, #128] \n"
"ld1 {v24.4s}, [%7] \n"
"fmla v22.4s, v16.4s, %38.s[2] \n"
"fmla v23.4s, v16.4s, %39.s[2] \n"
"fmla v24.4s, v16.4s, %40.s[2] \n"
"fmla v25.4s, v16.4s, %41.s[2] \n"
"st1 {v25.4s}, [%5], #16 \n"
"fmla v18.4s, v17.4s, %34.s[3] \n"
"fmla v19.4s, v17.4s, %35.s[3] \n"
"fmla v20.4s, v17.4s, %36.s[3] \n"
"fmla v21.4s, v17.4s, %37.s[3] \n"
"fmla v26.4s, v16.4s, %44.s[0] \n"
"fmla v26.4s, v17.4s, %44.s[1] \n"
"fmla v26.4s, v18.4s, %44.s[2] \n"
"fmla v26.4s, v19.4s, %44.s[3] \n"
"prfm pldl1keep, [%13, #128] \n"
"ld1 {v16.4s}, [%13], #16 \n"
"fmla v26.4s, v20.4s, %45.s[0] \n"
"fmla v26.4s, v21.4s, %45.s[1] \n"
"fmla v26.4s, v22.4s, %45.s[2] \n"
"fmla v26.4s, v23.4s, %45.s[3] \n"
"fmla v22.4s, v17.4s, %38.s[3] \n"
"fmla v23.4s, v17.4s, %39.s[3] \n"
"fmla v24.4s, v17.4s, %40.s[3] \n"
"fmla v25.4s, v17.4s, %41.s[3] \n"
"prfm pldl1keep, [%8, #128] \n"
"ld1 {v25.4s}, [%8] \n"
"fmla v18.4s, v16.4s, %42.s[0] \n"
"fmla v19.4s, v16.4s, %43.s[0] \n"
"fmla v20.4s, v16.4s, %44.s[0] \n"
"fmla v21.4s, v16.4s, %45.s[0] \n"
"st1 {v26.4s}, [%6], #16 \n"
"prfm pldl1keep, [%14, #128] \n"
"ld1 {v17.4s}, [%14], #16 \n"
"fmla v24.4s, v16.4s, %46.s[0] \n"
"fmla v24.4s, v17.4s, %46.s[1] \n"
"fmla v24.4s, v18.4s, %46.s[2] \n"
"fmla v24.4s, v19.4s, %46.s[3] \n"
"fmla v22.4s, v16.4s, %46.s[0] \n"
"fmla v23.4s, v16.4s, %47.s[0] \n"
"fmla v24.4s, v16.4s, %48.s[0] \n"
"fmla v25.4s, v16.4s, %49.s[0] \n"
"fmla v24.4s, v20.4s, %47.s[0] \n"
"fmla v24.4s, v21.4s, %47.s[1] \n"
"fmla v24.4s, v22.4s, %47.s[2] \n"
"fmla v24.4s, v23.4s, %47.s[3] \n"
"fmla v18.4s, v17.4s, %42.s[1] \n"
"fmla v19.4s, v17.4s, %43.s[1] \n"
"fmla v20.4s, v17.4s, %44.s[1] \n"
"fmla v21.4s, v17.4s, %45.s[1] \n"
"st1 {v24.4s}, [%7], #16 \n"
"prfm pldl1keep, [%15, #128] \n"
"ld1 {v16.4s}, [%15], #16 \n"
"fmla v25.4s, v16.4s, %48.s[0] \n"
"fmla v25.4s, v17.4s, %48.s[1] \n"
"fmla v25.4s, v18.4s, %48.s[2] \n"
"fmla v25.4s, v19.4s, %48.s[3] \n"
"fmla v22.4s, v17.4s, %46.s[1] \n"
"fmla v23.4s, v17.4s, %47.s[1] \n"
"fmla v24.4s, v17.4s, %48.s[1] \n"
"fmla v25.4s, v17.4s, %49.s[1] \n"
"fmla v25.4s, v20.4s, %49.s[0] \n"
"fmla v25.4s, v21.4s, %49.s[1] \n"
"fmla v25.4s, v22.4s, %49.s[2] \n"
"fmla v25.4s, v23.4s, %49.s[3] \n"
"fmla v18.4s, v16.4s, %42.s[2] \n"
"fmla v19.4s, v16.4s, %43.s[2] \n"
"fmla v20.4s, v16.4s, %44.s[2] \n"
"fmla v21.4s, v16.4s, %45.s[2] \n"
"prfm pldl1keep, [%16, #128] \n"
"ld1 {v17.4s}, [%16], #16 \n"
"fmla v22.4s, v16.4s, %46.s[2] \n"
"fmla v23.4s, v16.4s, %47.s[2] \n"
"fmla v24.4s, v16.4s, %48.s[2] \n"
"fmla v25.4s, v16.4s, %49.s[2] \n"
"fmla v18.4s, v17.4s, %42.s[3] \n"
"fmla v19.4s, v17.4s, %43.s[3] \n"
"fmla v20.4s, v17.4s, %44.s[3] \n"
"fmla v21.4s, v17.4s, %45.s[3] \n"
"st1 {v18.4s}, [%1], #16 \n"
"st1 {v19.4s}, [%2], #16 \n"
"st1 {v20.4s}, [%3], #16 \n"
"st1 {v21.4s}, [%4], #16 \n"
"fmla v22.4s, v17.4s, %46.s[3] \n"
"fmla v23.4s, v17.4s, %47.s[3] \n"
"fmla v24.4s, v17.4s, %48.s[3] \n"
"fmla v25.4s, v17.4s, %49.s[3] \n"
"st1 {v22.4s}, [%5], #16 \n"
"st1 {v23.4s}, [%6], #16 \n"
"st1 {v24.4s}, [%7], #16 \n"
"st1 {v25.4s}, [%8], #16 \n"
"prfm pldl1keep, [%9, #128] \n"
"ld1 {v16.4s}, [%9], #16 \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v18.4s}, [%1] \n"
"prfm pldl1keep, [%2, #128] \n"
"ld1 {v19.4s}, [%2] \n"
"subs %w0, %w0, #1 \n"
"bne 0b \n"
: "=r"(nw), // 0
......@@ -369,20 +363,20 @@ inline void GemmTile(const float *A,
"15"(b_ptr6), // 32
"16"(b_ptr7), // 33
"w"(a0), // 34
"w"(a1), // 35
"w"(a2), // 36
"w"(a3), // 37
"w"(a4), // 38
"w"(a5), // 39
"w"(a6), // 40
"w"(a7), // 41
"w"(a8), // 42
"w"(a9), // 43
"w"(a10), // 44
"w"(a11), // 45
"w"(a12), // 46
"w"(a13), // 47
"w"(a14), // 48
"w"(a2), // 35
"w"(a4), // 36
"w"(a6), // 37
"w"(a8), // 38
"w"(a10), // 39
"w"(a12), // 40
"w"(a14), // 41
"w"(a1), // 42
"w"(a3), // 43
"w"(a5), // 44
"w"(a7), // 45
"w"(a9), // 46
"w"(a11), // 47
"w"(a13), // 48
"w"(a15) // 49
: "cc", "memory",
"v16",
......@@ -585,7 +579,6 @@ void Gemm(const float *A,
}
memset(C, 0, sizeof(float) * batch * height * width);
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册