提交 f2adaf1c 编写于 作者: T tensor-tang

add vrelu and lstm kernel

test=develop
上级 e6d8aca3
...@@ -35,23 +35,6 @@ std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const { ...@@ -35,23 +35,6 @@ std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
return kers_.at(key); return kers_.at(key);
} }
template <>
std::shared_ptr<const LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const std::string&>(int d, const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell) {
std::string key =
"lstmf" + std::to_string(d) + act_gate + act_cand + act_cell;
if (kers_.find(key) == kers_.end()) {
auto p =
std::make_shared<LSTMKernel<float>>(d, act_gate, act_cand, act_cell);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
}
return std::dynamic_pointer_cast<const LSTMKernel<float>>(kers_.at(key));
}
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -87,36 +87,45 @@ class VAddBiasKernel : public Kernel { ...@@ -87,36 +87,45 @@ class VAddBiasKernel : public Kernel {
}; };
template <typename T> template <typename T>
class VExpKernel : public Kernel { class VActKernel : public Kernel {
public: public:
virtual void Compute(const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
class VSigmoidKernel : public Kernel { class VReluKernel : public VActKernel<T> {
public: public:
virtual void Compute(const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
class VTanhKernel : public Kernel { class VIdentityKernel : public VActKernel<T> {
public: public:
virtual void Compute(const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class VExpKernel : public VActKernel<T> {
public: public:
explicit LSTMKernel(int d, const std::string &act_gate, virtual void Compute(const T *x, T *y) const = 0;
const std::string &act_cand, const std::string &act_cell); };
void (*jit_ker)(T *, const T *, T *, T *); template <typename T>
std::function<void(T *, const T *, T *, T *)> ComputeCtHt, ComputeCtHt_NoC0H0; class VSigmoidKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
};
private: template <typename T>
int d_, d2_, d3_; class VTanhKernel : public VActKernel<T> {
std::function<void(const int, const T *, T *)> act_gate_, act_cell_, public:
act_cand_; virtual void Compute(const T *x, T *y) const = 0;
};
template <typename T>
class LSTMKernel : public Kernel {
public:
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht) const = 0;
}; };
} // namespace jitkernel } // namespace jitkernel
......
...@@ -266,15 +266,124 @@ INTRI16_FLOAT(jit::avx512f); ...@@ -266,15 +266,124 @@ INTRI16_FLOAT(jit::avx512f);
#endif #endif
// TODO(TJ): eq16 test and complete avx512 // TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
/* VRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VReluKernelImpl : public VReluKernel<T> {
public:
explicit VReluKernelImpl(int d) : VReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa) \
template <> \
VReluKernelImpl<float, isa, kGT8LT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + this->rest_, tmp1); \
}
#define INTRI_GT16_FLOAT(isa) \
template <> \
VReluKernelImpl<float, isa, kGT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + i, tmp); \
} \
__m256 tmp = _mm256_loadu_ps(x + this->rest_); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + this->rest_, tmp); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_GT8LT16_FLOAT(jit::avx);
INTRI_GT16_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_GT8LT16_FLOAT(jit::avx2);
INTRI_GT16_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_GT8LT16_FLOAT(jit::avx512f);
INTRI_GT16_FLOAT(jit::avx512f);
#endif
#undef INTRI8_FLOAT #undef INTRI8_FLOAT
#undef INTRI16_FLOAT #undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT #undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT #undef INTRI_GT16_FLOAT
/* An empty JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddb, VAddBiasKernel); REGISTER_JITKERNEL(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
......
...@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <functional>
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,51 +28,85 @@ namespace jitkernel { ...@@ -24,51 +28,85 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
template <> /* LSTM JitKernel */
LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str, template <typename T, jit::cpu_isa_t isa, jit_block>
const std::string& act_cand_str, class LSTMKernelImpl : public LSTMKernel<T> {
const std::string& act_cell_str) public:
: Kernel(), d_(d) { explicit LSTMKernelImpl(int d, const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2; d2_ = d * 2;
d3_ = d * 3; d3_ = d * 3;
if (platform::jit::MayIUse(platform::jit::avx512f)) { auto GetActKernel = [&](const std::string& type,
math::VecActivations<float, platform::jit::avx512f> act_functor; int n) -> std::shared_ptr<const VActKernel<T>> {
act_gate_ = act_functor(act_gate_str); if (type == "sigmoid") {
act_cell_ = act_functor(act_cell_str); return std::dynamic_pointer_cast<const VActKernel<T>>(
act_cand_ = act_functor(act_cand_str); KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
} else if (platform::jit::MayIUse(platform::jit::avx2)) { } else if (type == "relu") {
math::VecActivations<float, platform::jit::avx2> act_functor; return std::dynamic_pointer_cast<const VActKernel<T>>(
act_gate_ = act_functor(act_gate_str); KernelPool::Instance().template Get<VReluKernel<T>>(n));
act_cell_ = act_functor(act_cell_str); } else if (type == "tanh") {
act_cand_ = act_functor(act_cand_str); return std::dynamic_pointer_cast<const VActKernel<T>>(
} else if (platform::jit::MayIUse(platform::jit::avx)) { KernelPool::Instance().template Get<VTanhKernel<T>>(n));
math::VecActivations<float, platform::jit::avx> act_functor; } else if (type == "identity" || type == "") {
act_gate_ = act_functor(act_gate_str); return std::dynamic_pointer_cast<const VActKernel<T>>(
act_cell_ = act_functor(act_cell_str); KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
act_cand_ = act_functor(act_cand_str); }
// ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) { PADDLE_THROW("Not support type: %s", type);
// // gates: W_ch, W_ih, W_fh, W_oh };
// act_gate(d3_, gates + d_, gates + d_); act_gate_3d_ = GetActKernel(act_gate, d * 3);
act_cand_d_ = GetActKernel(act_cand, d);
// /* C_t = C_t-1 * fgated + cand_gated * igated */ act_cell_d_ = GetActKernel(act_cell, d);
// act_cand(d_, gates, gates); vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
// blas.VMUL(d_, gates, gates + d_, gates + d_); vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
// blas.VMUL(d_, ct_1, gates + d2_, gates + d2_); }
// blas.VADD(d_, gates + d_, gates + d2_, ct);
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override {
// /* H_t = act_cell(C_t) * ogated */ // gates: W_ch, W_ih, W_fh, W_oh
// act_cell(d_, ct, gates + d2_); act_gate_3d_->Compute(gates + d_, gates + d_);
// blas.VMUL(d_, gates + d2_, gates + d3_, ht)
// GET_Ct(ct_1, gates, ct); /* C_t = C_t-1 * fgated + cand_gated * igated */
// GET_Ht(ct, gates, ht); act_cand_d_->Compute(gates, gates);
// }; vmul_d_->Compute(gates, gates + d_, gates + d_);
} else { vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
math::VecActivations<float, platform::jit::isa_any> act_functor; vadd_d_->Compute(gates + d_, gates + d2_, ct);
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str); /* H_t = act_cell(C_t) * ogated */
act_cand_ = act_functor(act_cand_str); act_cell_d_->Compute(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
} }
}
private:
int d_, d2_, d3_;
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_;
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int, const std::string&, \
const std::string&, const std::string&>( \
int d, const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell)
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d, act_gate, act_cand, \
act_cell))
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <sys/time.h> #include <sys/time.h>
#include <cmath> // for exp
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -48,6 +49,59 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), ...@@ -48,6 +49,59 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
} }
} }
void vrelu_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0.f ? x[i] : 0.f;
}
}
#if defined __AVX__ || defined __AVX2__
void vrelu_intri8(const int n, const float* x, float* y) {
__m256 tmp = _mm256_loadu_ps(x);
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps());
_mm256_storeu_ps(y, tmp);
}
#endif
TEST(JitKernel, vrelu) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -10.f, 1.f);
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
const float* x_data = x.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vrelu_ref(d, x_data, zref_data);
}
auto trefe = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vrelu_intri8(d, x_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
void vaddbias_ref(const int n, const float a, const float* x, float* y) { void vaddbias_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = x[i] + a; y[i] = x[i] + a;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册