提交 2513b2cc 编写于 作者: T tensor-tang

fix bug vtanh

上级 cf8c8e72
...@@ -29,7 +29,6 @@ namespace jitkernel { ...@@ -29,7 +29,6 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
#define AVX_FLOAT_BLOCK 8 #define AVX_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8 #define AVX2_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16 #define AVX512_FLOAT_BLOCK 16
...@@ -40,8 +39,9 @@ class Kernel { ...@@ -40,8 +39,9 @@ class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
int num_{0};
private: int end_{0};
int rest_{0};
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
...@@ -95,13 +95,13 @@ class VExpKernel : public Kernel { ...@@ -95,13 +95,13 @@ class VExpKernel : public Kernel {
template <typename T> template <typename T>
class VSigmoidKernel : public Kernel { class VSigmoidKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
class VTanhKernel : public Kernel { class VTanhKernel : public Kernel {
public: public:
virtual void Compute(const int n, const T *x, T *y) const = 0; virtual void Compute(const T *x, T *y) const = 0;
}; };
template <typename T> template <typename T>
......
...@@ -195,7 +195,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -195,7 +195,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, ztgt_data); ker->Compute(x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -227,7 +227,7 @@ void vtanh_better( ...@@ -227,7 +227,7 @@ void vtanh_better(
vaddbias, vaddbias,
const int n, const float* x, float* y) { const int n, const float* x, float* y) {
vscal->Compute(n, 2.f, x, y); vscal->Compute(n, 2.f, x, y);
vsigmoid->Compute(n, y, y); vsigmoid->Compute(y, y);
vscal->Compute(n, 2.f, y); vscal->Compute(n, 2.f, y);
vaddbias->Compute(n, -1.f, y, y); vaddbias->Compute(n, -1.f, y, y);
} }
...@@ -261,7 +261,7 @@ TEST(JitKernel, vtanh) { ...@@ -261,7 +261,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(d, x_data, ztgt_data); ker->Compute(x_data, ztgt_data);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册