提交 5af0c7ba 编写于 作者: G GaoWei8 提交者: Yiqun Liu

Modify padding strategy: remove weight copy in fc padding (#21650)

test=develop
上级 10018f15
...@@ -30,8 +30,7 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -30,8 +30,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
framework::Tensor Y1; framework::Tensor Y1;
T* Y1_data = nullptr; T* Y1_data = nullptr;
auto padding = N % 128 == 0 && K % 128 == 0; if (padding_weights) {
if (padding) {
const int NN = N + 4; const int NN = N + 4;
const int KK = K + 4; const int KK = K + 4;
framework::Tensor X1; framework::Tensor X1;
...@@ -43,25 +42,13 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -43,25 +42,13 @@ class FCFunctor<platform::CPUDeviceContext, T> {
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
} }
framework::Tensor W1; blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK, W, NN,
T* W1_data = nullptr; static_cast<T>(0.0), Y1_data, NN);
if (!padding_weights) {
W1_data = W1.mutable_data<T>({(K + 4) * (N + 4)}, platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < K; i++) {
memcpy(W1_data + i * NN, W + i * N, N * sizeof(T));
}
}
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK,
(padding_weights ? W : W1_data), NN, static_cast<T>(0.0),
Y1_data, NN);
} else { } else {
blas.MatMul(M, N, K, X, W, Y); blas.MatMul(M, N, K, X, W, Y);
} }
if (B == NULL) { if (B == NULL) {
if (padding) { if (padding_weights) {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
...@@ -80,7 +67,7 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -80,7 +67,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
.At(N); .At(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
T* src = (padding) ? Y1_data + i * (N + 4) : dst; T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N); compute(B, src, dst, N);
} }
} else { } else {
...@@ -92,7 +79,7 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -92,7 +79,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
T* src = (padding) ? Y1_data + i * (N + 4) : dst; T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N); compute(B, src, dst, N);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册