提交 c75dc885 编写于 作者: P peizhilin

add the jit support

test=develop
上级 a395942c
...@@ -11,9 +11,7 @@ add_subdirectory(controlflow) ...@@ -11,9 +11,7 @@ add_subdirectory(controlflow)
add_subdirectory(csp) add_subdirectory(csp)
add_subdirectory(detection) add_subdirectory(detection)
add_subdirectory(elementwise) add_subdirectory(elementwise)
if(NOT WIN32) add_subdirectory(fused)
add_subdirectory(fused)
endif(NOT WIN32)
add_subdirectory(metrics) add_subdirectory(metrics)
add_subdirectory(optimizers) add_subdirectory(optimizers)
add_subdirectory(reduce_ops) add_subdirectory(reduce_ops)
...@@ -50,8 +48,9 @@ endif() ...@@ -50,8 +48,9 @@ endif()
set(COMMON_OP_DEPS "") set(COMMON_OP_DEPS "")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor sequence_padding sequence_scale cos_sim_functor memory concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor sequence_padding sequence_scale cos_sim_functor memory concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} lstm_compute matrix_bit_code gru_compute activation_functions jit_kernel)
if (NOT WIN32) if (NOT WIN32)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions dynload_warpctc jit_kernel) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch dynload_warpctc)
endif() endif()
if (WITH_GPU) if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub)
......
if (NOT WIN32) add_subdirectory(detail)
add_subdirectory(detail)
endif(NOT WIN32)
function(math_library TARGET) function(math_library TARGET)
# math_library is a function to create math library. # math_library is a function to create math library.
...@@ -43,10 +41,8 @@ math_library(depthwise_conv) ...@@ -43,10 +41,8 @@ math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
math_library(sampler) math_library(sampler)
if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function)
math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions)
math_library(lstm_compute DEPS activation_functions)
endif (NOT WIN32)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
math_library(math_function DEPS blas) math_library(math_function DEPS blas)
...@@ -58,9 +54,9 @@ math_library(sequence_padding) ...@@ -58,9 +54,9 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function) math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
if (NOT WIN32)
math_library(matrix_bit_code) math_library(matrix_bit_code)
endif (NOT WIN32)
math_library(unpooling) math_library(unpooling)
math_library(vol2col) math_library(vol2col)
...@@ -76,13 +72,13 @@ if(WITH_GPU) ...@@ -76,13 +72,13 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
if (NOT WIN32)
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc) set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK) if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak) list(APPEND JIT_KERNEL_DEPS xbyak)
endif() endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS}) cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
endif (NOT WIN32)
...@@ -15,6 +15,13 @@ limitations under the License. */ ...@@ -15,6 +15,13 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> #include <math.h>
#include <string> #include <string>
#ifdef _WIN32
#undef __AVX__
#undef __AVX__2
#endif
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
......
...@@ -118,7 +118,12 @@ void VXXJitCode::generate() { ...@@ -118,7 +118,12 @@ void VXXJitCode::generate() {
ret(); ret();
} }
#ifdef _WIN32
#define ALIGN32
#else
#define ALIGN32 __attribute__((aligned(32))) #define ALIGN32 __attribute__((aligned(32)))
#endif
#define EXP_HIG 88.3762626647949f #define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f #define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341 #define CEPHES_LOG2EF 1.44269504088896341
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册