提交 e2d6eddd 编写于 作者: T tensor-tang

remove ComputeDeprecated

test=develop
上级 f65ddff8
...@@ -178,6 +178,7 @@ bool VActJitCode::init(int d, operand_type type) { ...@@ -178,6 +178,7 @@ bool VActJitCode::init(int d, operand_type type) {
if (type == operand_type::relu) { if (type == operand_type::relu) {
return ok; return ok;
} else { } else {
// TODO(TJ): support more
return ok && d == 8; // only 8 yet return ok && d == 8; // only 8 yet
} }
} }
......
...@@ -98,42 +98,23 @@ class VAddBiasKernel : public Kernel { ...@@ -98,42 +98,23 @@ class VAddBiasKernel : public Kernel {
template <typename T> template <typename T>
class VActKernel : public Kernel { class VActKernel : public Kernel {
public: public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0; void (*Compute)(const T *, T *, int);
}; };
template <typename T> template <typename T>
class VReluKernel : public VActKernel<T> { class VReluKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T> template <typename T>
class VIdentityKernel : public VActKernel<T> { class VIdentityKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VExpKernel : public VActKernel<T> { class VExpKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T> template <typename T>
class VSigmoidKernel : public VActKernel<T> { class VSigmoidKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T> template <typename T>
class VTanhKernel : public VActKernel<T> { class VTanhKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
......
...@@ -346,7 +346,6 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -346,7 +346,6 @@ class VReluKernelImpl : public VReluKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() { explicit VReluKernelImpl(int d) : VReluKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 /* init size */ + size_t sz = 96 /* init size */ +
...@@ -361,9 +360,6 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -361,9 +360,6 @@ class VReluKernelImpl : public VReluKernel<T> {
this->Compute = VReluRefer<T>; this->Compute = VReluRefer<T>;
} }
void ComputeDeprecated(const T* x, T* y) const override {
VReluRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
...@@ -378,22 +374,26 @@ bool VReluKernelImpl<float>::useJIT(int d) { ...@@ -378,22 +374,26 @@ bool VReluKernelImpl<float>::useJIT(int d) {
} }
#endif #endif
REGISTER_JITKERNEL(vmul, VMulKernel); template <typename T>
REGISTER_JITKERNEL(vadd, VAddKernel); inline void VIdentityRefer(const T* x, T* y, int n) {}
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
/* An empty JitKernel */ /* An empty JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> { class VIdentityKernelImpl : public VIdentityKernel<T> {
public: public:
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; } JITKERNEL_DECLARE_STATIC_FUNC;
void ComputeDeprecated(const T* x, T* y) const override {} explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>;
}
}; };
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -36,6 +36,7 @@ namespace jitkernel { ...@@ -36,6 +36,7 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file // TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template <typename T> template <typename T>
void VExpRefer(const T* x, T* y, int n) { void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -67,6 +68,7 @@ void VTanhRefer(const T* x, T* y, int n) { ...@@ -67,6 +68,7 @@ void VTanhRefer(const T* x, T* y, int n) {
} }
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup
template <typename T> template <typename T>
void VExpMKL(const T* x, T* y, int n); void VExpMKL(const T* x, T* y, int n);
...@@ -112,7 +114,6 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -112,7 +114,6 @@ class VExpKernelImpl : public VExpKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VExpKernelImpl(int d) : VExpKernel<T>() { explicit VExpKernelImpl(int d) : VExpKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
...@@ -130,9 +131,7 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -130,9 +131,7 @@ class VExpKernelImpl : public VExpKernel<T> {
#endif #endif
this->Compute = VExpRefer<T>; this->Compute = VExpRefer<T>;
} }
void ComputeDeprecated(const T* x, T* y) const override {
VExpRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
...@@ -166,7 +165,6 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -166,7 +165,6 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() { explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
...@@ -186,9 +184,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -186,9 +184,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
#endif #endif
this->Compute = VSigmoidRefer<T>; this->Compute = VSigmoidRefer<T>;
} }
void ComputeDeprecated(const T* x, T* y) const override {
VSigmoidRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
...@@ -221,7 +217,6 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -221,7 +217,6 @@ class VTanhKernelImpl : public VTanhKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() { explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
...@@ -241,9 +236,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -241,9 +236,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
#endif #endif
this->Compute = VTanhRefer<T>; this->Compute = VTanhRefer<T>;
} }
void ComputeDeprecated(const T* x, T* y) const override {
VTanhRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
......
...@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override { T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
/* C_t = C_t-1 * fgated + cand_gated * igated */ /* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
...@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/ /* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */ /* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
...@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
} }
void ComputeH1(T* gates, T* ht) const override { void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->ComputeDeprecated(gates, gates); act_gate_d_->Compute(gates, gates, d_);
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); act_state_d_->Compute(gates + d2_, gates + d2_, d_);
vmul_d_->Compute(gates, gates + d2_, ht, d_); vmul_d_->Compute(gates, gates + d2_, ht, d_);
} }
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
act_gate_d2_->ComputeDeprecated(gates, gates); act_gate_d2_->Compute(gates, gates, d2_);
vmul_d_->Compute(ht_1, gates + d_, ht, d_); vmul_d_->Compute(ht_1, gates + d_, ht, d_);
} }
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_; T* y = gates + d2_;
act_state_d_->ComputeDeprecated(y, y); act_state_d_->Compute(y, y, d_);
// out = zt*ht~ + (1-zt)*ht_1 // out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) { for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i]; ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
......
...@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) { ...@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
// ker->ComputeDeprecated(x_data, ztgt_data); // ker->Compute(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -345,8 +345,8 @@ void lstm_ctht_ref( ...@@ -345,8 +345,8 @@ void lstm_ctht_ref(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1, const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) { const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->ComputeDeprecated(gates, gates); vtanh_d->Compute(gates, gates, d);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
...@@ -356,7 +356,7 @@ void lstm_ctht_ref( ...@@ -356,7 +356,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated // H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2; float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->ComputeDeprecated(&tmp, &tmp); vexp_1->Compute(&tmp, &tmp, 1);
tmp = 2.f / (1.f + tmp) - 1.f; tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k]; ht[k] = tmp * o[k];
} }
...@@ -374,13 +374,13 @@ void lstm_ctht_better( ...@@ -374,13 +374,13 @@ void lstm_ctht_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
const int d, float* gates, const float* ct_1, float* ct, float* ht) { const int d, float* gates, const float* ct_1, float* ct, float* ht) {
int d2 = d * 2; int d2 = d * 2;
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->ComputeDeprecated(gates, gates); vtanh_d->Compute(gates, gates, d);
vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct, d); vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
vtanh_d->ComputeDeprecated(ct, gates + d2); vtanh_d->Compute(ct, gates + d2, d);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d); vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
} }
...@@ -737,7 +737,7 @@ void vaddrelu_better( ...@@ -737,7 +737,7 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu, const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z, int d) { const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d); vadd->Compute(x, y, z, d);
vrelu->ComputeDeprecated(z, z); vrelu->Compute(z, z, d);
} }
TEST(JitKernel, vaddrelu) { TEST(JitKernel, vaddrelu) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册