提交 80a3f082 编写于 作者: D dolphin8

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -85,6 +85,12 @@ class Tensor { ...@@ -85,6 +85,12 @@ class Tensor {
} }
} }
Tensor(const Tensor &inTensor) {
this->dims_ = inTensor.dims_;
this->holder_ = inTensor.holder_;
this->offset_ = inTensor.offset_;
}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
inline T *data() { inline T *data() {
...@@ -169,7 +175,9 @@ class Tensor { ...@@ -169,7 +175,9 @@ class Tensor {
/*! The internal of two tensors share the same memory block. */ /*! The internal of two tensors share the same memory block. */
inline Tensor &ShareDataWith(const Tensor &src) { inline Tensor &ShareDataWith(const Tensor &src) {
src.check_memory_size(); src.check_memory_size();
*this = src; if (holder_.get() != src.holder_.get()) {
*this = src;
}
return *this; return *this;
} }
...@@ -198,7 +206,6 @@ class Tensor { ...@@ -198,7 +206,6 @@ class Tensor {
size_t base = numel() / dims_[0]; size_t base = numel() / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
dst.set_layout(layout_);
DDim dst_dims = dims_; DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx; dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims); dst.Resize(dst_dims);
...@@ -227,10 +234,6 @@ class Tensor { ...@@ -227,10 +234,6 @@ class Tensor {
"Tensor's dims_ is out of bound. "); "Tensor's dims_ is out of bound. ");
} }
inline DataLayout layout() const { return layout_; }
inline void set_layout(const DataLayout layout) { layout_ = layout; }
private: private:
/** /**
* @note Placeholder hides type T, so it doesn't appear as a * @note Placeholder hides type T, so it doesn't appear as a
...@@ -288,21 +291,6 @@ class Tensor { ...@@ -288,21 +291,6 @@ class Tensor {
DDim dims_; DDim dims_;
/**
* @brief the layout of memory block, default is NHWC.
*
* @note the memory allocation order, describe how weight/data is
* stored
* For example, in 4-D Tensor(rank=4), there are three
* commonly
* used layout. They are
* NCHW, NHWC, CHWN.
* N,C,H,W for respectively the batch size, the number of
* feature maps, the height, the width.
*/
DataLayout layout_ = DataLayout::kNHWC;
/** /**
* @brief A PlaceHolder may be shared by more than one tensor. * @brief A PlaceHolder may be shared by more than one tensor.
* *
......
...@@ -20,7 +20,6 @@ namespace framework { ...@@ -20,7 +20,6 @@ namespace framework {
void TensorCopy(const Tensor &src, Tensor *dst) { void TensorCopy(const Tensor &src, Tensor *dst) {
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(src.type()); auto dst_ptr = dst->mutable_data(src.type());
auto size = src.numel() * SizeOfType(src.type()); auto size = src.numel() * SizeOfType(src.type());
......
...@@ -477,7 +477,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict( ...@@ -477,7 +477,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
printf("====================[---------]======================\n"); printf("====================[---------]======================\n");
#endif #endif
return std::shared_ptr<framework::Tensor>(output_tensor); return std::make_shared<framework::Tensor>(framework::Tensor(*output_tensor));
} }
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict( std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
......
...@@ -14,17 +14,17 @@ limitations under the License. */ ...@@ -14,17 +14,17 @@ limitations under the License. */
#pragma once #pragma once
#include "t_malloc.h" #include "memory/t_malloc.h"
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
namespace paddle_mobile { namespace paddle_mobile {
namespace memory { namespace memory {
const int MALLOC_ALIGN = 16; const int MALLOC_ALIGN = 64;
void Copy(void *dst, const void *src, size_t num) { void Copy(void *dst, const void *src, size_t num) {
std::memcpy(dst, src, num); std::memcpy(dst, src, num);
}; }
void *Alloc(size_t size) { void *Alloc(size_t size) {
size_t offset = sizeof(void *) + MALLOC_ALIGN - 1; size_t offset = sizeof(void *) + MALLOC_ALIGN - 1;
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/gemm.h" #include "operators/math/gemm.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#ifndef X86 #ifndef X86
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
...@@ -214,7 +216,7 @@ void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -214,7 +216,7 @@ void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda,
} }
} }
//计算一个更小的 4 * 4 的 C 矩阵分块 // 计算一个更小的 4 * 4 的 C 矩阵分块
#if defined(IOS) #if defined(IOS)
void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
int ldb, float beta, float *C, int ldc, int mc, int nc) { int ldb, float beta, float *C, int ldc, int mc, int nc) {
...@@ -757,6 +759,10 @@ void sgemm(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -757,6 +759,10 @@ void sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc) { const float *B, int ldb, float beta, float *C, int ldc) {
int i, j, p, mc, nc, kc; int i, j, p, mc, nc, kc;
float beta_; float beta_;
if (m == 1) {
VectorKernel(1, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
return;
}
for (j = 0; j < n; j += NC) { for (j = 0; j < n; j += NC) {
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
for (p = 0; p < k; p += KC) { for (p = 0; p < k; p += KC) {
...@@ -803,6 +809,220 @@ void sgemm_relu(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -803,6 +809,220 @@ void sgemm_relu(int m, int n, int k, float alpha, const float *A, int lda,
} }
} }
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3;
float *c0, *C0;
int volatile kc1 = k / 4;
int volatile kc2 = k % 4;
int volatile nc1 = n / 16;
int _nc1 = n % 16;
int volatile nc2 = _nc1 / 4;
int volatile nc3 = _nc1 % 4;
for (int i = 0; i < kc1; i++) {
a0 = A + i * 4;
b0 = B + i * 4 * ldb;
b1 = b0 + ldb;
b2 = b1 + ldb;
b3 = b2 + ldb;
c0 = bufferC;
asm volatile(
"pld [%[a0], #16] \n\t"
"vld1.32 {q0}, [%[a0]] \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"cmp %[i], #0 \n\t"
"beq i_eq0_%= \n\t"
"bne i_ne0_%= \n\t"
"i_eq0_%=: \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"b gemm_nc1_%= \n\t"
"i_ne0_%=: \n\t"
"pld [%[c0], #64] \n\t"
"vld1.32 {q10, q11}, [%[c0]]! \n\t"
"vld1.32 {q12, q13}, [%[c0]] \n\t"
"sub %[c0], %[c0], #32 \n\t"
"gemm_nc1_%=: \n\t"
"pld [%[b0], #64] \n\t"
"vld1.32 {q2, q3}, [%[b0]]! \n\t"
"vld1.32 {q4, q5}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q3, d0[0] \n\t"
"vmla.f32 q12, q4, d0[0] \n\t"
"vmla.f32 q13, q5, d0[0] \n\t"
"pld [%[b1], #64] \n\t"
"vld1.32 {q2, q3}, [%[b1]]! \n\t"
"vld1.32 {q4, q5}, [%[b1]]! \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q4, d0[1] \n\t"
"vmla.f32 q13, q5, d0[1] \n\t"
"pld [%[b2], #64] \n\t"
"vld1.32 {q2, q3}, [%[b2]]! \n\t"
"vld1.32 {q4, q5}, [%[b2]]! \n\t"
"vmla.f32 q10, q2, d1[0] \n\t"
"vmla.f32 q11, q3, d1[0] \n\t"
"vmla.f32 q12, q4, d1[0] \n\t"
"vmla.f32 q13, q5, d1[0] \n\t"
"pld [%[b3], #64] \n\t"
"vld1.32 {q2, q3}, [%[b3]]! \n\t"
"vld1.32 {q4, q5}, [%[b3]]! \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q4, d1[1] \n\t"
"vmla.f32 q13, q5, d1[1] \n\t"
"vst1.32 {q10, q11}, [%[c0]]! \n\t"
"vst1.32 {q12, q13}, [%[c0]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"cmp %[i], #0 \n\t"
"beq ii_eq0_%= \n\t"
"bne ii_ne0_%= \n\t"
"ii_eq0_%=: \n\t"
"vmov.f32 q10, #0.0 \n\t"
"b gemm_nc2_%= \n\t"
"ii_ne0_%=: \n\t"
"pld [%[c0], #16] \n\t"
"vld1.32 {q10}, [%[c0]] \n\t"
"gemm_nc2_%=: \n\t"
"pld [%[b0], #16] \n\t"
"vld1.32 {q2}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"pld [%[b1], #16] \n\t"
"vld1.32 {q3}, [%[b1]]! \n\t"
"vmla.f32 q10, q3, d0[1] \n\t"
"pld [%[b2], #16] \n\t"
"vld1.32 {q4}, [%[b2]]! \n\t"
"vmla.f32 q10, q4, d1[0] \n\t"
"pld [%[b3], #16] \n\t"
"vld1.32 {q5}, [%[b3]]! \n\t"
"vmla.f32 q10, q5, d1[1] \n\t"
"vst1.32 {q10}, [%[c0]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
: [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3),
[c0] "+r"(c0)
: [a0] "r"(a0), [i] "r"(i), [nc1] "r"(nc1), [nc2] "r"(nc2)
: "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13");
for (int j = 0; j < nc3; j++) {
if (i == 0) {
*c0 = (*a0) * (*b0++);
} else {
*c0 += (*a0) * (*b0++);
}
*c0 += (*(a0 + 1)) * (*b1++);
*c0 += (*(a0 + 2)) * (*b2++);
*c0 += (*(a0 + 3)) * (*b3++);
c0++;
}
}
for (int i = 0; i < kc2; ++i) {
a0 = A + 4 * kc1 + i;
b0 = B + (4 * kc1 + i) * ldb;
c0 = bufferC;
asm volatile(
"pld [%[a0], #16] \n\t"
"vld1.32 {d0}, [%[a0]] \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"pld [%[c0], #64] \n\t"
"vld1.32 {q10, q11}, [%[c0]]! \n\t"
"vld1.32 {q12, q13}, [%[c0]] \n\t"
"sub %[c0], %[c0], #32 \n\t"
"gemm_nc1_%=: \n\t"
"pld [%[b0], #64] \n\t"
"vld1.32 {q2, q3}, [%[b0]]! \n\t"
"vld1.32 {q4, q5}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q3, d0[0] \n\t"
"vmla.f32 q12, q4, d0[0] \n\t"
"vmla.f32 q13, q5, d0[0] \n\t"
"vst1.32 {q10, q11}, [%[c0]]! \n\t"
"vst1.32 {q12, q13}, [%[c0]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"pld [%[c0], #16] \n\t"
"vld1.32 {q10}, [%[c0]] \n\t"
"gemm_nc2_%=: \n\t"
"vld1.32 {q2}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vst1.32 {q10}, [%[c0]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
: [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3),
[c0] "+r"(c0)
: [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2)
: "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13");
for (int j = 0; j < nc3; j++) {
*c0 += (*a0) * (*b0++);
c0++;
}
}
c0 = bufferC;
C0 = C;
for (int i = 0; i < n; i++) {
if (beta == 1.0) {
*C0++ += *c0++;
} else {
*C0++ = *c0++;
}
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -53,6 +53,10 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -53,6 +53,10 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
int first_time); int first_time);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc);
// 计算一个更小的 4 * 4 的 C 矩阵分块 // 计算一个更小的 4 * 4 的 C 矩阵分块
void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B, void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B,
int ldb, float beta, float *C, int ldc, int mc, int nc); int ldb, float beta, float *C, int ldc, int mc, int nc);
......
...@@ -30,7 +30,11 @@ int main() { ...@@ -30,7 +30,11 @@ int main() {
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time(); auto time3 = time();
executor.Predict(input, dims);
for (int i = 0; i < 10; ++i) {
executor.Predict(input, dims);
}
auto time4 = time(); auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n";
return 0; return 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册