From 8493f20ebc4e96313301066b9ed328829e882c6d Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Wed, 27 Nov 2019 13:56:23 +0800 Subject: [PATCH] Polish the codes of fc when needs padding (#21378) test=develop --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 6 +++--- paddle/fluid/operators/math/fc.cc | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index ed8128c3307..8eccad1ee0e 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -94,20 +94,20 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { auto* weight = scope->FindVar(w->Name())->GetMutable(); auto place = weight->place(); bool use_gpu = Get("use_gpu"); - auto weight_data = weight->data(); + auto* weight_data = weight->data(); auto weight_dims = weight->dims(); int weight_num = product(weight_dims); int w_h = weight_dims[0]; int w_w = weight_dims[1]; if (!use_gpu) { if (w_h % 128 == 0 && w_w % 128 == 0) { - float* weight_data_tmp = new float[weight_num]; + auto* weight_data_tmp = new float[weight_num]; for (int i = 0; i < w_h; i++) { memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w, w_w * sizeof(float)); } weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); - auto weight_data_new = + auto* weight_data_new = weight->mutable_data(platform::CPUPlace()); for (int i = 0; i < w_h; i++) { memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w, diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc index 38acd7ba948..b50b4435d7b 100644 --- a/paddle/fluid/operators/math/fc.cc +++ b/paddle/fluid/operators/math/fc.cc @@ -30,28 +30,28 @@ class FCFunctor { auto blas = math::GetBlas(context); framework::Tensor Y1; T* Y1_data = nullptr; - if (N % 128 == 0 && K % 128 == 0) { + auto padding = N % 128 == 0 && K % 128 == 0; + if (padding) { const int NN = N + 4; const int KK = K + 4; framework::Tensor X1; - T* X1_data = X1.Resize({M * KK}).mutable_data(platform::CPUPlace()); - Y1_data = Y1.Resize({M * (N + 4)}).mutable_data(platform::CPUPlace()); + T* X1_data = X1.mutable_data({M * KK}, platform::CPUPlace()); + Y1_data = Y1.mutable_data({M * (N + 4)}, platform::CPUPlace()); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (int i = 0; i < M; i++) { - memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0])); + memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); } framework::Tensor W1; T* W1_data = nullptr; if (!padding_weights) { - W1_data = W1.Resize({(K + 4) * (N + 4)}) - .mutable_data(platform::CPUPlace()); + W1_data = W1.mutable_data({(K + 4) * (N + 4)}, platform::CPUPlace()); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (int i = 0; i < K; i++) { - memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0])); + memcpy(W1_data + i * NN, W + i * N, N * sizeof(T)); } } blas.GEMM(false, false, M, N, K, static_cast(1.0), X1_data, KK, @@ -61,12 +61,12 @@ class FCFunctor { blas.MatMul(M, N, K, X, W, Y); } if (B == NULL) { - if (N % 128 == 0 && K % 128 == 0) { + if (padding) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (int i = 0; i < M; i++) { - memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0])); + memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); } } PADDLE_ENFORCE_EQ(relu, false, @@ -80,7 +80,7 @@ class FCFunctor { .At(N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; - T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; + T* src = (padding) ? Y1_data + i * (N + 4) : dst; compute(B, src, dst, N); } } else { @@ -92,7 +92,7 @@ class FCFunctor { #endif for (int i = 0; i < M; i++) { T* dst = Y + i * N; - T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; + T* src = (padding) ? Y1_data + i * (N + 4) : dst; compute(B, src, dst, N); } } -- GitLab