提交 8493f20e 编写于 作者: G GaoWei8 提交者: Yiqun Liu

Polish the codes of fc when needs padding (#21378)

test=develop
上级 5d7d5482
...@@ -94,20 +94,20 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -94,20 +94,20 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>(); auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
auto place = weight->place(); auto place = weight->place();
bool use_gpu = Get<bool>("use_gpu"); bool use_gpu = Get<bool>("use_gpu");
auto weight_data = weight->data<float>(); auto* weight_data = weight->data<float>();
auto weight_dims = weight->dims(); auto weight_dims = weight->dims();
int weight_num = product(weight_dims); int weight_num = product(weight_dims);
int w_h = weight_dims[0]; int w_h = weight_dims[0];
int w_w = weight_dims[1]; int w_w = weight_dims[1];
if (!use_gpu) { if (!use_gpu) {
if (w_h % 128 == 0 && w_w % 128 == 0) { if (w_h % 128 == 0 && w_w % 128 == 0) {
float* weight_data_tmp = new float[weight_num]; auto* weight_data_tmp = new float[weight_num];
for (int i = 0; i < w_h; i++) { for (int i = 0; i < w_h; i++) {
memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w, memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w,
w_w * sizeof(float)); w_w * sizeof(float));
} }
weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4});
auto weight_data_new = auto* weight_data_new =
weight->mutable_data<float>(platform::CPUPlace()); weight->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < w_h; i++) { for (int i = 0; i < w_h; i++) {
memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w, memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w,
......
...@@ -30,28 +30,28 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -30,28 +30,28 @@ class FCFunctor<platform::CPUDeviceContext, T> {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
framework::Tensor Y1; framework::Tensor Y1;
T* Y1_data = nullptr; T* Y1_data = nullptr;
if (N % 128 == 0 && K % 128 == 0) { auto padding = N % 128 == 0 && K % 128 == 0;
if (padding) {
const int NN = N + 4; const int NN = N + 4;
const int KK = K + 4; const int KK = K + 4;
framework::Tensor X1; framework::Tensor X1;
T* X1_data = X1.Resize({M * KK}).mutable_data<T>(platform::CPUPlace()); T* X1_data = X1.mutable_data<T>({M * KK}, platform::CPUPlace());
Y1_data = Y1.Resize({M * (N + 4)}).mutable_data<T>(platform::CPUPlace()); Y1_data = Y1.mutable_data<T>({M * (N + 4)}, platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0])); memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
} }
framework::Tensor W1; framework::Tensor W1;
T* W1_data = nullptr; T* W1_data = nullptr;
if (!padding_weights) { if (!padding_weights) {
W1_data = W1.Resize({(K + 4) * (N + 4)}) W1_data = W1.mutable_data<T>({(K + 4) * (N + 4)}, platform::CPUPlace());
.mutable_data<T>(platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0])); memcpy(W1_data + i * NN, W + i * N, N * sizeof(T));
} }
} }
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK, blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK,
...@@ -61,12 +61,12 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -61,12 +61,12 @@ class FCFunctor<platform::CPUDeviceContext, T> {
blas.MatMul(M, N, K, X, W, Y); blas.MatMul(M, N, K, X, W, Y);
} }
if (B == NULL) { if (B == NULL) {
if (N % 128 == 0 && K % 128 == 0) { if (padding) {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0])); memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T));
} }
} }
PADDLE_ENFORCE_EQ(relu, false, PADDLE_ENFORCE_EQ(relu, false,
...@@ -80,7 +80,7 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -80,7 +80,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
.At(N); .At(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; T* src = (padding) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N); compute(B, src, dst, N);
} }
} else { } else {
...@@ -92,7 +92,7 @@ class FCFunctor<platform::CPUDeviceContext, T> { ...@@ -92,7 +92,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; T* src = (padding) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N); compute(B, src, dst, N);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册