提交 5f2001d2 编写于 作者: 李寅

Merge branch 'conv1x1' into 'master'

Optimize conv1x1

See merge request !498
...@@ -61,6 +61,8 @@ class BufferBase { ...@@ -61,6 +61,8 @@ class BufferBase {
virtual void Clear() = 0; virtual void Clear() = 0;
virtual void Clear(index_t size) = 0;
virtual index_t offset() const { return 0; } virtual index_t offset() const { return 0; }
template <typename T> template <typename T>
...@@ -198,7 +200,11 @@ class Buffer : public BufferBase { ...@@ -198,7 +200,11 @@ class Buffer : public BufferBase {
bool OnHost() const { return allocator_->OnHost(); } bool OnHost() const { return allocator_->OnHost(); }
void Clear() { void Clear() {
memset(reinterpret_cast<char*>(raw_mutable_data()), 0, size_); Clear(size_);
}
void Clear(index_t size) {
memset(reinterpret_cast<char*>(raw_mutable_data()), 0, size);
} }
protected: protected:
...@@ -312,6 +318,11 @@ class Image : public BufferBase { ...@@ -312,6 +318,11 @@ class Image : public BufferBase {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
void Clear(index_t size) {
MACE_UNUSED(size);
MACE_NOT_IMPLEMENTED;
}
private: private:
Allocator *allocator_; Allocator *allocator_;
std::vector<size_t> shape_; std::vector<size_t> shape_;
...@@ -431,7 +442,11 @@ class BufferSlice : public BufferBase { ...@@ -431,7 +442,11 @@ class BufferSlice : public BufferBase {
bool OnHost() const { return buffer_->OnHost(); } bool OnHost() const { return buffer_->OnHost(); }
void Clear() { void Clear() {
memset(raw_mutable_data(), 0, size_); Clear(size_);
}
void Clear(index_t size) {
memset(raw_mutable_data(), 0, size);
} }
private: private:
......
...@@ -208,7 +208,7 @@ class Tensor { ...@@ -208,7 +208,7 @@ class Tensor {
inline void Clear() { inline void Clear() {
MACE_CHECK_NOTNULL(buffer_); MACE_CHECK_NOTNULL(buffer_);
buffer_->Clear(); buffer_->Clear(raw_size());
} }
inline void Reshape(const std::vector<index_t> &shape) { inline void Reshape(const std::vector<index_t> &shape) {
......
...@@ -297,7 +297,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -297,7 +297,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
output_shape.data()); output_shape.data());
} }
output->Resize(output_shape); output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0); index_t batch = output->dim(0);
index_t channels = output->dim(1); index_t channels = output->dim(1);
...@@ -415,7 +414,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -415,7 +414,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
if (extra_input_width != padded_input_width) { if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width); pad_right += (extra_input_width - padded_input_width);
} }
} else { } else if (!use_neon_1x1_s1) {
extra_output_height = height; extra_output_height = height;
extra_input_height = extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * stride_h std::max(padded_input_height, (extra_output_height - 1) * stride_h
...@@ -602,7 +601,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -602,7 +601,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const Tensor *pad_input_ptr = input; const Tensor *pad_input_ptr = input;
if (extra_input_height != input_height if (extra_input_height != input_height
|| extra_input_width != input_width) { || extra_input_width != input_width) {
padded_input.Clear();
ConstructNCHWInputWithSpecificPadding(input, ConstructNCHWInputWithSpecificPadding(input,
pad_top, pad_top,
pad_bottom, pad_bottom,
...@@ -612,13 +610,17 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -612,13 +610,17 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
pad_input_ptr = &padded_input; pad_input_ptr = &padded_input;
} }
// TODO(libin): don't need clear after bias is integrated in each conv
Tensor *pad_output_ptr = output; Tensor *pad_output_ptr = output;
if (extra_output_height != height || extra_output_width != width) { if (extra_output_height != height || extra_output_width != width) {
padded_output.Reshape({batch, channels, extra_output_height, padded_output.Reshape({batch, channels, extra_output_height,
extra_output_width}); extra_output_width});
padded_output.Clear(); padded_output.Clear();
pad_output_ptr = &padded_output; pad_output_ptr = &padded_output;
} else if (!use_neon_1x1_s1) {
output->Clear();
} }
const float *pad_input_data = pad_input_ptr->data<float>(); const float *pad_input_data = pad_input_ptr->data<float>();
float *pad_output_data = pad_output_ptr->mutable_data<float>(); float *pad_output_data = pad_output_ptr->mutable_data<float>();
......
...@@ -377,6 +377,7 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor, ...@@ -377,6 +377,7 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
std::vector<index_t> output_shape( std::vector<index_t> output_shape(
{batch, channels, height + pad_height, width + pad_width}); {batch, channels, height + pad_height, width + pad_width});
output_tensor->Resize(output_shape); output_tensor->Resize(output_shape);
output_tensor->Clear();
Tensor::MappingGuard padded_output_mapper(output_tensor); Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>(); float *output_data = output_tensor->mutable_data<float>();
......
...@@ -140,8 +140,8 @@ inline void GemmTile(const float *A, ...@@ -140,8 +140,8 @@ inline void GemmTile(const float *A,
#endif #endif
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) #if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for (h = 0; h + 7 < height; h += 8) { for (h = 0; h < height - 7; h += 8) {
for (k = 0; k + 7 < K; k += 8) { for (k = 0; k < K - 7; k += 8) {
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
#ifdef __clang__ #ifdef __clang__
int nw = width >> 2; int nw = width >> 2;
...@@ -185,153 +185,147 @@ inline void GemmTile(const float *A, ...@@ -185,153 +185,147 @@ inline void GemmTile(const float *A,
float *c_ptr7 = C + (h + 7) * stride_w; float *c_ptr7 = C + (h + 7) * stride_w;
asm volatile( asm volatile(
"0: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v24.4s}, [%1] \n"
// load b: 0-7
"prfm pldl1keep, [%9, #128] \n" "prfm pldl1keep, [%9, #128] \n"
"ld1 {v16.4s}, [%9], #16 \n" "ld1 {v16.4s}, [%9], #16 \n"
"prfm pldl1keep, [%10, #128] \n" "prfm pldl1keep, [%1, #128] \n"
"ld1 {v17.4s}, [%10], #16 \n" "ld1 {v18.4s}, [%1] \n"
"prfm pldl1keep, [%11, #128] \n"
"ld1 {v18.4s}, [%11], #16 \n"
"prfm pldl1keep, [%12, #128] \n"
"ld1 {v19.4s}, [%12], #16 \n"
"prfm pldl1keep, [%2, #128] \n" "prfm pldl1keep, [%2, #128] \n"
"ld1 {v25.4s}, [%2] \n" "ld1 {v19.4s}, [%2] \n"
"prfm pldl1keep, [%13, #128] \n"
"ld1 {v20.4s}, [%13], #16 \n"
"prfm pldl1keep, [%14, #128] \n"
"ld1 {v21.4s}, [%14], #16 \n"
"prfm pldl1keep, [%15, #128] \n" "0: \n"
"ld1 {v22.4s}, [%15], #16 \n"
"prfm pldl1keep, [%16, #128] \n"
"ld1 {v23.4s}, [%16], #16 \n"
"prfm pldl1keep, [%3, #128] \n" "prfm pldl1keep, [%3, #128] \n"
"ld1 {v26.4s}, [%3] \n" "ld1 {v20.4s}, [%3] \n"
"fmla v24.4s, v16.4s, %34.s[0] \n"
"fmla v24.4s, v17.4s, %34.s[1] \n"
"fmla v24.4s, v18.4s, %34.s[2] \n"
"fmla v24.4s, v19.4s, %34.s[3] \n"
"fmla v24.4s, v20.4s, %35.s[0] \n"
"fmla v24.4s, v21.4s, %35.s[1] \n"
"fmla v24.4s, v22.4s, %35.s[2] \n"
"fmla v24.4s, v23.4s, %35.s[3] \n"
"st1 {v24.4s}, [%1], #16 \n"
"fmla v25.4s, v16.4s, %36.s[0] \n"
"fmla v25.4s, v17.4s, %36.s[1] \n"
"fmla v25.4s, v18.4s, %36.s[2] \n"
"fmla v25.4s, v19.4s, %36.s[3] \n"
"fmla v25.4s, v20.4s, %37.s[0] \n"
"fmla v25.4s, v21.4s, %37.s[1] \n"
"fmla v25.4s, v22.4s, %37.s[2] \n"
"fmla v25.4s, v23.4s, %37.s[3] \n"
"prfm pldl1keep, [%4, #128] \n" "prfm pldl1keep, [%4, #128] \n"
"ld1 {v24.4s}, [%4] \n" "ld1 {v21.4s}, [%4] \n"
"st1 {v25.4s}, [%2], #16 \n"
"fmla v26.4s, v16.4s, %38.s[0] \n"
"fmla v26.4s, v17.4s, %38.s[1] \n"
"fmla v26.4s, v18.4s, %38.s[2] \n"
"fmla v26.4s, v19.4s, %38.s[3] \n"
"fmla v26.4s, v20.4s, %39.s[0] \n"
"fmla v26.4s, v21.4s, %39.s[1] \n"
"fmla v26.4s, v22.4s, %39.s[2] \n"
"fmla v26.4s, v23.4s, %39.s[3] \n"
"prfm pldl1keep, [%5, #128] \n" "prfm pldl1keep, [%5, #128] \n"
"ld1 {v25.4s}, [%5] \n" "ld1 {v22.4s}, [%5] \n"
"prfm pldl1keep, [%6, #128] \n"
"ld1 {v23.4s}, [%6] \n"
"prfm pldl1keep, [%7, #128] \n"
"ld1 {v24.4s}, [%7] \n"
"prfm pldl1keep, [%8, #128] \n"
"ld1 {v25.4s}, [%8] \n"
"prfm pldl1keep, [%10, #128] \n"
"ld1 {v17.4s}, [%10], #16 \n"
"st1 {v26.4s}, [%3], #16 \n" "fmla v18.4s, v16.4s, %34.s[0] \n"
"fmla v19.4s, v16.4s, %35.s[0] \n"
"fmla v20.4s, v16.4s, %36.s[0] \n"
"fmla v21.4s, v16.4s, %37.s[0] \n"
"fmla v22.4s, v16.4s, %38.s[0] \n"
"fmla v23.4s, v16.4s, %39.s[0] \n"
"fmla v24.4s, v16.4s, %40.s[0] \n" "fmla v24.4s, v16.4s, %40.s[0] \n"
"fmla v24.4s, v17.4s, %40.s[1] \n" "fmla v25.4s, v16.4s, %41.s[0] \n"
"fmla v24.4s, v18.4s, %40.s[2] \n"
"fmla v24.4s, v19.4s, %40.s[3] \n"
"fmla v24.4s, v20.4s, %41.s[0] \n" "fmla v18.4s, v17.4s, %34.s[1] \n"
"fmla v24.4s, v21.4s, %41.s[1] \n" "fmla v19.4s, v17.4s, %35.s[1] \n"
"fmla v24.4s, v22.4s, %41.s[2] \n" "fmla v20.4s, v17.4s, %36.s[1] \n"
"fmla v24.4s, v23.4s, %41.s[3] \n" "fmla v21.4s, v17.4s, %37.s[1] \n"
"prfm pldl1keep, [%6, #128] \n" "prfm pldl1keep, [%11, #128] \n"
"ld1 {v26.4s}, [%6] \n" "ld1 {v16.4s}, [%11], #16 \n"
"st1 {v24.4s}, [%4], #16 \n" "fmla v22.4s, v17.4s, %38.s[1] \n"
"fmla v23.4s, v17.4s, %39.s[1] \n"
"fmla v24.4s, v17.4s, %40.s[1] \n"
"fmla v25.4s, v17.4s, %41.s[1] \n"
"fmla v25.4s, v16.4s, %42.s[0] \n" "fmla v18.4s, v16.4s, %34.s[2] \n"
"fmla v25.4s, v17.4s, %42.s[1] \n" "fmla v19.4s, v16.4s, %35.s[2] \n"
"fmla v25.4s, v18.4s, %42.s[2] \n" "fmla v20.4s, v16.4s, %36.s[2] \n"
"fmla v25.4s, v19.4s, %42.s[3] \n" "fmla v21.4s, v16.4s, %37.s[2] \n"
"fmla v25.4s, v20.4s, %43.s[0] \n" "prfm pldl1keep, [%12, #128] \n"
"fmla v25.4s, v21.4s, %43.s[1] \n" "ld1 {v17.4s}, [%12], #16 \n"
"fmla v25.4s, v22.4s, %43.s[2] \n"
"fmla v25.4s, v23.4s, %43.s[3] \n"
"prfm pldl1keep, [%7, #128] \n" "fmla v22.4s, v16.4s, %38.s[2] \n"
"ld1 {v24.4s}, [%7] \n" "fmla v23.4s, v16.4s, %39.s[2] \n"
"fmla v24.4s, v16.4s, %40.s[2] \n"
"fmla v25.4s, v16.4s, %41.s[2] \n"
"st1 {v25.4s}, [%5], #16 \n" "fmla v18.4s, v17.4s, %34.s[3] \n"
"fmla v19.4s, v17.4s, %35.s[3] \n"
"fmla v20.4s, v17.4s, %36.s[3] \n"
"fmla v21.4s, v17.4s, %37.s[3] \n"
"fmla v26.4s, v16.4s, %44.s[0] \n" "prfm pldl1keep, [%13, #128] \n"
"fmla v26.4s, v17.4s, %44.s[1] \n" "ld1 {v16.4s}, [%13], #16 \n"
"fmla v26.4s, v18.4s, %44.s[2] \n"
"fmla v26.4s, v19.4s, %44.s[3] \n"
"fmla v26.4s, v20.4s, %45.s[0] \n" "fmla v22.4s, v17.4s, %38.s[3] \n"
"fmla v26.4s, v21.4s, %45.s[1] \n" "fmla v23.4s, v17.4s, %39.s[3] \n"
"fmla v26.4s, v22.4s, %45.s[2] \n" "fmla v24.4s, v17.4s, %40.s[3] \n"
"fmla v26.4s, v23.4s, %45.s[3] \n" "fmla v25.4s, v17.4s, %41.s[3] \n"
"prfm pldl1keep, [%8, #128] \n" "fmla v18.4s, v16.4s, %42.s[0] \n"
"ld1 {v25.4s}, [%8] \n" "fmla v19.4s, v16.4s, %43.s[0] \n"
"fmla v20.4s, v16.4s, %44.s[0] \n"
"fmla v21.4s, v16.4s, %45.s[0] \n"
"st1 {v26.4s}, [%6], #16 \n" "prfm pldl1keep, [%14, #128] \n"
"ld1 {v17.4s}, [%14], #16 \n"
"fmla v24.4s, v16.4s, %46.s[0] \n" "fmla v22.4s, v16.4s, %46.s[0] \n"
"fmla v24.4s, v17.4s, %46.s[1] \n" "fmla v23.4s, v16.4s, %47.s[0] \n"
"fmla v24.4s, v18.4s, %46.s[2] \n" "fmla v24.4s, v16.4s, %48.s[0] \n"
"fmla v24.4s, v19.4s, %46.s[3] \n" "fmla v25.4s, v16.4s, %49.s[0] \n"
"fmla v24.4s, v20.4s, %47.s[0] \n" "fmla v18.4s, v17.4s, %42.s[1] \n"
"fmla v24.4s, v21.4s, %47.s[1] \n" "fmla v19.4s, v17.4s, %43.s[1] \n"
"fmla v24.4s, v22.4s, %47.s[2] \n" "fmla v20.4s, v17.4s, %44.s[1] \n"
"fmla v24.4s, v23.4s, %47.s[3] \n" "fmla v21.4s, v17.4s, %45.s[1] \n"
"st1 {v24.4s}, [%7], #16 \n" "prfm pldl1keep, [%15, #128] \n"
"ld1 {v16.4s}, [%15], #16 \n"
"fmla v25.4s, v16.4s, %48.s[0] \n" "fmla v22.4s, v17.4s, %46.s[1] \n"
"fmla v25.4s, v17.4s, %48.s[1] \n" "fmla v23.4s, v17.4s, %47.s[1] \n"
"fmla v25.4s, v18.4s, %48.s[2] \n" "fmla v24.4s, v17.4s, %48.s[1] \n"
"fmla v25.4s, v19.4s, %48.s[3] \n" "fmla v25.4s, v17.4s, %49.s[1] \n"
"fmla v25.4s, v20.4s, %49.s[0] \n" "fmla v18.4s, v16.4s, %42.s[2] \n"
"fmla v25.4s, v21.4s, %49.s[1] \n" "fmla v19.4s, v16.4s, %43.s[2] \n"
"fmla v25.4s, v22.4s, %49.s[2] \n" "fmla v20.4s, v16.4s, %44.s[2] \n"
"fmla v25.4s, v23.4s, %49.s[3] \n" "fmla v21.4s, v16.4s, %45.s[2] \n"
"prfm pldl1keep, [%16, #128] \n"
"ld1 {v17.4s}, [%16], #16 \n"
"fmla v22.4s, v16.4s, %46.s[2] \n"
"fmla v23.4s, v16.4s, %47.s[2] \n"
"fmla v24.4s, v16.4s, %48.s[2] \n"
"fmla v25.4s, v16.4s, %49.s[2] \n"
"fmla v18.4s, v17.4s, %42.s[3] \n"
"fmla v19.4s, v17.4s, %43.s[3] \n"
"fmla v20.4s, v17.4s, %44.s[3] \n"
"fmla v21.4s, v17.4s, %45.s[3] \n"
"st1 {v18.4s}, [%1], #16 \n"
"st1 {v19.4s}, [%2], #16 \n"
"st1 {v20.4s}, [%3], #16 \n"
"st1 {v21.4s}, [%4], #16 \n"
"fmla v22.4s, v17.4s, %46.s[3] \n"
"fmla v23.4s, v17.4s, %47.s[3] \n"
"fmla v24.4s, v17.4s, %48.s[3] \n"
"fmla v25.4s, v17.4s, %49.s[3] \n"
"st1 {v22.4s}, [%5], #16 \n"
"st1 {v23.4s}, [%6], #16 \n"
"st1 {v24.4s}, [%7], #16 \n"
"st1 {v25.4s}, [%8], #16 \n" "st1 {v25.4s}, [%8], #16 \n"
"prfm pldl1keep, [%9, #128] \n"
"ld1 {v16.4s}, [%9], #16 \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v18.4s}, [%1] \n"
"prfm pldl1keep, [%2, #128] \n"
"ld1 {v19.4s}, [%2] \n"
"subs %w0, %w0, #1 \n" "subs %w0, %w0, #1 \n"
"bne 0b \n" "bne 0b \n"
: "=r"(nw), // 0 : "=r"(nw), // 0
...@@ -369,20 +363,20 @@ inline void GemmTile(const float *A, ...@@ -369,20 +363,20 @@ inline void GemmTile(const float *A,
"15"(b_ptr6), // 32 "15"(b_ptr6), // 32
"16"(b_ptr7), // 33 "16"(b_ptr7), // 33
"w"(a0), // 34 "w"(a0), // 34
"w"(a1), // 35 "w"(a2), // 35
"w"(a2), // 36 "w"(a4), // 36
"w"(a3), // 37 "w"(a6), // 37
"w"(a4), // 38 "w"(a8), // 38
"w"(a5), // 39 "w"(a10), // 39
"w"(a6), // 40 "w"(a12), // 40
"w"(a7), // 41 "w"(a14), // 41
"w"(a8), // 42 "w"(a1), // 42
"w"(a9), // 43 "w"(a3), // 43
"w"(a10), // 44 "w"(a5), // 44
"w"(a11), // 45 "w"(a7), // 45
"w"(a12), // 46 "w"(a9), // 46
"w"(a13), // 47 "w"(a11), // 47
"w"(a14), // 48 "w"(a13), // 48
"w"(a15) // 49 "w"(a15) // 49
: "cc", "memory", : "cc", "memory",
"v16", "v16",
...@@ -585,7 +579,6 @@ void Gemm(const float *A, ...@@ -585,7 +579,6 @@ void Gemm(const float *A,
} }
memset(C, 0, sizeof(float) * batch * height * width); memset(C, 0, sizeof(float) * batch * height * width);
// It is better to use large block size if it fits for fast cache. // It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3). // the block size should be sqrt(32k / sizeof(T) / 3).
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册