提交 f65ddff8 编写于 作者: T tensor-tang

unify act jitcode of relu, exp, sigmoid and tanh

上级 6a159071
...@@ -118,40 +118,6 @@ void VXXJitCode::generate() { ...@@ -118,40 +118,6 @@ void VXXJitCode::generate() {
ret(); ret();
} }
bool ReluJitCode::init(int d) { return MayIUse(avx); }
void ReluJitCode::generate() {
int offset = 0;
vxorps(ymm_zero, ymm_zero, ymm_zero);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
vmaxps(ymm_dst, ymm_zero, ymm_src);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovss(ptr[param2 + offset], xmm_dst);
}
ret();
}
#define ALIGN32 __attribute__((aligned(32))) #define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f #define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f #define EXP_LOW -88.3762626647949f
...@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = { ...@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0}; static int g_tmp_mem[16] ALIGN32 = {0};
bool VExpJitCode::init(int d) { bool VActJitCode::init(int d, operand_type type) {
return MayIUse(avx) && d == 8; // only 8 yet bool ok = MayIUse(avx);
if (type == operand_type::relu) {
return ok;
} else {
return ok && d == 8; // only 8 yet
}
} }
void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
// use reg rax and ymm 2~5 vmaxps(ymm_dst, ymm_zero, ymm_src);
reg64_t reg_ptr_global = rax; }
ymm_t ymm_fx = ymm_t(2);
ymm_t ymm_fy = ymm_t(3); void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
ymm_t ymm_mask = ymm_t(4); int fy_idx, int mask_idx, int tmp_idx) {
ymm_t ymm_tmp = ymm_t(5);
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal
ymm_t ymm_fx = ymm_t(fx_idx);
ymm_t ymm_fy = ymm_t(fy_idx);
ymm_t ymm_mask = ymm_t(mask_idx);
ymm_t ymm_tmp = ymm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global); push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
...@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { ...@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
void VExpJitCode::generate() { void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int offset = 0; int fy_idx, int mask_idx, int tmp_idx) {
vmovups(ymm_src, ptr[param1 + offset]); // y = 1 / (1 + e^-x)
exp_ymm(ymm_src, ymm_dst); ymm_t ymm_tmp = ymm_t(tmp_idx);
vmovups(ptr[param2 + offset], ymm_dst);
ret();
}
bool VSigmoidJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
// use ymm2
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
ymm_t ymm_tmp = ymm_t(2);
push(reg_ptr_global); push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
...@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { ...@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
vmaxps(ymm_src, ymm_src, ymm_tmp); vmaxps(ymm_src, ymm_src, ymm_tmp);
vxorps(ymm_tmp, ymm_tmp, ymm_tmp); vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
vsubps(ymm_src, ymm_tmp, ymm_src); vsubps(ymm_src, ymm_tmp, ymm_src);
exp_ymm(ymm_src, ymm_dst); exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp); vaddps(ymm_dst, ymm_dst, ymm_tmp);
vdivps(ymm_dst, ymm_tmp, ymm_dst); vdivps(ymm_dst, ymm_tmp, ymm_dst);
pop(reg_ptr_global); pop(reg_ptr_global);
} }
void VSigmoidJitCode::generate() { void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int offset = 0; int fy_idx, int mask_idx, int tmp_idx) {
vmovups(ymm_src, ptr[param1 + offset]);
sigmoid_ymm(ymm_src, ymm_dst);
vmovups(ptr[param2 + offset], ymm_dst);
ret();
}
bool VTanhJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
// y = 2 / (1 + e^(-2x)) - 1 // y = 2 / (1 + e^(-2x)) - 1
// use ymm2, ymm3 ymm_t ymm_tmp = ymm_t(tmp_idx);
ymm_t ymm_zero = ymm_t(mask_idx);
reg64_t reg_ptr_global = rax; reg64_t reg_ptr_global = rax;
ymm_t ymm_tmp = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
push(reg_ptr_global); push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(ymm_zero, ymm_zero, ymm_zero); vxorps(ymm_zero, ymm_zero, ymm_zero);
vsubps(ymm_tmp, ymm_zero, ymm_tmp); vsubps(ymm_tmp, ymm_zero, ymm_tmp);
vmulps(ymm_src, ymm_src, ymm_tmp); vmulps(ymm_src, ymm_src, ymm_tmp);
exp_ymm(ymm_src, ymm_dst); exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp); vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
...@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { ...@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
void VTanhJitCode::generate() { void VActJitCode::generate() {
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_zero = ymm_t(2);
if (type_ == operand_type::relu) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
int offset = 0; int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]); for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vtanh_ymm(ymm_src, ymm_dst); vmovups(ymm_src, ptr[param1 + offset]);
vmovups(ptr[param2 + offset], ymm_dst); switch (type_) {
case operand_type::relu:
relu_ymm(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
break;
}
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
if (type_ != operand_type::relu) {
// TODO(TJ): remove me
ret();
return;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovss(ptr[param2 + offset], xmm_dst);
}
ret(); ret();
} }
......
...@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm; ...@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm; using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label; using Label = Xbyak::Label;
typedef enum { mul = 0, add } operand_type; typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
} operand_type;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode { class VXXJitCode : public JitCode {
...@@ -85,87 +94,65 @@ class VXXJitCode : public JitCode { ...@@ -85,87 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3); ymm_t ymm_zero = ymm_t(3);
}; };
class ReluJitCode : public JitCode { class VActJitCode : public JitCode {
public: public:
DECLARE_JIT_CODE(ReluJitCode); const char* name() const override {
explicit ReluJitCode(int d, size_t code_size = 256 * 1024, std::string base = "VActJitCode";
void* code_ptr = nullptr) switch (type_) {
: JitCode(code_size, code_ptr), num_(d) {} case operand_type::relu:
static bool init(int d); base += "_Relu";
void generate() override; break;
case operand_type::exp:
private: base += "_Exp";
int num_; break;
reg64_t param1{abi_param1}; case operand_type::sigmoid:
reg64_t param2{abi_param2}; base += "_Sigmoid";
break;
xmm_t xmm_zero = xmm_t(0); case operand_type::tanh:
xmm_t xmm_src = xmm_t(1); base += "_Tanh";
xmm_t xmm_dst = xmm_t(1); break;
case operand_type::identity:
ymm_t ymm_zero = ymm_t(0); base += "_Identity";
ymm_t ymm_src = ymm_t(1); break;
ymm_t ymm_dst = ymm_t(1); default:
}; break;
}
return base.c_str();
}
class VExpJitCode : public JitCode { explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
public:
DECLARE_JIT_CODE(VExpJitCode);
explicit VExpJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {} : JitCode(code_size, code_ptr), num_(d), type_(type) {}
static bool init(int d); static bool init(int d, operand_type type);
void generate() override; void generate() override;
protected: protected:
// compute exp with ymm // compute relu with ymm
void exp_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
const Xbyak::Ymm& zero);
private: // compute exp with ymm
int num_; void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
reg64_t param1{abi_param1}; int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
reg64_t param2{abi_param2};
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
};
class VSigmoidJitCode : public VExpJitCode {
public:
DECLARE_JIT_CODE(VSigmoidJitCode);
explicit VSigmoidJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
// compute sigmoid with ymm // compute sigmoid with ymm
void sigmoid_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
private: // compute tanh with ymm
int num_; void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
reg64_t param1{abi_param1}; int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
reg64_t param2{abi_param2};
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
};
class VTanhJitCode : public VExpJitCode { protected:
public:
DECLARE_JIT_CODE(VTanhJitCode);
explicit VTanhJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
// compute sigmoid with ymm
void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
private:
int num_; int num_;
operand_type type_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
xmm_t xmm_src = xmm_t(0);
ymm_t ymm_src = ymm_t(0); ymm_t ymm_src = ymm_t(0);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(1);
}; };
......
...@@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel<T> {
size_t sz = 96 /* init size */ + size_t sz = 96 /* init size */ +
d / AVX_FLOAT_BLOCK * 4 /* instructions */ * d / AVX_FLOAT_BLOCK * 4 /* instructions */ *
8 /* average bytes for each instruction */; 8 /* average bytes for each instruction */;
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return; return;
} }
...@@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VReluKernelImpl<float>::useJIT(int d) { bool VReluKernelImpl<float>::useJIT(int d) {
return gen::ReluJitCode::init(d); return gen::VActJitCode::init(d, gen::operand_type::relu);
} }
#endif #endif
......
...@@ -116,7 +116,8 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -116,7 +116,8 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
jitcode_.reset(new gen::VExpJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return; return;
} }
...@@ -135,14 +136,14 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -135,14 +136,14 @@ class VExpKernelImpl : public VExpKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VExpJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VExpKernelImpl<float>::useJIT(int d) { bool VExpKernelImpl<float>::useJIT(int d) {
return gen::VExpJitCode::init(d); return gen::VActJitCode::init(d, gen::operand_type::exp);
} }
#endif #endif
...@@ -169,7 +170,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -169,7 +170,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
jitcode_.reset(new gen::VSigmoidJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return; return;
} }
...@@ -190,14 +192,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -190,14 +192,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VSigmoidJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VSigmoidKernelImpl<float>::useJIT(int d) { bool VSigmoidKernelImpl<float>::useJIT(int d) {
return gen::VSigmoidJitCode::init(d); return gen::VActJitCode::init(d, gen::operand_type::sigmoid);
} }
#endif #endif
...@@ -223,7 +225,8 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -223,7 +225,8 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
jitcode_.reset(new gen::VTanhJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return; return;
} }
...@@ -244,14 +247,14 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -244,14 +247,14 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VTanhJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VTanhKernelImpl<float>::useJIT(int d) { bool VTanhKernelImpl<float>::useJIT(int d) {
return gen::VTanhJitCode::init(d); return gen::VActJitCode::init(d, gen::operand_type::tanh);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册