提交 6648995f 编写于 作者: T tensor-tang

fix build

上级 74292f41
...@@ -82,8 +82,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,8 +82,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::crfdecoding, jit::CRFDecoding, platform::CPUPlace>( auto ker = jit::Get<jit::crfdecoding, jit::CRFDecodingTuples<T>,
tag_num); platform::CPUPlace>(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
......
...@@ -108,7 +108,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -108,7 +108,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = jit::Get<jit::nchw16cmulnc, jit::NCHW16CMulNCTuples, auto multiply = jit::Get<jit::nchw16cmulnc, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>(0); platform::CPUPlace>(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
......
...@@ -197,11 +197,11 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -197,11 +197,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = \ auto ComputeH1 = \
jit::Get<jit::gruh1, jit::GRUTuples, platform::CPUPlace>(attr); \ jit::Get<jit::gruh1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart1 = \ auto ComputeHtPart1 = \
jit::Get<jit::gruhtpart1, jit::GRUTuples, platform::CPUPlace>(attr); \ jit::Get<jit::gruhtpart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \ auto ComputeHtPart2 = \
jit::Get<jit::gruhtpart2, jit::GRUTuples, platform::CPUPlace>(attr); \ jit::Get<jit::gruhtpart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
......
...@@ -250,19 +250,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -250,19 +250,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const jit \ const jit::lstm_attr_t attr( \
: lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \ use_peepholes); \
math::jitkernel::lstm_t one_step; \ jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \ auto ComputeC1H1 = \
jit::Get<jit::lstmc1h1, jit::LSTMTuples, platform::CPUPlace>(attr); \ jit::Get<jit::lstmc1h1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
auto ComputeCtHt = \ auto ComputeCtHt = \
jit::Get<jit::lstmctht, jit::LSTMTuples, platform::CPUPlace>(attr) jit::Get<jit::lstmctht, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -434,7 +433,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -434,7 +433,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data; one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ComputeC1H1(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
......
...@@ -32,7 +32,7 @@ inline typename std::enable_if< ...@@ -32,7 +32,7 @@ inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value && std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) { GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type; using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
...@@ -68,7 +68,7 @@ inline typename std::enable_if< ...@@ -68,7 +68,7 @@ inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value || !std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) { GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr; return nullptr;
} }
...@@ -93,8 +93,8 @@ inline typename KernelTuples::func_type GetRefer() { ...@@ -93,8 +93,8 @@ inline typename KernelTuples::func_type GetRefer() {
template <KernelType KT, typename KernelTuples, template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace> typename PlaceType = platform::CPUPlace>
// TODO(TJ): const & attr typename KernelTuples::func_type Get(
typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr); auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) { if (jitfunc) {
return jitfunc; return jitfunc;
......
...@@ -230,7 +230,7 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -230,7 +230,7 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = auto ker =
jit::Get<jit::layernorm, jit::LayerNormTuples, platform::CPUPlace>( jit::Get<jit::layernorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
right); right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
......
...@@ -31,13 +31,14 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -31,13 +31,14 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
} }
if (relu) { if (relu) {
auto compute = auto compute =
jit::Get<jit::vaddrelu, jit::XYZNTuples, platform::CPUPlcace>(N); jit::Get<jit::vaddrelu, jit::XYZNTuples<T>, platform::CPUPlace>(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = jit::Get<jit::vadd, jit::XYZNTuples, platform::CPUPlcace>(N); auto compute =
jit::Get<jit::vadd, jit::XYZNTuples<T>, platform::CPUPlace>(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册