From 4e67fe6a122636bc84b2f8df6d5f94feb5ed1a78 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Fri, 16 Nov 2018 10:09:40 +0000
Subject: [PATCH] refine act and vxx with all size

---
 paddle/fluid/operators/math/jit_code.cc | 147 ++++++++++--------------
 1 file changed, 60 insertions(+), 87 deletions(-)

diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc
index 9efd4e81748..a5eef019c89 100644
--- a/paddle/fluid/operators/math/jit_code.cc
+++ b/paddle/fluid/operators/math/jit_code.cc
@@ -60,60 +60,53 @@ void VXXJitCode::generate() {
     offset += sizeof(float) * YMM_FLOAT_BLOCK;
   }
   int rest = num_ % YMM_FLOAT_BLOCK;
-  if (rest >= 4) {
-    if (scalar_index_ != 1) {
-      vmovups(xmm_src1, ptr[param1 + offset]);
-    }
-    if (scalar_index_ != 2) {
-      vmovups(xmm_src2, ptr[param2 + offset]);
-    }
-    if (type_ == operand_type::mul) {
-      vmulps(xmm_dst, xmm_src1, xmm_src2);
-    } else if (type_ == operand_type::add) {
-      vaddps(xmm_dst, xmm_src1, xmm_src2);
-    }
-    if (with_relu_) {
-      vmaxps(xmm_dst, xmm_zero, xmm_dst);
-    }
-    vmovups(ptr[param3 + offset], xmm_dst);
-    offset += sizeof(float) * 4;
-    rest -= 4;
-  }
-  if (rest >= 2) {
-    if (scalar_index_ != 1) {
-      vmovq(xmm_src1, ptr[param1 + offset]);
-    }
-    if (scalar_index_ != 2) {
-      vmovq(xmm_src2, ptr[param2 + offset]);
+  int block = XMM_FLOAT_BLOCK;
+  while (rest > 0) {
+    if (rest >= 4) {
+      if (scalar_index_ != 1) {
+        vmovups(xmm_src1, ptr[param1 + offset]);
+      }
+      if (scalar_index_ != 2) {
+        vmovups(xmm_src2, ptr[param2 + offset]);
+      }
+    } else if (rest >= 2) {
+      if (scalar_index_ != 1) {
+        vmovq(xmm_src1, ptr[param1 + offset]);
+      }
+      if (scalar_index_ != 2) {
+        vmovq(xmm_src2, ptr[param2 + offset]);
+      }
+    } else {
+      if (scalar_index_ != 1) {
+        vmovss(xmm_src1, ptr[param1 + offset]);
+      }
+      if (scalar_index_ != 2) {
+        vmovss(xmm_src2, ptr[param2 + offset]);
+      }
     }
-    if (type_ == operand_type::mul) {
-      vmulps(xmm_dst, xmm_src1, xmm_src2);
-    } else if (type_ == operand_type::add) {
-      vaddps(xmm_dst, xmm_src1, xmm_src2);
+    switch (type_) {
+      case operand_type::mul:
+        vmulps(xmm_dst, xmm_src1, xmm_src2);
+        break;
+      case operand_type::add:
+        vaddps(xmm_dst, xmm_src1, xmm_src2);
+        break;
+      default:
+        break;
     }
     if (with_relu_) {
       vmaxps(xmm_dst, xmm_zero, xmm_dst);
     }
-    vmovq(ptr[param3 + offset], xmm_dst);
-    offset += sizeof(float) * 2;
-    rest -= 2;
-  }
-  if (rest > 0) {
-    if (scalar_index_ != 1) {
-      vmovss(xmm_src1, ptr[param1 + offset]);
-    }
-    if (scalar_index_ != 2) {
-      vmovss(xmm_src2, ptr[param2 + offset]);
-    }
-    if (type_ == operand_type::mul) {
-      vmulss(xmm_dst, xmm_src1, xmm_src2);
-    } else if (type_ == operand_type::add) {
-      vaddss(xmm_dst, xmm_src1, xmm_src2);
+    if (rest >= 4) {
+      vmovups(ptr[param3 + offset], xmm_dst);
+    } else if (rest >= 2) {
+      vmovq(ptr[param3 + offset], xmm_dst);
+    } else {
+      vmovss(ptr[param3 + offset], xmm_dst);
     }
-    if (with_relu_) {
-      vmaxps(xmm_dst, xmm_zero, xmm_dst);
-    }
-    vmovss(ptr[param3 + offset], xmm_dst);
+    offset += sizeof(float) * block;
+    rest -= block;
+    block /= 2;
   }
   ret();
 }
@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
 
 bool VActJitCode::init(int d, operand_type type) {
   bool ok = MayIUse(avx);
-  if (type == operand_type::relu) {
+  if (type == operand_type::relu || type == operand_type::exp) {
+    // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
     return ok;
-  } else if (type == operand_type::exp) {
-    // exp is slower than mkl when d >= 256
-    return ok;  //&& d % 4 == 0 && d < 256;
   } else {
     // TODO(TJ): support more
     return ok && d % 8 == 0;
@@ -412,24 +403,15 @@ void VActJitCode::generate() {
     return;
   }
   int rest = num_ % YMM_FLOAT_BLOCK;
-  if (rest >= 4) {
-    vmovups(xmm_src, ptr[param1 + offset]);
-    switch (type_) {
-      case operand_type::relu:
-        relu_xmm(xmm_dst, xmm_src, xmm_zero);
-        break;
-      case operand_type::exp:
-        exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
-        break;
-      default:
-        break;
+  int block = XMM_FLOAT_BLOCK;
+  while (rest > 0) {
+    if (rest >= 4) {
+      vmovups(xmm_src, ptr[param1 + offset]);
+    } else if (rest >= 2) {
+      vmovq(xmm_src, ptr[param1 + offset]);
+    } else {
+      vmovss(xmm_src, ptr[param1 + offset]);
     }
-    vmovups(ptr[param2 + offset], xmm_dst);
-    offset += sizeof(float) * 4;
-    rest -= 4;
-  }
-  if (rest >= 2) {
-    vmovq(xmm_src, ptr[param1 + offset]);
     switch (type_) {
       case operand_type::relu:
         relu_xmm(xmm_dst, xmm_src, xmm_zero);
@@ -440,25 +422,16 @@ void VActJitCode::generate() {
       default:
         break;
     }
-    vmovq(ptr[param2 + offset], xmm_dst);
-    offset += sizeof(float) * 2;
-    rest -= 2;
-  }
-  if (rest > 0) {
-    // vmovups();
-    vmovss(xmm_src, ptr[param1 + offset]);
-
-    switch (type_) {
-      case operand_type::relu:
-        relu_xmm(xmm_dst, xmm_src, xmm_zero);
-        break;
-      case operand_type::exp:
-        exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
-        break;
-      default:
-        break;
+    if (rest >= 4) {
+      vmovups(ptr[param2 + offset], xmm_dst);
+    } else if (rest >= 2) {
+      vmovq(ptr[param2 + offset], xmm_dst);
+    } else {
+      vmovss(ptr[param2 + offset], xmm_dst);
     }
-    vmovss(ptr[param2 + offset], xmm_dst);
+    offset += sizeof(float) * block;
+    rest -= block;
+    block /= 2;
   }
   ret();
 }
-- 
GitLab