未验证 提交 e8642c3c 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14265 from tensor-tang/fea/jit/vadd

add vadd, vaddrelu jitcode
...@@ -36,7 +36,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -36,7 +36,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
.template Get<jitkernel::VAddReluKernel<T>>(N); .template Get<jitkernel::VAddReluKernel<T>>(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
vaddrelu->Compute(B, dst, dst); vaddrelu->Compute(B, dst, dst, N);
} }
} else { } else {
const auto& vadd = jitkernel::KernelPool::Instance() const auto& vadd = jitkernel::KernelPool::Instance()
...@@ -47,7 +47,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -47,7 +47,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
vadd->Compute(B, dst, dst); vadd->Compute(B, dst, dst, N);
} }
} }
} }
......
...@@ -24,19 +24,29 @@ namespace gen { ...@@ -24,19 +24,29 @@ namespace gen {
using namespace platform::jit; // NOLINT using namespace platform::jit; // NOLINT
bool VMulJitCode::init(int d) { bool VVVJitCode::init(int d) {
// It's not necessary to use avx512 since it would slow down the frequency // It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound. // and this kernel is not compute bound.
return MayIUse(avx); return MayIUse(avx);
} }
void VMulJitCode::generate() { void VVVJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512 // do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0; int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]); vmovups(ymm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2); vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst); vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * AVX_FLOAT_BLOCK;
} }
...@@ -44,7 +54,14 @@ void VMulJitCode::generate() { ...@@ -44,7 +54,14 @@ void VMulJitCode::generate() {
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovups(ptr[param3 + offset], xmm_dst); vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4; offset += sizeof(float) * 4;
rest -= 4; rest -= 4;
...@@ -52,7 +69,14 @@ void VMulJitCode::generate() { ...@@ -52,7 +69,14 @@ void VMulJitCode::generate() {
if (rest >= 2) { if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]); vmovq(xmm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovq(ptr[param3 + offset], xmm_dst); vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2; offset += sizeof(float) * 2;
rest -= 2; rest -= 2;
...@@ -60,12 +84,18 @@ void VMulJitCode::generate() { ...@@ -60,12 +84,18 @@ void VMulJitCode::generate() {
if (rest > 0) { if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]); vmovss(xmm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2); vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddss(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovss(ptr[param3 + offset], xmm_dst); vmovss(ptr[param3 + offset], xmm_dst);
} }
ret(); ret();
} }
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h" #include "paddle/fluid/operators/math/jit_gen.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -29,28 +29,47 @@ using ymm_t = const Xbyak::Ymm; ...@@ -29,28 +29,47 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm; using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label; using Label = Xbyak::Label;
class VMulJitCode : public JitCode { // function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type;
class VVVJitCode : public JitCode {
public: public:
DECLARE_JIT_CODE(VMulJitCode); const char* name() const override {
explicit VMulJitCode(int d, size_t code_size = 256 * 1024, std::string base = "VVVJitCode";
void* code_ptr = nullptr) if (type_ == operand_type::mul) {
: JitCode(code_size, code_ptr), num_(d) {} base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
base += (with_relu_ ? "_relu" : "");
return base.c_str();
}
explicit VVVJitCode(int d, operand_type type, bool with_relu,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
with_relu_(with_relu) {}
static bool init(int d); static bool init(int d);
void generate() override; void generate() override;
private: private:
int num_; int num_;
operand_type type_;
bool with_relu_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
reg64_t param3{abi_param3}; reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1); xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2); xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1); ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2); ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
}; };
} // namespace gen } // namespace gen
......
...@@ -71,26 +71,26 @@ class VMulKernel : public Kernel { ...@@ -71,26 +71,26 @@ class VMulKernel : public Kernel {
template <typename T> template <typename T>
class VAddKernel : public Kernel { class VAddKernel : public Kernel {
public: public:
virtual void Compute(const T *x, const T *y, T *z) const = 0; void (*Compute)(const T *, const T *, T *, int);
}; };
template <typename T> template <typename T>
class VScalKernel : public Kernel { class VAddReluKernel : public Kernel {
public: public:
virtual void Compute(const T a, const T *x, T *y) const = 0; void (*Compute)(const T *, const T *, T *, int);
virtual void Compute(const T a, T *x) const = 0;
}; };
template <typename T> template <typename T>
class VAddBiasKernel : public Kernel { class VScalKernel : public Kernel {
public: public:
virtual void Compute(const T a, const T *x, T *y) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
}; };
template <typename T> template <typename T>
class VAddReluKernel : public Kernel { class VAddBiasKernel : public Kernel {
public: public:
virtual void Compute(const T *x, const T *y, T *z) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
......
...@@ -42,6 +42,21 @@ void VMulRefer(const T* x, const T* y, T* z, int n) { ...@@ -42,6 +42,21 @@ void VMulRefer(const T* x, const T* y, T* z, int n) {
} }
} }
template <typename T>
void VAddRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n); void VMulMKL(const T* x, const T* y, T* z, int n);
...@@ -50,28 +65,45 @@ template <> ...@@ -50,28 +65,45 @@ template <>
void VMulMKL<float>(const float* x, const float* y, float* z, int n) { void VMulMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z); platform::dynload::vsMul(n, x, y, z);
} }
template <> template <>
void VMulMKL<double>(const double* x, const double* y, double* z, int n) { void VMulMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z); platform::dynload::vdMul(n, x, y, z);
} }
template <typename T>
void VAddMKL(const T* x, const T* y, T* z, int n);
template <>
void VAddMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
#endif #endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */ /* VMUL JitKernel */
template <typename T> template <typename T>
class VMulKernelImpl : public VMulKernel<T> { class VMulKernelImpl : public VMulKernel<T> {
public: public:
static inline std::string name(int d) { DECLARE_STATIC_FUNC;
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit VMulKernelImpl(int d) : VMulKernel<T>() { explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false,
sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return; return;
...@@ -89,14 +121,14 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -89,14 +121,14 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VMulKernelImpl<float>::useJIT(int d) { bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VMulJitCode::init(d); return gen::VVVJitCode::init(d);
} }
#endif #endif
...@@ -112,63 +144,89 @@ bool VMulKernelImpl<double>::useMKL(int d) { ...@@ -112,63 +144,89 @@ bool VMulKernelImpl<double>::useMKL(int d) {
} }
#endif #endif
REGISTER_JITKERNEL(vmul, VMulKernel); /* VAdd JitKernel */
template <typename T>
/* VADD JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddKernelImpl : public VAddKernel<T> { class VAddKernelImpl : public VAddKernel<T> {
public: public:
explicit VAddKernelImpl(int d) : VAddKernel<T>() { this->num_ = d; } DECLARE_STATIC_FUNC;
void Compute(const T* x, const T* y, T* z) const override { explicit VAddKernelImpl(int d) : VAddKernel<T>() {
for (int i = 0; i < this->num_; ++i) { #ifdef PADDLE_WITH_XBYAK
z[i] = x[i] + y[i]; if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VAddMKL<T>;
return;
} }
#endif
this->Compute = VAddRefer<T>;
} }
private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
}; };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_XBYAK
#define MKL_FLOAT(isa, block) \ template <>
template <> \ bool VAddKernelImpl<float>::useJIT(int d) {
void VAddKernelImpl<float, isa, block>::Compute( \ return gen::VVVJitCode::init(d);
const float* x, const float* y, float* z) const { \ }
platform::dynload::vsAdd(this->num_, x, y, z); \ #endif
}
#define MKL_DOUBLE(isa, block) \ #ifdef PADDLE_WITH_MKLML
template <> \ template <>
void VAddKernelImpl<double, isa, block>::Compute( \ bool VAddKernelImpl<float>::useMKL(int d) {
const double* x, const double* y, double* z) const { \ return d > 512;
platform::dynload::vdAdd(this->num_, x, y, z); \ }
}
FOR_EACH_ISA(MKL_FLOAT, kGT16); template <>
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); bool VAddKernelImpl<double>::useMKL(int d) {
return true;
}
#endif #endif
#define INTRI8_FLOAT(isa) \ /* VAddRelu JitKernel */
template <> \ template <typename T>
void VAddKernelImpl<float, isa, kEQ8>::Compute( \ class VAddReluKernelImpl : public VAddReluKernel<T> {
const float* x, const float* y, float* z) const { \ public:
__m256 tmpx, tmpy; \ DECLARE_STATIC_FUNC;
tmpx = _mm256_loadu_ps(x); \ explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
tmpy = _mm256_loadu_ps(y); \ #ifdef PADDLE_WITH_XBYAK
tmpx = _mm256_add_ps(tmpx, tmpy); \ if (useJIT(d)) {
_mm256_storeu_ps(z, tmpx); \ size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
} }
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif #endif
#ifdef __AVX2__ this->Compute = VAddReluRefer<T>;
INTRI8_FLOAT(jit::avx2); }
#endif
#ifdef __AVX512F__ private:
INTRI8_FLOAT(jit::avx512f); std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d);
}
#endif #endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT #undef DECLARE_STATIC_FUNC
#undef MKL_FLOAT
#undef MKL_DOUBLE REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VSCAL JitKernel */ /* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
...@@ -405,98 +463,9 @@ class VIdentityKernelImpl : public VIdentityKernel<T> { ...@@ -405,98 +463,9 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {} void Compute(const T* x, T* y) const override {}
}; };
/* VAddRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx = _mm256_loadu_ps(x); \
__m256 tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
_mm256_storeu_ps(z, tmpy); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(y); \
tmp0 = _mm256_add_ps(tmp0, tmp1); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
tmp1 = _mm256_add_ps(tmp1, tmp2); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(z, tmp0); \
_mm256_storeu_ps(z + 8, tmp1); \
}
#define INTRI_COMMON_FLOAT(isa, block) \
template <> \
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
: VAddReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
} \
template <> \
void VAddReluKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmpx = _mm256_loadu_ps(x + i); \
__m256 tmpy = _mm256_loadu_ps(y + i); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, zeros); \
_mm256_storeu_ps(z + i, tmpy); \
} \
for (int i = this->end_; i < this->num_; ++i) { \
z[i] = x[i] + y[i]; \
z[i] = z[i] > 0 ? z[i] : 0; \
} \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_COMMON_FLOAT(jit::avx, kGT16);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_COMMON_FLOAT(jit::avx2, kGT16);
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL_DEPRECATED(vadd, VAddKernel);
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
......
...@@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cand_d_->Compute(gates, gates); act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
...@@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
/* get fgated and igated*/ /* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_); act_gate_d2_->Compute(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates); act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/ /* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
...@@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */ /* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
......
...@@ -371,7 +371,7 @@ void lstm_ctht_better( ...@@ -371,7 +371,7 @@ void lstm_ctht_better(
vtanh_d->Compute(gates, gates); vtanh_d->Compute(gates, gates);
vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct); vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2); vtanh_d->Compute(ct, gates + d2);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d); vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
...@@ -695,7 +695,7 @@ TEST(JitKernel, vadd) { ...@@ -695,7 +695,7 @@ TEST(JitKernel, vadd) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data); ker->Compute(x_data, y_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -723,8 +723,8 @@ void vaddrelu_better( ...@@ -723,8 +723,8 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu, const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z) { const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z); vadd->Compute(x, y, z, d);
vrelu->Compute(z, z); vrelu->Compute(z, z);
} }
...@@ -752,12 +752,12 @@ TEST(JitKernel, vaddrelu) { ...@@ -752,12 +752,12 @@ TEST(JitKernel, vaddrelu) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS(); auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data); vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d);
} }
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data); ker->Compute(x_data, y_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册