提交 a19b3225 编写于 作者: T tensor-tang

fix jitcode small size

test=develop
上级 4dbdfa60
...@@ -59,9 +59,10 @@ void VXXJitCode::generate() { ...@@ -59,9 +59,10 @@ void VXXJitCode::generate() {
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
int block = XMM_FLOAT_BLOCK;
while (rest > 0) { while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
block = 4;
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
} }
...@@ -69,6 +70,7 @@ void VXXJitCode::generate() { ...@@ -69,6 +70,7 @@ void VXXJitCode::generate() {
vmovups(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src2, ptr[param2 + offset]);
} }
} else if (rest >= 2) { } else if (rest >= 2) {
block = 2;
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src1, ptr[param1 + offset]);
} }
...@@ -76,6 +78,7 @@ void VXXJitCode::generate() { ...@@ -76,6 +78,7 @@ void VXXJitCode::generate() {
vmovq(xmm_src2, ptr[param2 + offset]); vmovq(xmm_src2, ptr[param2 + offset]);
} }
} else { } else {
block = 1;
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src1, ptr[param1 + offset]);
} }
...@@ -105,7 +108,6 @@ void VXXJitCode::generate() { ...@@ -105,7 +108,6 @@ void VXXJitCode::generate() {
} }
offset += sizeof(float) * block; offset += sizeof(float) * block;
rest -= block; rest -= block;
block /= 2;
} }
ret(); ret();
} }
...@@ -167,13 +169,16 @@ void VActJitCode::generate() { ...@@ -167,13 +169,16 @@ void VActJitCode::generate() {
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
int block = XMM_FLOAT_BLOCK;
while (rest > 0) { while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
block = 4;
vmovups(xmm_src, ptr[param1 + offset]); vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) { } else if (rest >= 2) {
block = 2;
vmovq(xmm_src, ptr[param1 + offset]); vmovq(xmm_src, ptr[param1 + offset]);
} else { } else {
block = 1;
vmovss(xmm_src, ptr[param1 + offset]); vmovss(xmm_src, ptr[param1 + offset]);
} }
switch (type_) { switch (type_) {
...@@ -201,7 +206,6 @@ void VActJitCode::generate() { ...@@ -201,7 +206,6 @@ void VActJitCode::generate() {
} }
offset += sizeof(float) * block; offset += sizeof(float) * block;
rest -= block; rest -= block;
block /= 2;
} }
ret(); ret();
} }
......
...@@ -69,7 +69,7 @@ void vrelu_intri8(const int n, const float* x, float* y) { ...@@ -69,7 +69,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
TEST(JitKernel, vrelu) { TEST(JitKernel, vrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -10.f, 1.f); RandomVec<float>(d, x.data(), -10.f, 1.f);
...@@ -159,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) { ...@@ -159,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) { TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 12, 15, 16, 20, 30, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f); RandomVec<float>(d, x.data(), -2.f, 2.f);
...@@ -234,7 +234,7 @@ void vsigmoid_better( ...@@ -234,7 +234,7 @@ void vsigmoid_better(
TEST(JitKernel, vsigmoid) { TEST(JitKernel, vsigmoid) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f); RandomVec<float>(d, x.data(), -2.f, 2.f);
...@@ -298,7 +298,7 @@ void vtanh_better( ...@@ -298,7 +298,7 @@ void vtanh_better(
TEST(JitKernel, vtanh) { TEST(JitKernel, vtanh) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f); RandomVec<float>(d, x.data(), -2.f, 2.f);
...@@ -389,7 +389,7 @@ void lstm_ctht_better( ...@@ -389,7 +389,7 @@ void lstm_ctht_better(
TEST(JitKernel, lstm) { TEST(JitKernel, lstm) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) {
int d4 = d * 4; int d4 = d * 4;
int d3 = d * 3; int d3 = d * 3;
std::vector<float> x(d4), xref(d4); std::vector<float> x(d4), xref(d4);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册