提交 720b55cb 编写于 作者: T tensor-tang

enable crf decoding and layer norm refer code

上级 64a90b2f
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <limits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
......@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track;
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
const auto& ker = math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::CRFDecodeKernel<T>>(
static_cast<int>(tag_num));
ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
auto ker = jit::Get<jit::crfdecoding, jit::CRFDecoding, platform::CPUPlace>(
tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) {
......
......@@ -42,6 +42,8 @@ const char* to_string(KernelType kt) {
ONE_CASE(gruh1);
ONE_CASE(gruhtpart1);
ONE_CASE(gruhtpart2);
ONE_CASE(crfdecoding);
ONE_CASE(layernorm);
default:
PADDLE_THROW("Not support type: %d", kt);
return "NOT JITKernel";
......@@ -64,6 +66,8 @@ KernelType to_kerneltype(const std::string& act) {
} else if (lower == "tanh" || lower == "vtanh") {
return vtanh;
}
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
return non_kernel;
}
......
......@@ -37,7 +37,9 @@ typedef enum {
lstmc1h1,
gruh1,
gruhtpart1,
gruhtpart2
gruhtpart2,
crfdecoding,
layernorm
} KernelType;
template <typename T>
......@@ -109,6 +111,21 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
const float, int);
};
// Just for adding to kernel pool without template
class Kernel {
public:
......
......@@ -23,3 +23,5 @@ USE_JITKERNEL_REFER(lstmc1h1)
USE_JITKERNEL_REFER(gruh1)
USE_JITKERNEL_REFER(gruhtpart1)
USE_JITKERNEL_REFER(gruhtpart2)
USE_JITKERNEL_REFER(crfdecoding)
USE_JITKERNEL_REFER(layernorm)
......@@ -42,4 +42,7 @@ REGISTER_REFER_KERNEL(gruh1, GRUH1);
REGISTER_REFER_KERNEL(gruhtpart1, GRUHtPart1);
REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2);
REGISTER_REFER_KERNEL(crfdecoding, CRFDecoding);
REGISTER_REFER_KERNEL(layernorm, LayerNorm);
#undef REGISTER_REFER_KERNEL
......@@ -13,6 +13,9 @@
* limitations under the License. */
#pragma once
#include <cmath>
#include <limits>
#include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -242,6 +245,80 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
}
template <typename T>
void CRFDecoding(const int seq_len, const T* x, const T* w, T* alpha,
int* track, int right) {
constexpr int state_trans_base_idx = 2;
for (int i = 0; i < right; ++i) {
alpha[i] = w[i] + x[i];
}
for (int k = 1; k < seq_len; ++k) {
for (int i = 0; i < right; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (int j = 0; j < right; ++j) {
T score = alpha[(k - 1) * right + j] +
w[(j + state_trans_base_idx) * right + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha[k * right + i] = max_score + x[k * right + i];
track[k * right + i] = max_j;
}
}
}
template <typename T>
void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
int height, const float epsilon, int right) {
// get mean
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * right;
for (int j = 0; j < right; j++) {
sum += x[offset + j];
}
mean[i] = sum / right;
}
// get variance
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * right;
for (int j = 0; j < right; j++) {
sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]);
}
var[i] = sum / right;
}
for (int i = 0; i < height; i++) {
int offset = i * right;
T sqrt_var = std::sqrt(var[i] + (T)epsilon);
for (int j = 0; j < right; j++) {
out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var;
}
}
if (scale) {
for (int i = 0; i < height; i++) {
int offset = i * right;
for (int j = 0; j < right; j++) {
out[offset + j] *= scale[j];
}
}
}
if (bias) {
for (int i = 0; i < height; i++) {
int offset = i * right;
for (int j = 0; j < right; j++) {
out[offset + j] += bias[j];
}
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
......@@ -275,6 +352,9 @@ DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
......
......@@ -515,7 +515,7 @@ TEST(JITKernel, gruhtpart2) {
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>();
}
// TODO(TJ): refine the tests template
// TODO(yihua/TJ): add crf decoding and layer norm unit tests
TEST(JITKernel, pool) {
// TODO(TJ): add some test
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__)
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#endif
#include "paddle/fluid/operators/math/math_function.h"
......@@ -229,12 +229,12 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right);
const auto& ker = math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::LayerNormKernel<T>>(
static_cast<int>(right));
ker->Compute(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
auto ker =
jit::Get<jit::layernorm, jit::LayerNormTuples, platform::CPUPlace>(
right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon));
static_cast<const float>(epsilon), right);
#endif
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册