提交 ce31deb7 编写于 作者: T tensor-tang

refine refer code and add lstm refer code

test=develop
上级 c2cfb03a
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -31,49 +32,6 @@ namespace math { ...@@ -31,49 +32,6 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
template <typename T>
void VMulRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAddRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VReluRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n); void VMulMKL(const T* x, const T* y, T* z, int n);
...@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) { ...@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1); platform::dynload::cblas_sscal(n, *a, y, 1);
} else { } else {
VScalRefer<float>(a, x, y, n); refer::VScal<float>(a, x, y, n);
} }
} }
...@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) { ...@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1); platform::dynload::cblas_dscal(n, *a, y, 1);
} else { } else {
VScalRefer<double>(a, x, y, n); refer::VScal<double>(a, x, y, n);
} }
} }
...@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VMulRefer<T>; this->Compute = refer::VMul<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> { ...@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VAddRefer<T>; this->Compute = refer::VAdd<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -242,7 +200,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -242,7 +200,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VAddReluRefer<T>; this->Compute = refer::VAddRelu<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -280,7 +238,7 @@ class VScalKernelImpl : public VScalKernel<T> { ...@@ -280,7 +238,7 @@ class VScalKernelImpl : public VScalKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VScalRefer<T>; this->Compute = refer::VScal<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -324,7 +282,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> { ...@@ -324,7 +282,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
} }
#endif #endif
this->Compute = VAddBiasRefer<T>; this->Compute = refer::VAddBias<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -358,7 +316,7 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -358,7 +316,7 @@ class VReluKernelImpl : public VReluKernel<T> {
} }
#endif #endif
this->Compute = VReluRefer<T>; this->Compute = refer::VRelu<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -374,16 +332,13 @@ bool VReluKernelImpl<float>::useJIT(int d) { ...@@ -374,16 +332,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
} }
#endif #endif
template <typename T>
inline void VIdentityRefer(const T* x, T* y, int n) {}
/* An empty JitKernel */ /* An empty JitKernel */
template <typename T> template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> { class VIdentityKernelImpl : public VIdentityKernel<T> {
public: public:
JITKERNEL_DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>; this->Compute = refer::VIdentity<T>;
} }
}; };
......
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_code.h"
...@@ -35,38 +35,6 @@ namespace math { ...@@ -35,38 +35,6 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanhRefer(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidRefer(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup // try to use MKL to speedup
template <typename T> template <typename T>
...@@ -129,7 +97,7 @@ class VExpKernelImpl : public VExpKernel<T> { ...@@ -129,7 +97,7 @@ class VExpKernelImpl : public VExpKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VExpRefer<T>; this->Compute = refer::VExp<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -182,7 +150,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -182,7 +150,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VSigmoidRefer<T>; this->Compute = refer::VSigmoid<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
...@@ -234,7 +202,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -234,7 +202,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
return; return;
} }
#endif #endif
this->Compute = VTanhRefer<T>; this->Compute = refer::VTanh<T>;
} }
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
......
...@@ -38,9 +38,13 @@ typedef struct { ...@@ -38,9 +38,13 @@ typedef struct {
void* checked{nullptr}; void* checked{nullptr};
} lstm_t; } lstm_t;
typedef struct { typedef struct lstm_attr_s {
int d; int d;
std::string act_gate, act_cand, act_cell; std::string act_gate, act_cand, act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell)
: d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {}
} lstm_attr_t; } lstm_attr_t;
} // namespace jitkernel } // namespace jitkernel
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace refer {
/* Refer code only focus on correctness */
template <typename T>
void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddRelu(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBias(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T>
inline void VIdentity(const T* x, T* y, int n) {}
template <typename T>
void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") {
return VSigmoid<T>;
} else if (type == "relu") {
return VRelu<T>;
} else if (type == "tanh") {
return VTanh<T>;
} else if (type == "identity" || type == "") {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
template <typename T>
void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
act_gate(gates + d, gates + d, d3);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
template <typename T>
void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
/* C_t = igated * cgated*/
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
Vmul(gates + d2, gates + d3, ht, d);
}
} // namespace refer
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
...@@ -53,12 +54,6 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), ...@@ -53,12 +54,6 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
} }
} }
void vrelu_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0.f ? x[i] : 0.f;
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vrelu_intri8(const int n, const float* x, float* y) { void vrelu_intri8(const int n, const float* x, float* y) {
__m256 tmp = _mm256_loadu_ps(x); __m256 tmp = _mm256_loadu_ps(x);
...@@ -69,6 +64,7 @@ void vrelu_intri8(const int n, const float* x, float* y) { ...@@ -69,6 +64,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
TEST(JitKernel, vrelu) { TEST(JitKernel, vrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) { for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -80,7 +76,7 @@ TEST(JitKernel, vrelu) { ...@@ -80,7 +76,7 @@ TEST(JitKernel, vrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vrelu_ref(d, x_data, zref_data); refer::VRelu<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
...@@ -107,14 +103,9 @@ TEST(JitKernel, vrelu) { ...@@ -107,14 +103,9 @@ TEST(JitKernel, vrelu) {
} }
} }
void vaddbias_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] + a;
}
}
TEST(JitKernel, vaddbias) { TEST(JitKernel, vaddbias) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -127,7 +118,7 @@ TEST(JitKernel, vaddbias) { ...@@ -127,7 +118,7 @@ TEST(JitKernel, vaddbias) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddbias_ref(d, a, x_data, zref_data); refer::VAddBias<float>(&a, x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -145,12 +136,6 @@ TEST(JitKernel, vaddbias) { ...@@ -145,12 +136,6 @@ TEST(JitKernel, vaddbias) {
} }
} }
void vexp_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
void vexp_mkl(const int n, const float* x, float* y) { void vexp_mkl(const int n, const float* x, float* y) {
paddle::platform::dynload::vsExp(n, x, y); paddle::platform::dynload::vsExp(n, x, y);
...@@ -159,6 +144,7 @@ void vexp_mkl(const int n, const float* x, float* y) { ...@@ -159,6 +144,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST(JitKernel, vexp) { TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -170,7 +156,7 @@ TEST(JitKernel, vexp) { ...@@ -170,7 +156,7 @@ TEST(JitKernel, vexp) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vexp_ref(d, x_data, zref_data); refer::VExp<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -203,19 +189,6 @@ TEST(JitKernel, vexp) { ...@@ -203,19 +189,6 @@ TEST(JitKernel, vexp) {
} }
} }
inline float _sigmoid(float x) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (x < min) ? min : ((x > max) ? max : x);
return 1.f / (1.f + std::exp(-tmp));
}
void vsigmoid_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = _sigmoid(x[i]);
}
}
void vsigmoid_better( void vsigmoid_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp, const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp,
...@@ -234,6 +207,7 @@ void vsigmoid_better( ...@@ -234,6 +207,7 @@ void vsigmoid_better(
TEST(JitKernel, vsigmoid) { TEST(JitKernel, vsigmoid) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -252,7 +226,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -252,7 +226,7 @@ TEST(JitKernel, vsigmoid) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vsigmoid_ref(d, x_data, zref_data); refer::VSigmoid<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -271,14 +245,6 @@ TEST(JitKernel, vsigmoid) { ...@@ -271,14 +245,6 @@ TEST(JitKernel, vsigmoid) {
} }
} }
inline float _tanh(float x) { return 2.f * _sigmoid(2.f * x) - 1.f; }
void vtanh_ref(const int n, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = _tanh(x[i]);
}
}
void vtanh_better( void vtanh_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal, const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal,
...@@ -298,6 +264,7 @@ void vtanh_better( ...@@ -298,6 +264,7 @@ void vtanh_better(
TEST(JitKernel, vtanh) { TEST(JitKernel, vtanh) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d); std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -320,7 +287,7 @@ TEST(JitKernel, vtanh) { ...@@ -320,7 +287,7 @@ TEST(JitKernel, vtanh) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vtanh_ref(d, x_data, zref_data); refer::VTanh<float>(x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -339,32 +306,6 @@ TEST(JitKernel, vtanh) { ...@@ -339,32 +306,6 @@ TEST(JitKernel, vtanh) {
} }
} }
void lstm_ctht_ref(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
vsigmoid_3d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->Compute(gates, gates, d);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
for (int k = 0; k < d; ++k) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[k] = ct_1[k] * f[k] + gates[k] * i[k];
// H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->Compute(&tmp, &tmp, 1);
tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k];
}
}
void lstm_ctht_better( void lstm_ctht_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>& const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
...@@ -389,6 +330,7 @@ void lstm_ctht_better( ...@@ -389,6 +330,7 @@ void lstm_ctht_better(
TEST(JitKernel, lstm) { TEST(JitKernel, lstm) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) { for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) {
int d4 = d * 4; int d4 = d * 4;
int d3 = d * 3; int d3 = d * 3;
...@@ -410,8 +352,6 @@ TEST(JitKernel, lstm) { ...@@ -410,8 +352,6 @@ TEST(JitKernel, lstm) {
d3); d3);
const auto& vtanh_d = const auto& vtanh_d =
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
const auto& vexp_1 =
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(1);
const auto& vmul_d = const auto& vmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d); jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
const auto& vadd_d = const auto& vadd_d =
...@@ -425,8 +365,14 @@ TEST(JitKernel, lstm) { ...@@ -425,8 +365,14 @@ TEST(JitKernel, lstm) {
float* ct_ref_data = ct_ref.data(); float* ct_ref_data = ct_ref.data();
float* ht_ref_data = ht_ref.data(); float* ht_ref_data = ht_ref.data();
// compute once to check correctness // compute once to check correctness
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, jit::lstm_t step;
ct_ref_data, ht_ref_data); jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell);
step.gates = xref_data;
step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
refer::LSTMCtHt<float>(&step, &attr);
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
...@@ -441,8 +387,7 @@ TEST(JitKernel, lstm) { ...@@ -441,8 +387,7 @@ TEST(JitKernel, lstm) {
auto tmkle = GetCurrentUS(); auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, refer::LSTMCtHt<float>(&step, &attr);
ct_ref_data, ht_ref_data);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -457,16 +402,6 @@ TEST(JitKernel, lstm) { ...@@ -457,16 +402,6 @@ TEST(JitKernel, lstm) {
} }
} }
void vscal_ref(const int n, const float a, const float* x, float* y) {
for (int i = 0; i < n; ++i) {
y[i] = a * x[i];
}
}
void vscal_inp_ref(const int n, const float a, float* x) {
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vscal_intri8(const int n, const float a, const float* x, float* y) { void vscal_intri8(const int n, const float a, const float* x, float* y) {
__m256 tmp; __m256 tmp;
...@@ -492,6 +427,7 @@ void vscal_inp_mkl(const int n, const float a, float* x) { ...@@ -492,6 +427,7 @@ void vscal_inp_mkl(const int n, const float a, float* x) {
TEST(JitKernel, vscal) { TEST(JitKernel, vscal) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -506,12 +442,12 @@ TEST(JitKernel, vscal) { ...@@ -506,12 +442,12 @@ TEST(JitKernel, vscal) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vscal_ref(d, a, x_data, zref_data); refer::VScal<float>(&a, x_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto trefs1 = GetCurrentUS(); auto trefs1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vscal_inp_ref(d, a, y_data); refer::VScal<float>(&a, y_data, y_data, d);
} }
auto trefe1 = GetCurrentUS(); auto trefe1 = GetCurrentUS();
...@@ -567,12 +503,6 @@ TEST(JitKernel, vscal) { ...@@ -567,12 +503,6 @@ TEST(JitKernel, vscal) {
} }
} }
void vmul_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vmul_intri8(const int n, const float* x, const float* y, float* z) { void vmul_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
...@@ -591,6 +521,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -591,6 +521,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vmul) { TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) { for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -604,7 +535,7 @@ TEST(JitKernel, vmul) { ...@@ -604,7 +535,7 @@ TEST(JitKernel, vmul) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vmul_ref(d, x_data, y_data, zref_data); refer::VMul<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -647,12 +578,6 @@ TEST(JitKernel, vmul) { ...@@ -647,12 +578,6 @@ TEST(JitKernel, vmul) {
} }
} }
void vadd_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
#if defined __AVX__ || defined __AVX2__ #if defined __AVX__ || defined __AVX2__
void vadd_intri8(const int n, const float* x, const float* y, float* z) { void vadd_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy; __m256 tmpx, tmpy;
...@@ -671,6 +596,7 @@ void vadd_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -671,6 +596,7 @@ void vadd_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vadd) { TEST(JitKernel, vadd) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -684,7 +610,7 @@ TEST(JitKernel, vadd) { ...@@ -684,7 +610,7 @@ TEST(JitKernel, vadd) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data); refer::VAdd<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
...@@ -727,12 +653,6 @@ TEST(JitKernel, vadd) { ...@@ -727,12 +653,6 @@ TEST(JitKernel, vadd) {
} }
} }
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
void vaddrelu_better( void vaddrelu_better(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
...@@ -745,6 +665,7 @@ void vaddrelu_better( ...@@ -745,6 +665,7 @@ void vaddrelu_better(
TEST(JitKernel, vaddrelu) { TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
...@@ -762,7 +683,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -762,7 +683,7 @@ TEST(JitKernel, vaddrelu) {
float* zref_data = zref.data(); float* zref_data = zref.data();
auto trefs = GetCurrentUS(); auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
vaddrelu_ref(d, x_data, y_data, zref_data); refer::VAddRelu<float>(x_data, y_data, zref_data, d);
} }
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS(); auto tmkls = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册