diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index cef21348e43c4342c92bd17f4c70ff106e35d99c..7d38d511723ab3e6edfd4aa853bd7f2521ec98e2 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -65,8 +65,9 @@ class VMulKernelImpl : public VMulKernel { explicit VMulKernelImpl(int d) : VMulKernel() { if (useJIT(d)) { - constexpr size_t sz = 256 * 1024; // TODO(TJ): should be related with d - jitcode_.reset(new gen::VMulJitCode(d, sz)); + // roughly estimate the size of code + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 593209d42b5ba102738e5419f2cf77c5ebbcd0da..667a95fe1a247cf9d3d63dae74f7e0fa9c2309ca 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -800,7 +800,7 @@ TEST(JitKernel, pool) { EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != std::dynamic_pointer_cast(pvmul_d)); - const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfany"); + const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfjit4"); EXPECT_EQ(pvmul_f, pvmul_from_key); const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit"); EXPECT_TRUE(pvmul_from_key2 == nullptr);