From 5af0c7ba89810bf78281f7602c8f29dcd515cbcd Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Wed, 11 Dec 2019 09:16:34 +0800 Subject: [PATCH] Modify padding strategy: remove weight copy in fc padding (#21650) test=develop --- paddle/fluid/operators/math/fc.cc | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc index b50b4435d7..27a75f7631 100644 --- a/paddle/fluid/operators/math/fc.cc +++ b/paddle/fluid/operators/math/fc.cc @@ -30,8 +30,7 @@ class FCFunctor { auto blas = math::GetBlas(context); framework::Tensor Y1; T* Y1_data = nullptr; - auto padding = N % 128 == 0 && K % 128 == 0; - if (padding) { + if (padding_weights) { const int NN = N + 4; const int KK = K + 4; framework::Tensor X1; @@ -43,25 +42,13 @@ class FCFunctor { for (int i = 0; i < M; i++) { memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); } - framework::Tensor W1; - T* W1_data = nullptr; - if (!padding_weights) { - W1_data = W1.mutable_data({(K + 4) * (N + 4)}, platform::CPUPlace()); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < K; i++) { - memcpy(W1_data + i * NN, W + i * N, N * sizeof(T)); - } - } - blas.GEMM(false, false, M, N, K, static_cast(1.0), X1_data, KK, - (padding_weights ? W : W1_data), NN, static_cast(0.0), - Y1_data, NN); + blas.GEMM(false, false, M, N, K, static_cast(1.0), X1_data, KK, W, NN, + static_cast(0.0), Y1_data, NN); } else { blas.MatMul(M, N, K, X, W, Y); } if (B == NULL) { - if (padding) { + if (padding_weights) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif @@ -80,7 +67,7 @@ class FCFunctor { .At(N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; - T* src = (padding) ? Y1_data + i * (N + 4) : dst; + T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; compute(B, src, dst, N); } } else { @@ -92,7 +79,7 @@ class FCFunctor { #endif for (int i = 0; i < M; i++) { T* dst = Y + i * N; - T* src = (padding) ? Y1_data + i * (N + 4) : dst; + T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; compute(B, src, dst, N); } } -- GitLab