提交 6648995f 编写于 作者: T tensor-tang

fix build

上级 74292f41
......@@ -82,8 +82,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track;
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::crfdecoding, jit::CRFDecoding, platform::CPUPlace>(
tag_num);
auto ker = jit::Get<jit::crfdecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
......
......@@ -108,7 +108,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16;
int C = c / simd_width;
auto multiply = jit::Get<jit::nchw16cmulnc, jit::NCHW16CMulNCTuples,
auto multiply = jit::Get<jit::nchw16cmulnc, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>(0);
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
......
......@@ -197,11 +197,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = \
jit::Get<jit::gruh1, jit::GRUTuples, platform::CPUPlace>(attr); \
jit::Get<jit::gruh1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart1 = \
jit::Get<jit::gruhtpart1, jit::GRUTuples, platform::CPUPlace>(attr); \
jit::Get<jit::gruhtpart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \
jit::Get<jit::gruhtpart2, jit::GRUTuples, platform::CPUPlace>(attr); \
jit::Get<jit::gruhtpart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
......
......@@ -250,19 +250,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const jit \
: lstm_attr_t attr( \
const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \
math::jitkernel::lstm_t one_step; \
jit::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \
jit::Get<jit::lstmc1h1, jit::LSTMTuples, platform::CPUPlace>(attr); \
jit::Get<jit::lstmc1h1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
auto ComputeCtHt = \
jit::Get<jit::lstmctht, jit::LSTMTuples, platform::CPUPlace>(attr)
jit::Get<jit::lstmctht, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
......@@ -434,7 +433,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ComputeC1H1(&one_step, &attr);
ComputeCtHt(&one_step, &attr);
// move one batch
cur_in_data += D4;
......
......@@ -32,7 +32,7 @@ inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
......@@ -68,7 +68,7 @@ inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
......@@ -93,8 +93,8 @@ inline typename KernelTuples::func_type GetRefer() {
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
// TODO(TJ): const & attr
typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
......
......@@ -230,7 +230,7 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker =
jit::Get<jit::layernorm, jit::LayerNormTuples, platform::CPUPlace>(
jit::Get<jit::layernorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
......
......@@ -31,13 +31,14 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
}
if (relu) {
auto compute =
jit::Get<jit::vaddrelu, jit::XYZNTuples, platform::CPUPlcace>(N);
jit::Get<jit::vaddrelu, jit::XYZNTuples<T>, platform::CPUPlace>(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
}
} else {
auto compute = jit::Get<jit::vadd, jit::XYZNTuples, platform::CPUPlcace>(N);
auto compute =
jit::Get<jit::vadd, jit::XYZNTuples<T>, platform::CPUPlace>(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册