提交 f4e7a473 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into accelerate_adam

test=develop
...@@ -126,16 +126,12 @@ if(ANDROID OR IOS) ...@@ -126,16 +126,12 @@ if(ANDROID OR IOS)
add_definitions(-DPADDLE_MOBILE_INFERENCE) add_definitions(-DPADDLE_MOBILE_INFERENCE)
endif() endif()
if (APPLE OR WIN32) if (APPLE)
set(WITH_MKL OFF CACHE STRING set(WITH_MKL OFF CACHE STRING
"Disable MKL for building on mac and windows" FORCE) "Disable MKL for building on mac" FORCE)
endif() endif()
if (WIN32) if (WIN32)
set(WITH_DSO OFF CACHE STRING
"Disable DSO when compiling for Windows" FORCE)
set(WITH_MKL OFF CACHE STRING
"Disable MKL when compiling for Windows" FORCE)
set(WITH_DISTRIBUTE OFF CACHE STRING set(WITH_DISTRIBUTE OFF CACHE STRING
"Disable DISTRIBUTE when compiling for Windows" FORCE) "Disable DISTRIBUTE when compiling for Windows" FORCE)
set(WITH_C_API OFF CACHE STRING set(WITH_C_API OFF CACHE STRING
......
...@@ -23,15 +23,14 @@ SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn) ...@@ -23,15 +23,14 @@ SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
IF(WIN32 OR APPLE) IF(APPLE)
MESSAGE(WARNING MESSAGE(WARNING
"Windows or Mac is not supported with MKLDNN in Paddle yet." "Mac is not supported with MKLDNN in Paddle yet."
"Force WITH_MKLDNN=OFF") "Force WITH_MKLDNN=OFF")
SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in Windows and MacOS" FORCE) SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in MacOS" FORCE)
return() return()
ENDIF() ENDIF()
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
...@@ -44,10 +43,14 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -44,10 +43,14 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") IF(NOT WIN32)
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
ENDIF(NOT WIN32)
ExternalProject_Add( ExternalProject_Add(
${MKLDNN_PROJECT} ${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -58,8 +61,15 @@ ExternalProject_Add( ...@@ -58,8 +61,15 @@ ExternalProject_Add(
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
CMAKE_ARGS -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
CMAKE_ARGS -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
CMAKE_ARGS -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
CMAKE_ARGS -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
CMAKE_ARGS -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON
CMAKE_ARGS -DMKLROOT=${MKLML_ROOT} CMAKE_ARGS -DMKLROOT=${MKLML_ROOT}
CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG} CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
...@@ -67,6 +77,11 @@ ExternalProject_Add( ...@@ -67,6 +77,11 @@ ExternalProject_Add(
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLML_ROOT} -DMKLROOT:PATH=${MKLML_ROOT}
) )
if(WIN32)
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/mkldnn.lib" CACHE FILEPATH "mkldnn library." FORCE)
else(WIN32)
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
endif(WIN32)
ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL) ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB}) SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
...@@ -85,10 +100,14 @@ ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) ...@@ -85,10 +100,14 @@ ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
# copy the real so.0 lib to install dir # copy the real so.0 lib to install dir
# it can be directly contained in wheel or capi # it can be directly contained in wheel or capi
SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0) if(WIN32)
ADD_CUSTOM_COMMAND(OUTPUT ${MKLDNN_SHARED_LIB} SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/lib/mkldnn.dll)
COMMAND cp ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB} else(WIN32)
DEPENDS mkldnn) SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0)
ADD_CUSTOM_COMMAND(OUTPUT ${MKLDNN_SHARED_LIB}
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB}
DEPENDS mkldnn)
endif(WIN32)
ADD_CUSTOM_TARGET(mkldnn_shared_lib ALL DEPENDS ${MKLDNN_SHARED_LIB}) ADD_CUSTOM_TARGET(mkldnn_shared_lib ALL DEPENDS ${MKLDNN_SHARED_LIB})
IF(WITH_C_API) IF(WITH_C_API)
......
...@@ -16,56 +16,67 @@ IF(NOT ${WITH_MKLML}) ...@@ -16,56 +16,67 @@ IF(NOT ${WITH_MKLML})
return() return()
ENDIF(NOT ${WITH_MKLML}) ENDIF(NOT ${WITH_MKLML})
IF(WIN32 OR APPLE) IF(APPLE)
MESSAGE(WARNING MESSAGE(WARNING
"Windows or Mac is not supported with MKLML in Paddle yet." "Mac is not supported with MKLML in Paddle yet."
"Force WITH_MKLML=OFF") "Force WITH_MKLML=OFF")
SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE) SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE)
return() return()
ENDIF() ENDIF()
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml")
IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL))
MESSAGE(STATUS "use pre defined download url")
SET(MKLML_VER "mklml_lnx_2019.0.20180710" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
SET(MKLML_DST_DIR "mklml") SET(MKLML_DST_DIR "mklml")
SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR})
SET(MKLML_INC_DIR ${MKLML_ROOT}/include) SET(MKLML_INC_DIR ${MKLML_ROOT}/include)
SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) if(WIN32)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
else()
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
endif()
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
INCLUDE_DIRECTORIES(${MKLML_INC_DIR}) IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL))
MESSAGE(STATUS "use pre defined download url")
if(WIN32)
SET(MKLML_VER "mklml_win_2019.0.20180710" CACHE STRING "" FORCE)
SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
else()
SET(MKLML_VER "mklml_lnx_2019.0.20180710" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
ENDIF()
endif()
FILE(WRITE ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt SET(MKLML_PROJECT "extern_mklml")
"PROJECT(MKLML)\n" MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}")
"cmake_minimum_required(VERSION 3.0)\n" SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
"install(DIRECTORY ${MKLML_VER}/include ${MKLML_VER}/lib \n" SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
" DESTINATION ${MKLML_DST_DIR})\n")
ExternalProject_Add( ExternalProject_Add(
${MKLML_PROJECT} ${MKLML_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${MKLML_SOURCE_DIR} PREFIX ${MKLML_SOURCE_DIR}
URL ${MKLML_URL}
DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR} DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${MKLML_URL} -c -q -O ${MKLML_VER}.tgz
&& tar zxf ${MKLML_VER}.tgz
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND "" CONFIGURE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT} BUILD_COMMAND ""
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLML_INSTALL_ROOT} UPDATE_COMMAND ""
INSTALL_COMMAND
${CMAKE_COMMAND} -E copy_directory ${MKLML_DOWNLOAD_DIR}/include ${MKLML_INC_DIR} &&
${CMAKE_COMMAND} -E copy_directory ${MKLML_DOWNLOAD_DIR}/lib ${MKLML_LIB_DIR}
) )
INCLUDE_DIRECTORIES(${MKLML_INC_DIR})
ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL) ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB}) SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB})
ADD_DEPENDENCIES(mklml ${MKLML_PROJECT}) ADD_DEPENDENCIES(mklml ${MKLML_PROJECT})
......
...@@ -267,7 +267,11 @@ function(cc_library TARGET_NAME) ...@@ -267,7 +267,11 @@ function(cc_library TARGET_NAME)
list(APPEND cc_library_DEPS dynload_mklml) list(APPEND cc_library_DEPS dynload_mklml)
endif() endif()
add_dependencies(${TARGET_NAME} mklml) add_dependencies(${TARGET_NAME} mklml)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed") if(WIN32)
target_link_libraries(${TARGET_NAME} ${MKLML_IOMP_LIB})
else(WIN32)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif(WIN32)
endif() endif()
# remove link to python, see notes at: # remove link to python, see notes at:
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually # https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
......
...@@ -115,20 +115,20 @@ if (NOT PROTOBUF_FOUND OR WIN32) ...@@ -115,20 +115,20 @@ if (NOT PROTOBUF_FOUND OR WIN32)
) )
endif () endif ()
if (NOT CBLAS_FOUND) if (WITH_MKLML)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/openblas")
copy(openblas_lib
SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include
DSTS ${dst_dir} ${dst_dir}
DEPS extern_openblas
)
elseif (WITH_MKLML)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mklml") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mklml")
copy(mklml_lib copy(mklml_lib
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR} SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR}
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir} DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir}
DEPS mklml DEPS mklml
) )
elseif (NOT CBLAS_FOUND OR WIN32)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/openblas")
copy(openblas_lib
SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include
DSTS ${dst_dir} ${dst_dir}
DEPS extern_openblas
)
endif () endif ()
if (WITH_MKLDNN) if (WITH_MKLDNN)
......
...@@ -208,6 +208,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act ...@@ -208,6 +208,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
......
...@@ -355,9 +355,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -355,9 +355,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle? // TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType(); CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]);
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
out_dtype);
} }
// This assumes the backward generating code will ensure IsScaleLossOp // This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss. // is true only for the op that scale the final scalar loss.
...@@ -660,13 +658,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( ...@@ -660,13 +658,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name, ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node, proto::VarType::Type dtype) const { ir::Node *out_var_node) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx, dtype); local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
......
...@@ -68,8 +68,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -68,8 +68,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateScaleLossGradOp(ir::Graph *result, void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name, const std::string &loss_grad_name,
ir::Node *out_var_node, ir::Node *out_var_node) const;
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
......
...@@ -22,66 +22,39 @@ namespace details { ...@@ -22,66 +22,39 @@ namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope, Scope *scope,
platform::Place place, platform::Place place,
platform::DeviceContext *dev_ctx, platform::DeviceContext *dev_ctx)
proto::VarType::Type dtype)
: OpHandleBase(node), : OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)), coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope), scope_(scope),
place_(place), place_(place) {
out_dtype_(dtype) {
this->SetDeviceContext(place_, dev_ctx); this->SetDeviceContext(place_, dev_ctx);
} }
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
struct ScaleLossGradFunctor {
float coeff_;
Tensor *out_;
platform::Place place_;
OpHandleBase *op_handle_;
proto::VarType::Type out_dtype_;
platform::DeviceContext *ctx_;
ScaleLossGradFunctor(float coeff, Tensor *out, platform::Place place,
OpHandleBase *op_handle, proto::VarType::Type dtype,
platform::DeviceContext *ctx)
: coeff_(coeff), out_(out), place_(place), out_dtype_(dtype), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto *out_data = out_->mutable_data<OutT>(place_);
if (platform::is_cpu_place(place_)) {
*out_data = static_cast<OutT>(coeff_);
} else {
#ifdef PADDLE_WITH_CUDA
OutT cast_coeff = static_cast<OutT>(coeff_);
auto stream = static_cast<platform::CUDADeviceContext *>(ctx_)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), out_data,
platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_),
stream);
VLOG(10) << place_ << "RUN Scale loss grad op";
#endif
}
}
};
void ScaleLossGradOpHandle::RunImpl() { void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event // Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>(); float *tmp = local_scope.FindVar(var_name)
tensor->Resize(make_ddim({1})); ->GetMutable<LoDTensor>()
->mutable_data<float>(make_ddim({1}), place_);
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, this->RunAndRecordEvent([&] {
this->dev_ctxes_.at(place_)); auto stream = static_cast<platform::CUDADeviceContext *>(
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); }); this->dev_ctxes_.at(place_))
#else ->stream();
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, nullptr); memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
framework::VisitDataType(out_dtype_, func); platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(10) << place_ << "RUN Scale loss grad op";
});
#endif #endif
}
} }
std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; } std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; }
......
...@@ -26,8 +26,8 @@ namespace details { ...@@ -26,8 +26,8 @@ namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase { struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place, platform::DeviceContext *context, platform::Place place,
proto::VarType::Type dtype); platform::DeviceContext *context);
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
...@@ -40,7 +40,6 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -40,7 +40,6 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
proto::VarType::Type out_dtype_;
}; };
} // namespace details } // namespace details
......
...@@ -157,13 +157,8 @@ bool CheckLoD(const LoD &in, int tensor_height) { ...@@ -157,13 +157,8 @@ bool CheckLoD(const LoD &in, int tensor_height) {
if (level.size() < 2) return false; if (level.size() < 2) return false;
// check: the first offset(the begin offset) of each level should be 0. // check: the first offset(the begin offset) of each level should be 0.
if (level.front() != 0) return false; if (level.front() != 0) return false;
// check: all the offsets in a level should be ascending(no same items // check: all the offsets in a level should be ascending(allow same items)
// allows). if (!std::is_sorted(level.begin(), level.end())) {
if (!std::is_sorted(level.begin(), level.begin(), [](size_t a, size_t b) {
if (a < b) return true;
return false;
})) {
LOG(INFO) << "ascending error";
return false; return false;
} }
} }
......
...@@ -217,6 +217,11 @@ TEST(LoD, CheckLoD) { ...@@ -217,6 +217,11 @@ TEST(LoD, CheckLoD) {
// check with underlying tensor storage. // check with underlying tensor storage.
ASSERT_TRUE(CheckLoD(relative_lod, 5)); ASSERT_TRUE(CheckLoD(relative_lod, 5));
ASSERT_FALSE(CheckLoD(relative_lod, 9)); ASSERT_FALSE(CheckLoD(relative_lod, 9));
// check whether lod is ascending-sorted (allow same items)
ASSERT_TRUE(CheckLoD({{0, 1, 2, 3, 4, 5}}, 5));
ASSERT_TRUE(CheckLoD({{0, 1, 3, 3, 4, 5}}, 5));
ASSERT_FALSE(CheckLoD({{0, 1, 3, 2, 5}}, 5));
} }
TEST(LoD, CheckAbsLoD) { TEST(LoD, CheckAbsLoD) {
......
...@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext {
} }
} }
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
DDim GetInputDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> GetInputsDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
bool IsRuntime() const override; bool IsRuntime() const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const override {
return GetVarTypes(Outputs(name));
}
void SetOutputDim(const std::string &name, const DDim &dim) override {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) override {
auto &names = Outputs(name);
SetDims(names, dims);
}
protected: protected:
proto::VarType::Type GetVarType(const std::string &name) const override; std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(
names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(const std::string &name) const;
DDim GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
DDim GetDim(const std::string &name) const override; std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDim(const std::string &name, const DDim &dim);
void SetDim(const std::string &name, const DDim &dim) override; void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (names[i] == framework::kEmptyVarName) {
continue;
}
SetDim(names[i], dims[i]);
}
}
std::vector<DDim> GetRepeatedDims(const std::string &name) const override; std::vector<DDim> GetRepeatedDims(const std::string &name) const override;
void SetRepeatedDims(const std::string &name, void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override; const std::vector<DDim> &dims) override;
InferShapeVarPtr GetVarPtr(const std::string &name) override;
const OpDesc &op_; const OpDesc &op_;
const BlockDesc &block_; const BlockDesc &block_;
}; };
...@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs( ...@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
return op_.Output(name); return op_.Output(name);
} }
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims( std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const { const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
...@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType( ...@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
const std::string &name) {
return block_.FindVarRecursive(name);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -123,6 +123,8 @@ class OpDesc { ...@@ -123,6 +123,8 @@ class OpDesc {
BlockDesc *Block() { return this->block_; } BlockDesc *Block() { return this->block_; }
const BlockDesc *Block() const { return this->block_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, ...@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const Scope& scope) { const Scope& scope) {
for (auto& var_name_item : innames) { for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first]; std::vector<Variable*>& input_vars = inputs[var_name_item.first];
input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name)); input_vars.push_back(scope.FindVar(var_name));
} }
} }
for (auto& var_name_item : outnames) { for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first]; std::vector<Variable*>& output_vars = outputs[var_name_item.first];
output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name)); output_vars.push_back(scope.FindVar(var_name));
} }
...@@ -474,6 +476,28 @@ const Tensor* ExecutionContext::LegacyInput<Tensor>( ...@@ -474,6 +476,28 @@ const Tensor* ExecutionContext::LegacyInput<Tensor>(
template <> template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const { const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
}
const std::vector<Variable*>& vars = it->second;
std::vector<const Tensor*> res;
res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> const Tensor* {
if (var == nullptr) return nullptr;
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"should be LoDTensor, but the received type is %s",
var->Type().name());
return &(var->Get<LoDTensor>());
});
return res;
}
template <>
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
const std::string& name) const {
auto names = op().Inputs(name); auto names = op().Inputs(name);
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -556,30 +580,28 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -556,30 +580,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
// has only one output // has only one output
const auto& outs = op_.Outputs(); const auto& outs = ctx_.outputs;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end()) { if (it == outs.end()) {
return false; return false;
} }
const auto& out = it->second; const auto& out = it->second;
if (out.size() == 0 || out[0] == kEmptyVarName) { if (out.size() == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(out.size(), 1UL, PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one outputs", name); "Output %s should not have more than one outputs", name);
return scope_.FindVar(out[0]) != nullptr; return out[0] != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
if (!op_.HasInputs(name)) { const auto& ins = ctx_.inputs;
return false; auto it = ins.find(name);
} if (it == ins.end() || it->second.empty()) {
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false; return false;
} }
for (auto& input : inputs) { for (auto& input : it->second) {
if (scope_.FindVar(input) == nullptr) { if (input == nullptr) {
return false; return false;
} }
} }
...@@ -587,15 +609,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -587,15 +609,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) { const auto& outs = ctx_.outputs;
return false; auto it = outs.find(name);
} if (it == outs.end() || it->second.empty()) {
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false; return false;
} }
for (auto& output : outputs) { for (auto& output : it->second) {
if (scope_.FindVar(output) == nullptr) { if (output == nullptr) {
return false; return false;
} }
} }
...@@ -616,16 +636,18 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -616,16 +636,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareDim(const std::string& in, const std::string& out, size_t i = 0, void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override { size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); auto in_it = ctx_.inputs.find(in);
PADDLE_ENFORCE_LT(j, Outputs(out).size()); auto out_it = ctx_.outputs.find(out);
const std::string& input_n = Inputs(in)[i]; PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
const std::string& output_n = Outputs(out)[j]; "Inputs %s should have %llu argument", in, i);
PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
"Outputs %s should have %llu argument", out, j);
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
Variable* in_var = scope_.FindVar(input_n);
Variable* out_var = scope_.FindVar(output_n);
PADDLE_ENFORCE(in_var->Type() == out_var->Type(), PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
"The type of %s and %s is not the same.", output_n, "The type of %s and %s is not the same.", in, out);
GetDim(input_n));
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
...@@ -646,13 +668,16 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -646,13 +668,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
const std::vector<std::string>& inputs = Inputs(in); auto in_it = ctx_.inputs.find(in);
const std::vector<std::string>& outputs = Outputs(out); auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_LT(i, inputs.size()); PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
PADDLE_ENFORCE_LT(j, outputs.size()); "Inputs %s should have %llu argument", in, i);
Variable* in_var = scope_.FindVar(inputs.at(i)); PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
"Outputs %s should have %llu argument", out, j);
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = scope_.FindVar(outputs.at(j)); Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out); "The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>(); auto in_tensor = in_var->Get<LoDTensor>();
...@@ -687,9 +712,64 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -687,9 +712,64 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, vars.size());
return this->GetDim(vars[0]);
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, vars.size());
SetDim(vars[0], dim);
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
protected: protected:
DDim GetDim(const std::string& name) const override { DDim GetDim(Variable* var) const {
Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
...@@ -697,25 +777,44 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -697,25 +777,44 @@ class RuntimeInferShapeContext : public InferShapeContext {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s.", "type_id is %s.",
name, var->Type().name()); var->Type().name());
} }
} }
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW("Only compile time support this method");
} }
void SetDim(const std::string& name, const DDim& dim) override { void SetDim(Variable* var, const DDim& dim) {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name()); var->Type().name());
}
}
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
} }
} }
...@@ -724,16 +823,36 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -724,16 +823,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW("Only compile time support this method");
} }
proto::VarType::Type GetVarType(const std::string& name) const override { std::vector<proto::VarType::Type> GetVarTypes(
auto* var = scope_.FindVar(name); const std::vector<Variable*>& vars) const {
return ToVarType(var->Type()); std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(), vars.end(), retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this, std::placeholders::_1));
return retv;
} }
InferShapeVarPtr GetVarPtr(const std::string& name) override { proto::VarType::Type GetVarType(Variable* var) const {
return scope_.FindVar(name); return ToVarType(var->Type());
} }
private: private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE(it != ctx_.inputs.end(),
"Operator %s does not have the input %s.", op_.Type(), name);
return it->second;
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE(it != ctx_.outputs.end(),
"Operator %s does not have the outputs %s.", op_.Type(),
name);
return it->second;
}
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
...@@ -864,8 +983,7 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -864,8 +983,7 @@ Scope* OperatorWithKernel::PrepareData(
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name); auto* var = input_vars[i];
input_vars[i] = var;
// Only tensor can be tranfer to another device. // Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) { if (var == nullptr || !VarIsTensor(*var)) {
......
...@@ -197,8 +197,31 @@ class ExecutionContext { ...@@ -197,8 +197,31 @@ class ExecutionContext {
const std::vector<const Variable*> MultiInputVar( const std::vector<const Variable*> MultiInputVar(
const std::string& name) const { const std::string& name) const {
auto names = op_.Inputs(name); auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
}
std::vector<const Variable*> res; std::vector<const Variable*> res;
res.reserve(it->second.size());
std::transform(it->second.begin(), it->second.end(),
std::back_inserter(res),
[this](Variable* var) { return var; });
return res;
}
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name);
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) {
return {};
}
return it->second;
}
const std::vector<Variable*> LegacyMultiInputVar(
const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<Variable*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [this](const std::string& name) {
...@@ -208,7 +231,7 @@ class ExecutionContext { ...@@ -208,7 +231,7 @@ class ExecutionContext {
return res; return res;
} }
std::vector<Variable*> MultiOutputVar(const std::string& name) const { std::vector<Variable*> LegacyMultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<Variable*> res; std::vector<Variable*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -250,6 +273,38 @@ class ExecutionContext { ...@@ -250,6 +273,38 @@ class ExecutionContext {
template <typename T> template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const { const std::vector<const T*> MultiInput(const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
}
const std::vector<Variable*>& vars = it->second;
std::vector<const T*> res;
res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> const T* {
return var == nullptr ? nullptr : &var->Get<T>();
});
return res;
}
template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) {
return {};
}
const std::vector<Variable*>& vars = it->second;
std::vector<T*> res;
res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> T* {
return var == nullptr ? nullptr : var->GetMutable<T>();
});
return res;
}
template <typename T>
const std::vector<const T*> LegacyMultiInput(const std::string& name) const {
auto names = op_.Inputs(name); auto names = op_.Inputs(name);
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -262,7 +317,7 @@ class ExecutionContext { ...@@ -262,7 +317,7 @@ class ExecutionContext {
} }
template <typename T> template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const { std::vector<T*> LegacyMultiOutput(const std::string& name) const {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<T*> res; std::vector<T*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -321,6 +376,10 @@ template <> ...@@ -321,6 +376,10 @@ template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const; const std::string& name) const;
template <>
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
const std::string& name) const;
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const; Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
......
...@@ -22,20 +22,6 @@ limitations under the License. */ ...@@ -22,20 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DDim InferShapeContext::GetInputDim(const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> InferShapeContext::GetInputsDim(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
std::vector<DDim> InferShapeContext::GetReaderDims( std::vector<DDim> InferShapeContext::GetReaderDims(
const std::string &name) const { const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name); const std::vector<std::string> &arg_names = Inputs(name);
...@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims( ...@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]); return this->GetRepeatedDims(arg_names[0]);
} }
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
return this->GetDim(names[idx]);
}
void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void InferShapeContext::SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
void InferShapeContext::SetReaderDims(const std::string &name, void InferShapeContext::SetReaderDims(const std::string &name,
const std::vector<DDim> &dims) { const std::vector<DDim> &dims) {
const std::vector<std::string> &arg_names = Outputs(name); const std::vector<std::string> &arg_names = Outputs(name);
...@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name, ...@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims); return this->SetRepeatedDims(arg_names[0], dims);
} }
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (names[i] == framework::kEmptyVarName) {
continue;
}
SetDim(names[i], dims[i]);
}
}
std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -25,6 +25,8 @@ limitations under the License. */ ...@@ -25,6 +25,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>; using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
class InferShapeContext { class InferShapeContext {
...@@ -33,22 +35,23 @@ class InferShapeContext { ...@@ -33,22 +35,23 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarType::Type> GetInputsVarType( virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const; const std::string &name) const = 0;
std::vector<proto::VarType::Type> GetOutputsVarType( virtual std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const = 0;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const; virtual DDim GetInputDim(const std::string &name) const = 0;
std::vector<DDim> GetInputsDim(const std::string &name) const; virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0;
std::vector<DDim> GetReaderDims(const std::string &name) const; virtual std::vector<DDim> GetReaderDims(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim); virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims); virtual void SetOutputsDim(const std::string &name,
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims); const std::vector<DDim> &dims) = 0;
virtual void SetReaderDims(const std::string &name,
const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0; virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs( virtual const std::vector<std::string> &Inputs(
...@@ -67,27 +70,15 @@ class InferShapeContext { ...@@ -67,27 +70,15 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name); virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name); const std::string &name) = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) = 0;
// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims);
protected: protected:
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0;
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name, virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0; const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const;
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,6 +28,9 @@ void Tensor::check_memory_size() const { ...@@ -27,6 +28,9 @@ void Tensor::check_memory_size() const {
"or maybe the required data-type mismatches the data already stored."); "or maybe the required data-type mismatches the data already stored.");
} }
Tensor::Tensor(std::type_index type)
: type_(framework::ToDataType(type)), offset_(0) {}
size_t Tensor::memory_size() const { size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
...@@ -101,5 +105,12 @@ const DDim& Tensor::dims() const { return dims_; } ...@@ -101,5 +105,12 @@ const DDim& Tensor::dims() const { return dims_; }
int64_t Tensor::numel() const { return product(dims_); } int64_t Tensor::numel() const { return product(dims_); }
void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) {
if (holder_) {
PADDLE_ENFORCE_EQ(numel() * SizeOfType(type()), holder->size());
}
holder_ = holder;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -69,6 +69,8 @@ class Tensor { ...@@ -69,6 +69,8 @@ class Tensor {
public: public:
Tensor() : type_(proto::VarType::FP32), offset_(0) {} Tensor() : type_(proto::VarType::FP32), offset_(0) {}
explicit Tensor(std::type_index type);
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
T* data(); T* data();
...@@ -162,6 +164,8 @@ class Tensor { ...@@ -162,6 +164,8 @@ class Tensor {
return std::move(holder_); return std::move(holder_);
} }
void ResetHolder(std::shared_ptr<memory::Allocation> holder);
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
......
...@@ -89,12 +89,21 @@ endif() ...@@ -89,12 +89,21 @@ endif()
if(WITH_MKL) if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include") include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} if(NOT WIN32)
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
else(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml${CMAKE_SHARED_LIBRARY_SUFFIX}
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md${CMAKE_SHARED_LIBRARY_SUFFIX})
endif(WIN32)
set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
if(EXISTS ${MKLDNN_PATH}) if(EXISTS ${MKLDNN_PATH})
include_directories("${MKLDNN_PATH}/include") include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) if(WIN32)
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
else(WIN32)
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
endif(WIN32)
endif() endif()
else() else()
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
......
...@@ -254,5 +254,16 @@ TEST(Analyzer_dam, compare) { compare(); } ...@@ -254,5 +254,16 @@ TEST(Analyzer_dam, compare) { compare(); }
TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif #endif
// Compare Deterministic result
TEST(Analyzer_dam, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -180,6 +180,17 @@ TEST(Analyzer_LAC, compare) { ...@@ -180,6 +180,17 @@ TEST(Analyzer_LAC, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_LAC, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -179,5 +179,16 @@ TEST(Analyzer_Chinese_ner, compare) { ...@@ -179,5 +179,16 @@ TEST(Analyzer_Chinese_ner, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_Chinese_ner, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -85,6 +85,17 @@ TEST(Analyzer_resnet50, compare) { compare(); } ...@@ -85,6 +85,17 @@ TEST(Analyzer_resnet50, compare) { compare(); }
TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif #endif
// Compare Deterministic result
TEST(Analyzer_resnet50, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -265,6 +265,17 @@ TEST(Analyzer_rnn1, compare) { ...@@ -265,6 +265,17 @@ TEST(Analyzer_rnn1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_rnn1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
// Test Multi-Thread. // Test Multi-Thread.
TEST(Analyzer_rnn1, multi_thread) { TEST(Analyzer_rnn1, multi_thread) {
contrib::AnalysisConfig cfg; contrib::AnalysisConfig cfg;
......
...@@ -158,5 +158,16 @@ TEST(Analyzer_rnn2, compare) { ...@@ -158,5 +158,16 @@ TEST(Analyzer_rnn2, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_rnn2, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -204,5 +204,16 @@ TEST(Analyzer_seq_conv1, compare) { ...@@ -204,5 +204,16 @@ TEST(Analyzer_seq_conv1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_seq_conv1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -106,6 +106,17 @@ TEST(Analyzer_Text_Classification, compare) { ...@@ -106,6 +106,17 @@ TEST(Analyzer_Text_Classification, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_Text_Classification, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) { TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
......
...@@ -145,6 +145,17 @@ TEST(Analyzer_vis, compare) { compare(); } ...@@ -145,6 +145,17 @@ TEST(Analyzer_vis, compare) { compare(); }
TEST(Analyzer_vis, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_vis, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif #endif
// Compare Deterministic result
TEST(Analyzer_vis, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -45,6 +45,7 @@ DEFINE_bool(use_analysis, true, ...@@ -45,6 +45,7 @@ DEFINE_bool(use_analysis, true,
"Running the inference program in analysis mode."); "Running the inference program in analysis mode.");
DEFINE_bool(record_benchmark, false, DEFINE_bool(record_benchmark, false,
"Record benchmark after profiling the model"); "Record benchmark after profiling the model");
DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
...@@ -85,7 +86,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -85,7 +86,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
float *pdata = static_cast<float *>(out.data.data()); float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data()); float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) { for (size_t j = 0; j < size; ++j) {
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3); EXPECT_NEAR(pdata_ref[j], pdata[j], FLAGS_accuracy);
} }
break; break;
} }
...@@ -283,6 +284,26 @@ void TestPrediction(const PaddlePredictor::Config *config, ...@@ -283,6 +284,26 @@ void TestPrediction(const PaddlePredictor::Config *config,
} }
} }
void CompareDeterministic(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
auto predictor = CreateTestPredictor(config, FLAGS_use_analysis);
// warmup run
std::vector<PaddleTensor> warmup_outputs, outputs;
predictor->Run(inputs[0], &warmup_outputs, batch_size);
// run num_times to Compare Deterministic Result.
for (int i = 0; i < num_times; i++) {
for (size_t j = 0; j < inputs.size(); j++) {
predictor->Run(inputs[j], &outputs, batch_size);
CompareResult(outputs, warmup_outputs);
}
}
}
void CompareNativeAndAnalysis( void CompareNativeAndAnalysis(
const PaddlePredictor::Config *config, const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) { const std::vector<std::vector<PaddleTensor>> &inputs) {
......
...@@ -16,6 +16,7 @@ add_subdirectory(metrics) ...@@ -16,6 +16,7 @@ add_subdirectory(metrics)
add_subdirectory(optimizers) add_subdirectory(optimizers)
add_subdirectory(reduce_ops) add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops) add_subdirectory(sequence_ops)
add_subdirectory(jit)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed) add_subdirectory(distributed)
...@@ -42,8 +43,7 @@ if (WITH_DISTRIBUTE) ...@@ -42,8 +43,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif() endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32) if (WITH_GPU AND NOT WIN32)
...@@ -65,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) ...@@ -65,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
if (WITH_GPU) if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
...@@ -92,4 +92,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) ...@@ -92,4 +92,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
...@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs); ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs)); ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kX);
auto pg_ig_names = ctx->Outputs(kXGRAD); auto pg_ig_names = ctx->Outputs(kXGRAD);
auto var_types = ctx->GetInputsVarType(kX); std::vector<framework::InferShapeVarPtr> in_var_ptrs =
std::vector<std::string> names_to_set; ctx->GetInputVarPtrs(kX);
std::vector<framework::DDim> dims_to_set; std::vector<framework::InferShapeVarPtr> out_var_ptrs =
for (size_t i = 0; i < p_names.size(); ++i) { ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size());
for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
if (pg_ig_names[i] == framework::kEmptyVarName) { if (pg_ig_names[i] == framework::kEmptyVarName) {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kX, i); if (ctx->IsRuntime()) {
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { framework::Variable *in_var =
names_to_set.push_back(pg_ig_names[i]); boost::get<framework::Variable *>(in_var_ptrs[i]);
dims_to_set.push_back(dims); framework::Variable *out_var =
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { boost::get<framework::Variable *>(out_var_ptrs[i]);
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_ig_names[i]); auto type = framework::ToVarType(in_var->Type());
dims_to_set.push_back(dims); if (type == framework::proto::VarType::LOD_TENSOR) {
out_var->GetMutable<LoDTensor>()->Resize(
in_var->Get<framework::LoDTensor>().dims());
} else if (type == framework::proto::VarType::SELECTED_ROWS) {
out_var->GetMutable<framework::SelectedRows>()->set_height(
in_var->Get<framework::SelectedRows>().GetCompleteDims()[0]);
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
PADDLE_THROW("WhileGradOp doesn't support type %d",
static_cast<int>(type));
}
} else {
framework::VarDesc *in_var =
boost::get<framework::VarDesc *>(in_var_ptrs[i]);
boost::get<framework::VarDesc *>(out_var_ptrs[i])
->SetShape(in_var->GetShape());
} }
} }
ctx->SetDims(names_to_set, dims_to_set);
} }
}; };
......
...@@ -155,11 +155,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -155,11 +155,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
if (is_conv3d) { weights_format = mkldnn::memory::format::any;
chosen_memory_format = // Check the format for user's special output
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); if (chosen_memory_format != mkldnn::memory::format::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
} }
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
...@@ -435,11 +438,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -435,11 +438,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
if (is_conv3d) { weights_format = mkldnn::memory::format::any;
chosen_memory_format = // Check the format for user's special output
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); if (chosen_memory_format != mkldnn::memory::format::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
} }
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/create_tensor_with_allocationptr.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -123,6 +124,8 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -123,6 +124,8 @@ class GemmConvKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
auto& dev_ctx = context.template device_context<DeviceContext>();
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
...@@ -155,13 +158,19 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -155,13 +158,19 @@ class GemmConvKernel : public framework::OpKernel<T> {
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
if (is_expand) { if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace()); auto tmp_allocation_ptr =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
framework::product(col_shape) * sizeof(T));
Tensor tep_tensor =
platform::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col.ShareDataWith(tep_tensor);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape =
input->dims(), 1, static_cast<int>(input->dims().size())); framework::slice_ddim(input->dims(), 1, input->dims().size());
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
...@@ -178,7 +187,6 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -178,7 +187,6 @@ class GemmConvKernel : public framework::OpKernel<T> {
math::Vol2ColFunctor<DeviceContext, T> vol2col; math::Vol2ColFunctor<DeviceContext, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
...@@ -237,6 +245,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -237,6 +245,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
...@@ -262,8 +272,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -262,8 +272,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1); framework::flatten_to_2d(col_shape, data_dim + 1);
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape =
input->dims(), 1, static_cast<int>(input->dims().size())); framework::slice_ddim(input->dims(), 1, input->dims().size());
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
...@@ -286,13 +296,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -286,13 +296,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
if (is_expand) { if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace()); auto tmp_allocation_ptr =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
framework::product(col_shape) * sizeof(T));
Tensor tep_tensor =
platform::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col.ShareDataWith(tep_tensor);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad) { if (input_grad) {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
const auto& ker = math::jitkernel::KernelPool::Instance() auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
.template Get<math::jitkernel::CRFDecodeKernel<T>>( platform::CPUPlace>(tag_num);
static_cast<int>(tag_num)); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
......
...@@ -12,23 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -16,11 +16,14 @@ limitations under the License. */ ...@@ -16,11 +16,14 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #ifdef PADDLE_WITH_XBYAK
#include "xbyak/xbyak.h" #include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h" #include "xbyak/xbyak_util.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -81,8 +84,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -81,8 +84,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format"); UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format");
UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format"); UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format");
Xbyak::util::Cpu cpu; const bool is_avx512_enabled = platform::MayIUse(platform::avx512f);
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
const bool are_dims_divisable = !(x_int_dims[1] % 16); const bool are_dims_divisable = !(x_int_dims[1] % 16);
const bool is_x_format_correct = x->format() == memory::format::nChw16c; const bool is_x_format_correct = x->format() == memory::format::nChw16c;
const bool is_y_format_correct = y->format() == memory::format::nc; const bool is_y_format_correct = y->format() == memory::format::nc;
...@@ -108,10 +110,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -108,10 +110,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
const auto& multiply = auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
math::jitkernel::KernelPool::Instance() platform::CPUPlace>(0);
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
...@@ -122,7 +122,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -122,7 +122,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto ptr_z = auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); multiply(ptr_x, ptr_y, ptr_z, h, w);
} }
} }
} }
......
...@@ -12,21 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>, elementwise_mul,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>); ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>); int64_t>);
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -23,6 +22,4 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -23,6 +22,4 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h" #include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("activation")); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
math::jitkernel::gru_t one_step; \ jit::gru_t one_step; \
const auto& ker = \ auto ComputeH1 = \
math::jitkernel::KernelPool::Instance() \ jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
.template Get<math::jitkernel::GRUKernel<T>, \ auto ComputeHtPart1 = \
const math::jitkernel::gru_attr_t&>(attr); \ jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \ auto ComputeHtPart2 = \
const T* wx_data = wx->data<T>(); \ jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* wh_data = wh->data<T>(); \ const T* x_data = x->data<T>(); \
auto place = ctx.GetPlace(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
...@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} else { } else {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data; one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
...@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const math::jitkernel::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("candidate_activation"), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
math::jitkernel::lstm_t one_step; \ use_peepholes); \
one_step.wp = wp_data; \ jit::lstm_t one_step; \
one_step.checked = checked_cell_data; \ one_step.wp = wp_data; \
const auto& ker = \ one_step.checked = checked_cell_data; \
math::jitkernel::KernelPool::Instance() \ auto ComputeC1H1 = \
.template Get<math::jitkernel::LSTMKernel<T>, \ jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
const math::jitkernel::lstm_attr_t&>(attr) auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = prev_c_data; one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
...@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data; one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
......
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
# refer must go first
add_subdirectory(refer)
add_subdirectory(more)
if(WITH_XBYAK)
add_subdirectory(gen)
endif()
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
if(NOT WIN32)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper device_tracer)
endif()
# JIT Kernel
JIT(Just In Time) Kernel contains actually generated code and some other implemenations with the same logic.
Each implementations has its own condition to use, defined in `UseMe`.
They are combined together to get the best performance of one single independent function.
They could be some very simple functions like vector multiply, or some complicated functions like LSTM.
And they can be composed with some other exited jit kernels to build up a complex function.
Currently it's only supported on CPU yet.
## Contents
```txt
PaddlePaddle/Paddle/paddle/fluid/
├── ...
└── operators/
├── .../
└── jit/
├── ...
├── gen/
│ └── ...
|── more/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
└── ...
```
All basical definations of jit kernels are addressed in `paddle/fluid/operators/jit` including these three key folders `refer`, `gen`, `more`. There is only one unique name for each kernel while may have seraval implementations with same functionality.
- `refer`: Each kernel must have one reference implementation on CPU, and it should only focus on the correctness and should not depends on any third-party libraries.
- `gen`: The code generated should be kept here. They should be designed focusing on the best performance, which depends on Xbyak.
- `more`: All other implementations should be kept in this folder with one directory corresponding to one library kind or method kind, such as mkl, mkldnn, openblas or intrinsic code. Each implementation should have it advantage.
## How to use
One simple function `jit::Get`, which is very easy to use, is supported to get the kernel.
It can automatically return the expected function with best performance under the given attributes.
All kernels are inlcuded in `paddle/fluid/operators/jit/kernels.h`, you can only include this one header to get all the registered kernels.
## Solid Test
- Unit Test
All functions should be compared with the corresponding reference functions, including data tyep `float` and `double`.
- Benchmark
All functions should be tested, and make sure the `jit::Get` function obtain the best performance with all attributes.
# How to add new kernel
## Required
1. Add `your_key` at `KernelType`.
2. Add reference function of `your_key`.
Note:
- this should be run on CPU and do not depend on any third-party.
- Add `USE_JITKERNEL_REFER(your_key)` in `refer/CmakeLists.txt` to make sure this code can be used.
3. Add unit test in `test.cc`, and verfiy at least `float` and `double`.
Test more data type for some special functions if necessary, for example `int8`.
4. Add functions in `benchmark.cc` to test all function of same `KernelType`. Make sure `jit::Get` always get the best one.
## Optional
Add more implementations of `your_kery` for performance enhancement.
1. Add functions based on generated code in `gen`. It should be derived from `JitCode` and should have corepsonding creator from `JitCodeCreator` which will be registered on the `your_key`.
Note: Add new `KernelTuples` if necessary,your can refer to `XYZNTuples`.
Specialie method `JitCodeKey` when add new attribute type。
2. Add more functions in `more`,you can use any third party you wish, like mkl, mkldnn or intrinsic code to reach the best performance.
# JIT Kernel
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
## 目录结构
```txt
PaddlePaddle/Paddle/paddle/fluid/
├── ...
└── operators/
├── .../
└── jit/
├── ...
├── gen/
│ └── ...
|── more/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
└── ...
```
基本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现都是可选的。
- gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就是性能。
- refer: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性。
- more: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic,openblas等,也可以是自身已有的kernel组合。
## 动态获取
提供一个`jit::Get`方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。
## 测试
- 逻辑测试
所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型
- 性能测试
所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要在各种条件下都是最好的。
# 如何添加新的算子
-`KernelType` 中添加 `your_key` .
- 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel.
- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。
- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。
- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。
-`test.cc`中添加unit test,至少需要测试`float``double`两种数据类型,如有必要需要支持额外的数据类型,比如`int8`的相关函数。
-`benchmark.cc`中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保`jit::Get`得到的实现一直是速度最快的。
# 优点
- 统一的Get方法,接口简单。
- 同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。
- 目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。
- 优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。
- 可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times.");
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
std::mt19937 rng(seed);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i <= FLAGS_max_size; ++i) {
s.push_back(i);
}
return s;
}
template <typename KernelTuples, typename... Args>
struct BenchFunc {
// return this function avg time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
auto start = paddle::platform::PosixInNsec() / 1e-3;
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(args...);
}
auto end = paddle::platform::PosixInNsec() / 1e-3;
return static_cast<double>(end - start) / FLAGS_repeat;
}
};
namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KT, KernelTuples>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
infos.push_back(std::make_pair("Target", benchmark(tgt, args...)));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d), z(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(),
z.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(),
d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
use_peephole);
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
std::vector<T> x(3 * d), ht_1(d), ht(d);
RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_data = ht.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
<< " times.";
using T = float;
using PlaceType = paddle::platform::CPUPlace;
// xyzn
BenchXYZNKernel<jit::kVMul, T, PlaceType>();
BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
BenchXYZNKernel<jit::kVSub, T, PlaceType>();
// axyn
BenchAXYNKernel<jit::kVScal, T, PlaceType>();
BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
// lstm and peephole
BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
// gru functions
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
}
file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
function(USE_JITKERNEL_GEN TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
USE_JITKERNEL_GEN(kVTanh)
USE_JITKERNEL_GEN(kLSTMCtHt)
USE_JITKERNEL_GEN(kLSTMC1H1)
USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)};
int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0};
void VActJitCode::genCode() {
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
act<ymm_t>(ymm_dst, ymm_src, type_);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) {
block = 2;
vmovq(xmm_src, ptr[param1 + offset]);
} else {
block = 1;
vmovss(xmm_src, ptr[param1 + offset]);
}
act<xmm_t>(xmm_dst, xmm_src, type_);
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
size_t VReluCreator::CodeSize(const int& d) const {
return 96 /* init size */ +
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
8 /* average bytes for each instruction */;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VExpCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8;
}
size_t VSigmoidCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8;
}
size_t VTanhCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8;
}
#undef DECLARE_ACT_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
You may obtain a copy of the License at * You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_gen.h" #include "glog/logging.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace jit {
namespace jitkernel {
namespace gen { namespace gen {
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
} operand_type;
extern const float exp_float_consts[]; extern const float exp_float_consts[];
extern const int exp_int_0x7f[]; extern const int exp_int_0x7f[];
extern int g_tmp_mem[]; extern int g_tmp_mem[];
...@@ -79,94 +59,15 @@ extern int g_tmp_mem[]; ...@@ -79,94 +59,15 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VActFunc : public JitCode {
class VXXJitCode : public JitCode {
public:
const char* name() const override {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
static bool init(int d, int scalar_index = 0);
void generate() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
class VActJitCode : public JitCode {
public: public:
const char* name() const override { explicit VActFunc(size_t code_size, void* code_ptr)
std::string base = "VActJitCode"; : JitCode(code_size, code_ptr) {}
switch (type_) { virtual const char* name() const = 0;
case operand_type::relu: virtual void genCode() = 0;
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), type_(type) {}
static bool init(int d, operand_type type);
void generate() override;
protected: protected:
// compute relu with ymm, xmm // compute RELU with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx); JMM zero = JMM(zero_idx);
...@@ -174,7 +75,7 @@ class VActJitCode : public JitCode { ...@@ -174,7 +75,7 @@ class VActJitCode : public JitCode {
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute EXP with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
...@@ -258,7 +159,7 @@ class VActJitCode : public JitCode { ...@@ -258,7 +159,7 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute sigmoid with ymm, xmm // compute SIGMOID with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -283,7 +184,7 @@ class VActJitCode : public JitCode { ...@@ -283,7 +184,7 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute tanh with ymm, xmm // compute TANH with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -310,223 +211,109 @@ class VActJitCode : public JitCode { ...@@ -310,223 +211,109 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute IDENTITY with ymm, xmm
template <typename JMM>
void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT
JMM zero = JMM(zero_idx);
vxorps(zero, zero, zero);
vaddps(dst, src, zero);
// TODO(TJ): use below
// dst.setIdx(src.getIdx());
}
template <typename JMM> template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15 // use 11~15
switch (type) { switch (type) {
case operand_type::relu: case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15); relu_jmm<JMM>(dst, src, 15);
break; break;
case operand_type::exp: case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::tanh: case operand_type::TANH:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::identity: case operand_type::IDENTITY:
identity_jmm<JMM>(dst, src, 15);
break; break;
default: default:
// throw error LOG(FATAL) << "Do not support this operand type: " << type;
break; break;
} }
} }
protected:
int num_;
operand_type type_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
xmm_t xmm_src = xmm_t(0);
ymm_t ymm_src = ymm_t(0);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1);
}; };
class LSTMJitCode : public VActJitCode { class VActJitCode : public VActFunc {
public: public:
const char* name() const override { explicit VActJitCode(int d, operand_type type, size_t code_size,
std::string base = "LSTMJitCode"; void* code_ptr = nullptr)
if (use_peephole_) { : VActFunc(code_size, code_ptr), num_(d), type_(type) {
base += "_Peephole"; if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
} type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
if (compute_c1h1_) { type_ == operand_type::IDENTITY)) {
base += "_C1H1"; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
auto AddTypeStr = [&](operand_type type) { this->genCode();
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
} }
static bool init(int d);
void generate() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
class GRUJitCode : public VActJitCode {
public:
const char* name() const override { const char* name() const override {
std::string base = "GRUJitCode"; std::string base = "VActJitCode";
if (id_ == 0) { switch (type_) {
base += "_H1"; case operand_type::RELU:
} else if (id_ == 1) { base += "_Relu";
base += "_HtPart1"; break;
} else if (id_ == 2) { case operand_type::EXP:
base += "_HtPart2"; base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
} }
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str(); return base.c_str();
} }
void genCode() override;
explicit GRUJitCode(int id, const gru_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
}
static bool init(int d);
void generate() override;
protected: protected:
int id_;
int num_; int num_;
operand_type act_gate_; operand_type type_;
operand_type act_cand_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
}; reg64_t param2{abi_param2};
#ifdef PADDLE_WITH_MKLDNN xmm_t xmm_src = xmm_t(0);
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { ymm_t ymm_src = ymm_t(0);
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
: Xbyak::CodeGenerator(code_size) {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx); xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1);
};
xor_(rax, rax); #define DECLARE_ACT_JITCODE(name, op_type) \
xor_(r10, r10); class name##JitCode : public VActJitCode { \
vmovups(zmm3, ptr[rsi]); public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VActJitCode(d, op_type, code_size, code_ptr) {} \
};
L("h_loop"); DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
xor_(rbx, rbx); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
L("w_loop"); DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
vmovups(zmm2, ptr[rdi + rax]); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
vmulps(zmm1, zmm2, zmm3); DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx); #undef DECLARE_ACT_JITCODE
ret();
}
};
#endif
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jit
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void VXXJitCode::genCode() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::MUL) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
} else if (rest >= 2) {
block = 2;
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
}
} else {
block = 1;
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
}
switch (type_) {
case operand_type::MUL:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
void NCHW16CMulNCJitCode::genCode() {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr));
}
};
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_BLAS_CREATOR(VMul);
DECLARE_BLAS_CREATOR(VAdd);
DECLARE_BLAS_CREATOR(VSub);
DECLARE_BLAS_CREATOR(VAddRelu);
DECLARE_BLAS_CREATOR(VScal);
DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::MUL) {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
void genCode() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
#undef DECLARE_BLAS_JITCODE
// nChw16c = nChw16c .* NC
class NCHW16CMulNCJitCode : public JitCode {
public:
DECLARE_JIT_CODE(NCHW16CMulNCJitCode);
explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
void genCode() override;
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void GRUJitCode::genCode() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
#define DECLARE_GRU_CREATOR(name) \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const gru_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_GRU_CREATOR(GRUH1);
DECLARE_GRU_CREATOR(GRUHtPart1);
DECLARE_GRU_CREATOR(GRUHtPart2);
#undef DECLARE_GRU_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class GRUJitCode : public VActFunc {
public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
this->genCode();
}
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
void genCode() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#define DECLARE_GRU_JITCODE(name, id) \
class name##JitCode : public GRUJitCode { \
public: \
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: GRUJitCode(id, attr, code_size, code_ptr) {} \
};
DECLARE_GRU_JITCODE(GRUH1, 0);
DECLARE_GRU_JITCODE(GRUHtPart1, 1);
DECLARE_GRU_JITCODE(GRUHtPart2, 2);
#undef DECLARE_GRU_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX);
constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
MUL = 0,
ADD,
SUB,
RELU,
EXP,
SIGMOID,
TANH,
IDENTITY
} operand_type;
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
virtual void preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}
virtual void postCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
}
ret();
}
void L(const char* label) { Xbyak::CodeGenerator::L(label); }
void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false) {
int scale = 0;
// Learn from https://github.com/intel/mkl-dnn
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
offt = offt - 2 * EVEX_max_8b_offt;
scale = 1;
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
offt = offt - 4 * EVEX_max_8b_offt;
scale = 2;
}
auto re = Xbyak::RegExp() + base + offt;
if (scale) {
re = re + reg_EVEX_max_8b_offt * scale;
}
if (bcast) {
return zword_b[re];
} else {
return zword[re];
}
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void LSTMJitCode::genCode() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
#define DECLARE_LSTM_CREATOR(name) \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const lstm_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_LSTM_CREATOR(LSTMCtHt);
DECLARE_LSTM_CREATOR(LSTMC1H1);
#undef DECLARE_LSTM_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class LSTMJitCode : public VActFunc {
public:
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size, void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr),
num_(attr.d),
compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
this->genCode();
}
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
void genCode() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \
class name##JitCode : public LSTMJitCode { \
public: \
explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \
};
DECLARE_LSTM_JITCODE(LSTMCtHt, false);
DECLARE_LSTM_JITCODE(LSTMC1H1, true);
#undef DECLARE_LSTM_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;
std::ostringstream filename;
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
counter++;
std::ofstream fout(filename.str(), std::ios::out);
if (fout.is_open()) {
fout.write(reinterpret_cast<const char*>(code), this->getSize());
fout.close();
}
}
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
DECLARE_bool(dump_jitcode);
namespace paddle {
namespace operators {
namespace jit {
class GenBase : public Kernel {
public:
virtual ~GenBase() = default;
virtual const char* name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename Func>
Func getCode() {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
}
// Note: failed to cast with reinterpret_cast<const Func> on Mac clang,
// then workaround with const_cast. Any better idea is appreciated.
return reinterpret_cast<Func>(const_cast<unsigned char*>(code));
}
protected:
void dumpCode(const unsigned char* code) const;
};
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class GenCreator {
public:
virtual ~GenCreator() = default;
};
template <typename Attr>
class JitCodeCreator : public GenCreator {
public:
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
// create this code
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
#define ONE_CASE(key) \
case key: \
return #key
const char* to_string(KernelType kt) {
switch (kt) {
ONE_CASE(kVMul);
ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu);
ONE_CASE(kVSub);
ONE_CASE(kVScal);
ONE_CASE(kVAddBias);
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
ONE_CASE(kLSTMC1H1);
ONE_CASE(kGRUH1);
ONE_CASE(kGRUHtPart1);
ONE_CASE(kGRUHtPart2);
ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
}
return nullptr;
}
#undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
std::string lower = act;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower == "relu" || lower == "vrelu") {
return kVRelu;
} else if (lower == "identity" || lower == "videntity" || lower == "") {
return kVIdentity;
} else if (lower == "exp" || lower == "vexp") {
return kVExp;
} else if (lower == "sigmoid" || lower == "vsigmoid") {
return kVSigmoid;
} else if (lower == "tanh" || lower == "vtanh") {
return kVTanh;
}
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
return kNone;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>();
}
const char* to_string(KernelType kt);
KernelType to_kerneltype(const std::string& act);
inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "],act_cell["
<< to_string(attr.act_cell) << "],use_peephole["
<< (attr.use_peephole ? "True" : "False") << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "]";
return os;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
namespace jit {
typedef enum {
kNone = 0,
kVMul = 1,
kVAdd = 2,
kVAddRelu,
kVSub,
kVScal,
kVAddBias,
kVRelu,
kVIdentity,
kVExp,
kVSigmoid,
kVTanh,
kLSTMCtHt,
kLSTMC1H1,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
kCRFDecoding,
kLayerNorm,
kNCHW16CMulNC,
} KernelType;
template <typename T>
struct XYZNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
};
template <typename T>
struct AXYNTuples : public XYZNTuples<T> {};
template <typename T>
struct XYNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int);
};
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr}; // W_ic, W_fc, W_oc
void* checked{nullptr}; // size: 2 * d
} lstm_t;
typedef struct {
void* gates; // gates: {x_update, x_reset; x_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
KernelType act_gate, act_cand;
rnn_attr_s() = default;
explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
KernelType act_cell;
lstm_attr_s() = default;
explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
template <typename T>
struct GRUTuples {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
const float, int);
};
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// Just for adding to kernel pool without template
class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename KernelTuples>
class KernelMore : public Kernel {
public:
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
protected:
Func func{nullptr};
};
template <typename KernelTuples>
class ReferKernel : public KernelMore<KernelTuples> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace paddle {
namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
return d;
}
constexpr int act_type_shift = 3; // suppot 2^3 act types
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
struct KernelKey {
struct Hash {
size_t operator()(const KernelKey& key) const {
int place = key.place_.which(); // less than 2^8
int type = static_cast<int>(key.type_) << 8; // less than 2^(32-8)
std::hash<int> hasher;
return hasher(place + type);
}
};
KernelType type_;
platform::Place place_;
KernelKey(KernelType type, platform::Place place)
: type_(type), place_(place) {}
size_t hash_key() const { return Hash()(*this); }
bool operator==(const KernelKey& o) const {
return platform::places_are_same_class(place_, o.place_) &&
type_ == o.type_;
}
bool operator!=(const KernelKey& o) const { return !(*this == o); }
};
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace jit {
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;
}
KernelPool& KernelPool::Instance() {
static KernelPool g_kernel_pool;
return g_kernel_pool;
}
ReferKernelPool& ReferKernelPool::Instance() {
static ReferKernelPool g_refer_kernel_pool;
return g_refer_kernel_pool;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap;
public:
JitCodePool() = default;
static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes;
return g_jit_codes;
}
const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value));
}
private:
JitCodeMap codes_;
DISABLE_COPY_AND_ASSIGN(JitCodePool);
};
class JitCodeCreatorPool {
typedef std::unique_ptr<const GenCreator> GenCreatorPtr;
typedef std::unordered_map<KernelKey, std::vector<GenCreatorPtr>,
KernelKey::Hash>
GenCreatorPtrMap;
public:
JitCodeCreatorPool() = default;
static JitCodeCreatorPool& Instance();
GenCreatorPtrMap& AllCreators() { return creators_; }
void Insert(const KernelKey& key, GenCreatorPtr value) {
if (creators_.find(key) == creators_.end()) {
creators_.emplace(key, std::vector<GenCreatorPtr>());
}
creators_.at(key).emplace_back(std::move(value));
}
private:
GenCreatorPtrMap creators_;
DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool);
};
typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap;
class KernelPool {
public:
static KernelPool& Instance();
KernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(KernelPool);
};
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class ReferKernelPool {
public:
static ReferKernelPool& Instance();
ReferKernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(ReferKernelPool);
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
namespace paddle {
namespace operators {
namespace jit {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
} // namespace jit
} // namespace operators
} // namespace paddle
function(USE_JITKERNEL_MORE TARGET TYPE)
file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n")
endfunction()
if(WITH_MKLML)
add_subdirectory(mkl)
endif()
if(WITH_AVX)
add_subdirectory(intrinsic)
endif()
# mix should be last
add_subdirectory(mix)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} PARENT_SCOPE)
file(GLOB jit_kernel_cc_intrinsic RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kCRFDecoding, intrinsic)
USE_JITKERNEL_MORE(kLayerNorm, intrinsic)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
// Note: intrinsic code is not runtime build.
// For example, if you build code on AVX, and run on AVX512 it can only use AVX
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num) {
#ifdef __AVX512F__
const int step_size = ZMM_FLOAT_BLOCK;
#else
const int step_size = YMM_FLOAT_BLOCK;
#endif
const int end = tag_num / step_size;
const int rest = tag_num % step_size;
/* Setup the alpha initial value.*/
int i_offset = 0;
int last_offset = rest - step_size;
for (int i = 0; i <= end; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps(w + i_offset);
x_content = _mm512_loadu_ps(x + i_offset);
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(alpha_value + i_offset, alpha_content);
#else
// AVX or AVX2
// weights, input and alpha values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps(w + i_offset);
x_content = _mm256_loadu_ps(x + i_offset);
alpha_content = _mm256_add_ps(w_content, x_content);
_mm256_storeu_ps(alpha + i_offset, alpha_content);
#endif
i_offset += step_size;
if (i == end - 1) {
if (rest > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
// Use the column-major strategy to get the location of maximum score.
int seq_offset = 0;
constexpr int state_trans_base_idx = 2;
for (int k = 1; k < seq_len; ++k) {
int j_offset = 0;
for (int j = 0; j <= end; ++j) {
/* Initialize the variables of maximum score and location.*/
#ifdef __AVX512F__
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max());
__m512i max_j = _mm512_setzero_si512();
#else
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
/* Calculate the offset of transition_weights.*/
int trans_offset = state_trans_base_idx * tag_num + j_offset;
for (int i = 0; i < tag_num; ++i) {
/* Initalize the content of alpha variable with related offset.*/
#ifdef __AVX512F__
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i));
/* Obtain the content of weights from un-aligned address.*/
__m512 w_content = _mm512_loadu_ps(w + trans_offset);
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
/* AVX512 instructions.*/
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
/* Update the max_score value.*/
max_score = _mm512_max_ps(max_score, score_v);
#else
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i);
/* Obtain the content of weights from un-aligned address.*/
__m256 w_content = _mm256_loadu_ps(w + trans_offset);
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
/* According to the mask value, update the index of the max_score.*/
#ifdef __AVX2__
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 0); // NOLINT
__m128i hi_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 1); // NOLINT
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
/* Update the max_score value.*/
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
/* Update the alpha and track values. */
#ifdef __AVX512F__
__m512 x_content =
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset);
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset +
this->num_ + j_offset),
max_j);
#else
__m256 x_content = _mm256_loadu_ps(x + seq_offset + tag_num + j_offset);
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(alpha + seq_offset + tag_num + j_offset, max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track + seq_offset + tag_num + j_offset),
max_j);
#endif
/* Calculate the offset of next step*/
j_offset += step_size;
if (j == end - 1) {
if (rest > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
}
bool CRFDecodingKernel::UseMe(const int& d) const {
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else
constexpr int block = YMM_FLOAT_BLOCK;
#endif
return platform::MayIUse(platform::avx) && d >= block;
}
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(
const typename CRFDecodingTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
此差异已折叠。
此差异已折叠。
file(GLOB jit_kernel_mix_cc RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE)
USE_JITKERNEL_MORE(kVSigmoid, mix)
USE_JITKERNEL_MORE(kVTanh, mix)
USE_JITKERNEL_MORE(kLSTMCtHt, mix)
USE_JITKERNEL_MORE(kLSTMC1H1, mix)
USE_JITKERNEL_MORE(kGRUH1, mix)
USE_JITKERNEL_MORE(kGRUHtPart1, mix)
USE_JITKERNEL_MORE(kGRUHtPart2, mix)
此差异已折叠。
此差异已折叠。
cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -73,12 +73,3 @@ if(WITH_GPU) ...@@ -73,12 +73,3 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak)
endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册