提交 46894436 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_embedding_to_phi

...@@ -20,6 +20,7 @@ build/ ...@@ -20,6 +20,7 @@ build/
build_doc/ build_doc/
*.user *.user
*.tmp *.tmp
*.pyc
.vscode .vscode
.idea .idea
......
...@@ -330,6 +330,7 @@ if(WITH_BRPC_RDMA) ...@@ -330,6 +330,7 @@ if(WITH_BRPC_RDMA)
endif() endif()
endif() endif()
if(WITH_GPU) if(WITH_GPU)
include(cuda) include(cuda)
# lite subgraph compilation depends on CUDNN_ROOT, # lite subgraph compilation depends on CUDNN_ROOT,
......
...@@ -99,7 +99,7 @@ endfunction() ...@@ -99,7 +99,7 @@ endfunction()
function(mlir_add_rewriter td_base) function(mlir_add_rewriter td_base)
set(LLVM_TARGET_DEFINITIONS ${td_base}.td) set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") mlir_tablegen(${td_base}.cpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass")
add_public_tablegen_target(${td_base}_IncGen) add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
endfunction() endfunction()
......
...@@ -116,19 +116,19 @@ function(find_fluid_modules TARGET_NAME) ...@@ -116,19 +116,19 @@ function(find_fluid_modules TARGET_NAME)
endif() endif()
endfunction(find_fluid_modules) endfunction(find_fluid_modules)
set_property(GLOBAL PROPERTY PTEN_MODULES "") set_property(GLOBAL PROPERTY PHI_MODULES "")
# find all pten modules is used for paddle static library # find all phi modules is used for paddle static library
# for building inference libs # for building inference libs
function(find_pten_modules TARGET_NAME) function(find_phi_modules TARGET_NAME)
get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE)
string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path}) string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path})
string(FIND "${__target_path}" "phi" pos) string(FIND "${__target_path}" "phi" pos)
if(pos GREATER 1) if(pos GREATER 1)
get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES) get_property(phi_modules GLOBAL PROPERTY PHI_MODULES)
set(pten_modules ${pten_modules} ${TARGET_NAME}) set(phi_modules ${phi_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY PTEN_MODULES "${pten_modules}") set_property(GLOBAL PROPERTY PHI_MODULES "${phi_modules}")
endif() endif()
endfunction(find_pten_modules) endfunction(find_phi_modules)
function(common_link TARGET_NAME) function(common_link TARGET_NAME)
if (WITH_PROFILER) if (WITH_PROFILER)
...@@ -324,7 +324,7 @@ function(cc_library TARGET_NAME) ...@@ -324,7 +324,7 @@ function(cc_library TARGET_NAME)
else() else()
add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) add_library(${TARGET_NAME} STATIC ${cc_library_SRCS})
find_fluid_modules(${TARGET_NAME}) find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME}) find_phi_modules(${TARGET_NAME})
endif() endif()
if(cc_library_DEPS) if(cc_library_DEPS)
# Don't need link libwarpctc.so # Don't need link libwarpctc.so
...@@ -497,7 +497,7 @@ function(nv_library TARGET_NAME) ...@@ -497,7 +497,7 @@ function(nv_library TARGET_NAME)
else() else()
add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) add_library(${TARGET_NAME} STATIC ${nv_library_SRCS})
find_fluid_modules(${TARGET_NAME}) find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME}) find_phi_modules(${TARGET_NAME})
endif() endif()
if (nv_library_DEPS) if (nv_library_DEPS)
add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
...@@ -588,7 +588,7 @@ function(hip_library TARGET_NAME) ...@@ -588,7 +588,7 @@ function(hip_library TARGET_NAME)
else() else()
hip_add_library(${TARGET_NAME} STATIC ${hip_library_SRCS}) hip_add_library(${TARGET_NAME} STATIC ${hip_library_SRCS})
find_fluid_modules(${TARGET_NAME}) find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME}) find_phi_modules(${TARGET_NAME})
endif() endif()
if (hip_library_DEPS) if (hip_library_DEPS)
add_dependencies(${TARGET_NAME} ${hip_library_DEPS}) add_dependencies(${TARGET_NAME} ${hip_library_DEPS})
......
...@@ -224,7 +224,7 @@ copy(inference_lib_dist ...@@ -224,7 +224,7 @@ copy(inference_lib_dist
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/crypto/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/crypto/)
include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io) include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
# copy api headers for pten & custom op # copy api headers for phi & custom op
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/api/ext/*.h SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/api/ext/*.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api/ext/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api/ext/)
...@@ -244,11 +244,11 @@ copy(inference_lib_dist ...@@ -244,11 +244,11 @@ copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
# the header file of pten is copied to the experimental directory, # the header file of phi is copied to the experimental directory,
# the include path of pten needs to be changed to adapt to inference api path # the include path of phi needs to be changed to adapt to inference api path
add_custom_command(TARGET inference_lib_dist POST_BUILD add_custom_command(TARGET inference_lib_dist POST_BUILD
COMMAND ${CMAKE_COMMAND} -P "${PADDLE_SOURCE_DIR}/cmake/pten_header.cmake" COMMAND ${CMAKE_COMMAND} -P "${PADDLE_SOURCE_DIR}/cmake/phi_header.cmake"
COMMENT "Change pten header include path to adapt to inference api path") COMMENT "Change phi header include path to adapt to inference api path")
# CAPI inference library for only inference # CAPI inference library for only inference
set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING
......
...@@ -73,6 +73,12 @@ function(op_library TARGET) ...@@ -73,6 +73,12 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu)
endif() endif()
# rename in KP: .kps -> .cu
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps)
file(COPY ${TARGET}.kps DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.kps ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu)
endif()
if (WITH_NV_JETSON) if (WITH_NV_JETSON)
list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu") list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu")
endif() endif()
...@@ -96,6 +102,12 @@ function(op_library TARGET) ...@@ -96,6 +102,12 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND hip_srcs ${TARGET}.cu) list(APPEND hip_srcs ${TARGET}.cu)
endif() endif()
# rename in KP: .kps -> .cu
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps)
file(COPY ${TARGET}.kps DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.kps ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu)
list(APPEND hip_srcs ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
${PART_CUDA_KERNEL_FILES} PARENT_SCOPE) ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
......
...@@ -51,7 +51,7 @@ function(generate_unify_header DIR_NAME) ...@@ -51,7 +51,7 @@ function(generate_unify_header DIR_NAME)
endforeach() endforeach()
# append header into extension.h # append header into extension.h
string(REPLACE "${PADDLE_SOURCE_DIR}\/" "" header_file "${header_file}") string(REPLACE "${PADDLE_SOURCE_DIR}\/" "" header_file "${header_file}")
file(APPEND ${pten_extension_header_file} "#include \"${header_file}\"\n") file(APPEND ${phi_extension_header_file} "#include \"${header_file}\"\n")
endfunction() endfunction()
# call kernel_declare need to make sure whether the target of input exists # call kernel_declare need to make sure whether the target of input exists
...@@ -81,6 +81,8 @@ function(kernel_declare TARGET_LIST) ...@@ -81,6 +81,8 @@ function(kernel_declare TARGET_LIST)
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./xpu\/") elseif (${kernel_path} MATCHES "./xpu\/")
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./gpudnn\/")
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n")
else () else ()
# deal with device independent kernel, now we use CPU temporaary # deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
...@@ -94,6 +96,7 @@ function(kernel_library TARGET) ...@@ -94,6 +96,7 @@ function(kernel_library TARGET)
set(cpu_srcs) set(cpu_srcs)
set(gpu_srcs) set(gpu_srcs)
set(xpu_srcs) set(xpu_srcs)
set(gpudnn_srcs)
set(selected_rows_srcs) set(selected_rows_srcs)
# parse and save the deps kerenl targets # parse and save the deps kerenl targets
set(all_srcs) set(all_srcs)
...@@ -101,6 +104,8 @@ function(kernel_library TARGET) ...@@ -101,6 +104,8 @@ function(kernel_library TARGET)
set(oneValueArgs SUB_DIR) set(oneValueArgs SUB_DIR)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
set(target_build_flag 1)
cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}" cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN}) "${multiValueArgs}" ${ARGN})
...@@ -123,6 +128,9 @@ function(kernel_library TARGET) ...@@ -123,6 +128,9 @@ function(kernel_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc)
endif() endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu)
list(APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu)
endif()
endif() endif()
if (WITH_XPU) if (WITH_XPU)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc)
...@@ -141,6 +149,7 @@ function(kernel_library TARGET) ...@@ -141,6 +149,7 @@ function(kernel_library TARGET)
list(APPEND all_srcs ${cpu_srcs}) list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs}) list(APPEND all_srcs ${gpu_srcs})
list(APPEND all_srcs ${xpu_srcs}) list(APPEND all_srcs ${xpu_srcs})
list(APPEND all_srcs ${gpudnn_srcs})
foreach(src ${all_srcs}) foreach(src ${all_srcs})
file(READ ${src} target_content) file(READ ${src} target_content)
string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content})
...@@ -166,21 +175,22 @@ function(kernel_library TARGET) ...@@ -166,21 +175,22 @@ function(kernel_library TARGET)
list(LENGTH cpu_srcs cpu_srcs_len) list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len)
list(LENGTH gpudnn_srcs gpudnn_srcs_len)
list(LENGTH selected_rows_srcs selected_rows_srcs_len) list(LENGTH selected_rows_srcs selected_rows_srcs_len)
# Build Target according different src organization # Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0) AND (${common_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) AND
${selected_rows_srcs_len} GREATER 0)) (${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0))
# If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule. # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU) if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif() endif()
elseif (WITH_ROCM) elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif() endif()
else() else()
...@@ -190,14 +200,14 @@ function(kernel_library TARGET) ...@@ -190,14 +200,14 @@ function(kernel_library TARGET)
endif() endif()
endif() endif()
# If there are only specific device srcs, build target using this rule. # If there are only specific device srcs, build target using this rule.
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
if (WITH_GPU) if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
elseif (WITH_ROCM) elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
else() else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
...@@ -234,35 +244,40 @@ function(kernel_library TARGET) ...@@ -234,35 +244,40 @@ function(kernel_library TARGET)
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
else() else()
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") set(target_build_flag 0)
endif() endif()
if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR if (${target_build_flag} EQUAL 1)
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0) ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
# append target into PTEN_KERNELS property ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) # append target into PHI_KERNELS property
set(pten_kernels ${pten_kernels} ${TARGET}) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set_property(GLOBAL PROPERTY PTEN_KERNELS ${pten_kernels}) set(phi_kernels ${phi_kernels} ${TARGET})
endif() set_property(GLOBAL PROPERTY PHI_KERNELS ${phi_kernels})
endif()
# parse kernel name and auto generate kernel declaration # parse kernel name and auto generate kernel declaration
# here, we don't need to check WITH_XXX, because if not WITH_XXX, the # here, we don't need to check WITH_XXX, because if not WITH_XXX, the
# xxx_srcs_len will be equal to 0 # xxx_srcs_len will be equal to 0
if (${common_srcs_len} GREATER 0) if (${common_srcs_len} GREATER 0)
kernel_declare(${common_srcs}) kernel_declare(${common_srcs})
endif() endif()
if (${cpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0)
kernel_declare(${cpu_srcs}) kernel_declare(${cpu_srcs})
endif() endif()
if (${gpu_srcs_len} GREATER 0) if (${gpu_srcs_len} GREATER 0)
kernel_declare(${gpu_srcs}) kernel_declare(${gpu_srcs})
endif() endif()
if (${xpu_srcs_len} GREATER 0) if (${xpu_srcs_len} GREATER 0)
kernel_declare(${xpu_srcs}) kernel_declare(${xpu_srcs})
endif() endif()
if (${selected_rows_srcs_len} GREATER 0) if (${gpudnn_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs}) kernel_declare(${gpudnn_srcs})
endif()
if (${selected_rows_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs})
endif()
endif() endif()
endfunction() endfunction()
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
set(PADDLE_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_install_dir") set(PADDLE_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_install_dir")
function(pten_header_path_compat TARGET_PATH) function(phi_header_path_compat TARGET_PATH)
message(STATUS "pten header path compat processing: ${TARGET_PATH}") message(STATUS "phi header path compat processing: ${TARGET_PATH}")
string(FIND ${TARGET_PATH} "experimental" pos) string(FIND ${TARGET_PATH} "experimental" pos)
if (pos GREATER 1) if (pos GREATER 1)
file(GLOB HEADERS "${TARGET_PATH}/*" "*.h") file(GLOB HEADERS "${TARGET_PATH}/*" "*.h")
...@@ -25,17 +25,17 @@ if (pos GREATER 1) ...@@ -25,17 +25,17 @@ if (pos GREATER 1)
string(REPLACE "paddle/phi/" "paddle/include/experimental/phi/" HEADER_CONTENT "${HEADER_CONTENT}") string(REPLACE "paddle/phi/" "paddle/include/experimental/phi/" HEADER_CONTENT "${HEADER_CONTENT}")
string(REPLACE "paddle/utils/" "paddle/include/experimental/utils/" HEADER_CONTENT "${HEADER_CONTENT}") string(REPLACE "paddle/utils/" "paddle/include/experimental/utils/" HEADER_CONTENT "${HEADER_CONTENT}")
file(WRITE ${header} "${HEADER_CONTENT}") file(WRITE ${header} "${HEADER_CONTENT}")
message(STATUS "pten header path compat processing complete: ${header}") message(STATUS "phi header path compat processing complete: ${header}")
endif() endif()
endforeach() endforeach()
endif() endif()
endfunction() endfunction()
pten_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental) phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental)
pten_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api) phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api)
pten_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api/ext) phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api/ext)
pten_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api/include) phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/api/include)
pten_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/common) phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/common)
# In order to be compatible with the original behavior, the header file name needs to be changed # In order to be compatible with the original behavior, the header file name needs to be changed
file(RENAME ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/extension.h file(RENAME ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/extension.h
......
cc_library(processgroup SRCS ProcessGroup.cc DEPS pten pten_api eager_api) cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
if(WITH_NCCL) if(WITH_NCCL)
cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context pten pten_api eager_api) cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api)
endif() endif()
...@@ -238,7 +238,7 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, ...@@ -238,7 +238,7 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
place, place,
framework::TransToPtenDataType(VarMessageToVarType(msg.data_type()))); framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer // IO Buffer
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
...@@ -281,7 +281,7 @@ void DeserializeSelectedRows( ...@@ -281,7 +281,7 @@ void DeserializeSelectedRows(
tensor->Resize(phi::make_ddim(vec_dim)); tensor->Resize(phi::make_ddim(vec_dim));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
place, place,
framework::TransToPtenDataType(VarMessageToVarType(msg.data_type()))); framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer // IO Buffer
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
unsigned long data_len; // NOLINT unsigned long data_len; // NOLINT
......
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node) set(eager_deps phi phi_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy) set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node) set(generated_deps dygraph_function dygraph_node)
...@@ -10,11 +10,11 @@ endif() ...@@ -10,11 +10,11 @@ endif()
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(accumulation) add_subdirectory(accumulation)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi phi_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api) cc_library(autograd_meta SRCS autograd_meta.cc DEPS phi phi_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils) cc_library(utils SRCS utils.cc DEPS phi phi_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)
add_subdirectory(tests) add_subdirectory(tests)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulator pten pten_api grad_node_info) cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulator phi phi_api grad_node_info)
...@@ -76,13 +76,13 @@ operator()( ...@@ -76,13 +76,13 @@ operator()(
} }
void GradNodeAccumulation::RegisterReduceHook( void GradNodeAccumulation::RegisterReduceHook(
const std::function<void(void)>& hook) { std::shared_ptr<TensorVoidHook>&& hook) {
reduce_hooks_.emplace_back(hook); reduce_hooks_.emplace_back(std::move(hook));
} }
void GradNodeAccumulation::ApplyReduceHooks() { void GradNodeAccumulation::ApplyReduceHooks() {
for (auto& hook : reduce_hooks_) { for (auto& hook : reduce_hooks_) {
hook(); (*hook)();
} }
} }
} // namespace egr } // namespace egr
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
namespace egr { namespace egr {
...@@ -39,7 +40,7 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -39,7 +40,7 @@ class GradNodeAccumulation : public GradNodeBase {
/** /**
* Register ReduceHook * Register ReduceHook
* **/ * **/
void RegisterReduceHook(const std::function<void(void)>& hook); void RegisterReduceHook(std::shared_ptr<TensorVoidHook>&& hook);
/** /**
* Apply ReduceHook here * Apply ReduceHook here
...@@ -54,7 +55,7 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -54,7 +55,7 @@ class GradNodeAccumulation : public GradNodeBase {
const paddle::experimental::Tensor&)> const paddle::experimental::Tensor&)>
retain_grad_hook_; retain_grad_hook_;
std::vector<std::function<void(void)>> reduce_hooks_; std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
}; };
} // namespace egr } // namespace egr
cc_library(scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info) cc_library(scale_node SRCS scale_node.cc DEPS global_utils phi phi_api grad_node_info)
if(NOT ON_INFER) if(NOT ON_INFER)
cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps}) cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
......
...@@ -33,36 +33,36 @@ static void ScaleDeviceDispatch(const phi::DenseTensor& dense_tensor, ...@@ -33,36 +33,36 @@ static void ScaleDeviceDispatch(const phi::DenseTensor& dense_tensor,
phi::DenseTensor* dense_out) { phi::DenseTensor* dense_out) {
switch (dense_tensor.dtype()) { switch (dense_tensor.dtype()) {
case phi::DataType::FLOAT64: { case phi::DataType::FLOAT64: {
phi::ScaleKernel<double, typename paddle::framework::ConvertToPtenContext< phi::ScaleKernel<double, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>( DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext< static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */, dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
case phi::DataType::FLOAT32: { case phi::DataType::FLOAT32: {
phi::ScaleKernel<float, typename paddle::framework::ConvertToPtenContext< phi::ScaleKernel<float, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>( DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext< static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */, dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
case phi::DataType::INT64: { case phi::DataType::INT64: {
phi::ScaleKernel<int64_t, typename paddle::framework:: phi::ScaleKernel<int64_t, typename paddle::framework::ConvertToPhiContext<
ConvertToPtenContext<DeviceContext>::TYPE>( DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext< static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */, dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
case phi::DataType::INT32: { case phi::DataType::INT32: {
phi::ScaleKernel<int32_t, typename paddle::framework:: phi::ScaleKernel<int32_t, typename paddle::framework::ConvertToPhiContext<
ConvertToPtenContext<DeviceContext>::TYPE>( DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext< static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */, dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
......
cc_library(eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node) cc_library(eager_scale SRCS scale.cc DEPS phi_api phi autograd_meta scale_node)
if(NOT ON_INFER) if(NOT ON_INFER)
cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps}) cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
......
cc_library(tensor_utils SRCS tensor_utils.cc DEPS pten pten_api autograd_meta grad_node_info accumulation_node) cc_library(tensor_utils SRCS tensor_utils.cc DEPS phi phi_api autograd_meta grad_node_info accumulation_node)
cc_library(hook_utils SRCS hook_utils.cc DEPS pten tensor_utils autograd_meta grad_node_info utils accumulation_node) cc_library(hook_utils SRCS hook_utils.cc DEPS phi tensor_utils autograd_meta grad_node_info utils accumulation_node)
cc_library(global_utils SRCS global_utils.cc DEPS place tracer) cc_library(global_utils SRCS global_utils.cc DEPS place tracer)
...@@ -22,19 +22,19 @@ ...@@ -22,19 +22,19 @@
namespace egr { namespace egr {
namespace egr_utils_api { namespace egr_utils_api {
void RegisterGradientHookForTensor( int64_t RegisterGradientHookForTensor(
const paddle::experimental::Tensor& tensor, const paddle::experimental::Tensor& tensor,
std::function<paddle::experimental::Tensor( std::shared_ptr<egr::TensorHook>&& hook) {
const paddle::experimental::Tensor&)>& hook) {
// Find grad_node and out_rank from AutogradMeta // Find grad_node and out_rank from AutogradMeta
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor); std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
auto rank_info = EagerUtils::unsafe_autograd_meta(tensor)->OutRankInfo(); auto rank_info = EagerUtils::unsafe_autograd_meta(tensor)->OutRankInfo();
grad_node->RegisterGradientHook(rank_info.first, rank_info.second, hook); return grad_node->RegisterGradientHook(rank_info.first, rank_info.second,
std::move(hook));
} }
void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
const std::function<void(void)>& hook) { std::shared_ptr<egr::TensorVoidHook>&& hook) {
if (IsLeafTensor(tensor)) { if (IsLeafTensor(tensor)) {
VLOG(6) << "Register ReduceHook for leaf tensor"; VLOG(6) << "Register ReduceHook for leaf tensor";
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor); std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
...@@ -45,7 +45,7 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, ...@@ -45,7 +45,7 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
"with type: GradNodeAccumulation")); "with type: GradNodeAccumulation"));
auto accumulation_grad_node = auto accumulation_grad_node =
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node); std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(hook); accumulation_grad_node->RegisterReduceHook(std::move(hook));
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Only can register reduce hook for leaf Tensor.")); "Only can register reduce hook for leaf Tensor."));
...@@ -65,28 +65,27 @@ static void RetainGradForRegularNode( ...@@ -65,28 +65,27 @@ static void RetainGradForRegularNode(
meta->WeakGrad(); meta->WeakGrad();
// Define Hook // Define Hook
std::function<paddle::experimental::Tensor( auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) {
const paddle::experimental::Tensor&)> if (!weak_grad_tensor.expired()) {
hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) { auto grad_tensor = weak_grad_tensor.lock();
if (!weak_grad_tensor.expired()) { if (t.defined()) {
auto grad_tensor = weak_grad_tensor.lock(); VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name();
if (t.defined()) { // Simply Copy impl() to grad_tensor
VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name(); grad_tensor->set_impl(t.impl());
// Simply Copy impl() to grad_tensor return *grad_tensor.get();
grad_tensor->set_impl(t.impl()); } else {
return *grad_tensor.get(); VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
} else { return paddle::experimental::Tensor();
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; }
return paddle::experimental::Tensor(); } else {
} VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
} else { return paddle::experimental::Tensor();
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; }
return paddle::experimental::Tensor(); };
}
};
// Append to GradientHooks // Append to GradientHooks
RegisterGradientHookForTensor(tensor, hook); RegisterGradientHookForTensor(tensor,
std::make_shared<egr::CppTensorHook>(hook));
} }
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
......
...@@ -16,17 +16,17 @@ ...@@ -16,17 +16,17 @@
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
namespace egr { namespace egr {
namespace egr_utils_api { namespace egr_utils_api {
void RegisterGradientHookForTensor( int64_t RegisterGradientHookForTensor(
const paddle::experimental::Tensor& tensor, const paddle::experimental::Tensor& tensor,
std::function<paddle::experimental::Tensor( std::shared_ptr<egr::TensorHook>&& hook);
const paddle::experimental::Tensor&)>& hook);
void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
const std::function<void(void)>& hook); std::shared_ptr<egr::TensorVoidHook>&& hook);
void RetainGradForTensor(const paddle::experimental::Tensor& tensor); void RetainGradForTensor(const paddle::experimental::Tensor& tensor);
} // namespace egr_utils_api } // namespace egr_utils_api
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace egr { namespace egr {
...@@ -43,7 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue( ...@@ -43,7 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue(
bool is_leaf) { bool is_leaf) {
paddle::experimental::Tensor out = paddle::experimental::full( paddle::experimental::Tensor out = paddle::experimental::full(
phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype, phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype,
phi::TransToPtenBackend(place)); phi::TransToPhiBackend(place));
auto meta = EagerUtils::autograd_meta(&out); auto meta = EagerUtils::autograd_meta(&out);
if (is_leaf) { if (is_leaf) {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
// pten // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
#define NUM_CREATED_DUP_INPUTS 4 #define NUM_CREATED_DUP_INPUTS 4
...@@ -544,7 +544,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -544,7 +544,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// since only OperatorWithKernel can run in dygraph mode. // since only OperatorWithKernel can run in dygraph mode.
auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
if (!all_kernels.count(op_type) && if (!all_kernels.count(op_type) &&
!phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) { !phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
return false; return false;
} }
...@@ -2040,12 +2040,13 @@ static std::string GenerateGradNodeCCContents( ...@@ -2040,12 +2040,13 @@ static std::string GenerateGradNodeCCContents(
const char* BWD_RETURN_TEMPLATE = const char* BWD_RETURN_TEMPLATE =
" std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads = " " std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads = "
"egr::GradNodeBase::ApplyGradientHooks(grads);\n" "GradNode%s::ApplyGradientHooks(grads);\n"
" std::vector<std::vector<paddle::experimental::Tensor>> outputs(%d);\n" " std::vector<std::vector<paddle::experimental::Tensor>> outputs(%d);\n"
" %s\n" " %s\n"
" return outputs;\n"; " return outputs;\n";
generated_grad_function_body = paddle::string::Sprintf( generated_grad_function_body =
BWD_RETURN_TEMPLATE, in_vars.size(), generated_grad_function_body); paddle::string::Sprintf(BWD_RETURN_TEMPLATE, fwd_op_type, in_vars.size(),
generated_grad_function_body);
// [Generation] Get Full Grad Function // [Generation] Get Full Grad Function
const char* GRAD_FUNCTION_TEMPLATE = const char* GRAD_FUNCTION_TEMPLATE =
......
...@@ -143,7 +143,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -143,7 +143,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str, fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str) GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}},\n" python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n"
return python_c_function_str, python_c_function_reg_str return python_c_function_str, python_c_function_reg_str
...@@ -197,7 +197,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) { ...@@ -197,7 +197,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
""" """
core_ops_infos_registry = """ core_ops_infos_registry = """
{\"get_final_state_core_ops_args_info\", ,{\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS, (PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"}, \"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\", {\"get_final_state_core_ops_args_type_info\",
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
#pragma once #pragma once
// framework deps // framework deps
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
// pten deps // Phi deps
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_declare.h" #include "paddle/phi/api/lib/api_declare.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
* provide variable in * provide variable in
* paddle::framework::ExecutionContext to support it. We should remove this as * paddle::framework::ExecutionContext to support it. We should remove this as
* soon as we finish our latest * soon as we finish our latest
* Pten Lib, and use paddle::experimental::Tensor instead. * Phi Lib, and use paddle::experimental::Tensor instead.
* *
* Note: Keep this class as clean as possible. * Note: Keep this class as clean as possible.
* This class should only support method declared in * This class should only support method declared in
......
...@@ -210,22 +210,22 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const { ...@@ -210,22 +210,22 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_; return adj_edges_;
} }
void GradNodeBase::RegisterGradientHook( int64_t GradNodeBase::RegisterGradientHook(
size_t slot_id, size_t rank, size_t slot_id, size_t rank, std::shared_ptr<egr::TensorHook>&& hook) {
const std::function<paddle::experimental::Tensor( gradient_hooks_.emplace(next_hook_id_,
const paddle::experimental::Tensor&)>& hook) { std::make_tuple(slot_id, rank, std::move(hook)));
gradient_hooks_.emplace_back(std::make_tuple(slot_id, rank, hook)); return next_hook_id_++;
} }
std::vector<std::vector<paddle::experimental::Tensor>> std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks( GradNodeBase::ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) { const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
std::vector<std::vector<paddle::experimental::Tensor>> outs(tensors.size()); std::vector<std::vector<paddle::experimental::Tensor>> outs(tensors.size());
for (auto& tuple : gradient_hooks_) { for (auto& hook_pair : gradient_hooks_) {
size_t slot_id = std::get<0>(tuple); size_t slot_id = std::get<0>(hook_pair.second);
size_t rank = std::get<1>(tuple); size_t rank = std::get<1>(hook_pair.second);
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& hook = std::get<2>(tuple); auto hook = std::get<2>(hook_pair.second);
PADDLE_ENFORCE(slot_id < tensors.size(), PADDLE_ENFORCE(slot_id < tensors.size(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -242,12 +242,11 @@ GradNodeBase::ApplyGradientHooks( ...@@ -242,12 +242,11 @@ GradNodeBase::ApplyGradientHooks(
slot_out.resize(tensors[slot_id].size()); slot_out.resize(tensors[slot_id].size());
paddle::experimental::Tensor& out = slot_out[rank]; paddle::experimental::Tensor& out = slot_out[rank];
if (!out.defined() || !out.initialized()) { if (!out.defined() || !out.initialized()) {
VLOG(8) << "Run Hook for tensor: " << tensors[slot_id][rank].name(); out = (*hook)(tensors[slot_id][rank]);
out = hook(tensors[slot_id][rank]);
} else { } else {
// If more than one hook is registered, the input to the next hook func // If more than one hook is registered, the input to the next hook func
// should be the output of the previous hook // should be the output of the previous hook
out = hook(out); out = (*hook)(out);
} }
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
namespace egr { namespace egr {
...@@ -135,14 +136,24 @@ class GradNodeBase { ...@@ -135,14 +136,24 @@ class GradNodeBase {
/** /**
* Register GradientHook * Register GradientHook
* **/ * **/
void RegisterGradientHook(size_t slot_id, size_t rank, int64_t RegisterGradientHook(size_t slot_id, size_t rank,
const std::function<paddle::experimental::Tensor( std::shared_ptr<egr::TensorHook>&& hook);
const paddle::experimental::Tensor&)>& hook);
/**
* Remove GradientHook
* **/
bool RemoveGradientHook(const int64_t& hook_id) {
auto remove_cnt = gradient_hooks_.erase(hook_id);
if (remove_cnt == 0) {
return false;
}
return true;
}
/** /**
* Apply GradientHook * Apply GradientHook
* **/ * **/
inline bool GradientHooksRegistered() { return gradient_hooks_.size() != 0; } inline bool GradientHooksRegistered() { return !gradient_hooks_.empty(); }
std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks( std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors); const std::vector<std::vector<paddle::experimental::Tensor>>& tensors);
...@@ -166,12 +177,14 @@ class GradNodeBase { ...@@ -166,12 +177,14 @@ class GradNodeBase {
// Gradient Hooks // Gradient Hooks
// Customer may register a list of hooks which will be called in order during // Customer may register a list of hooks which will be called in order during
// backward // backward
// Each entry consists one pair of <out_rank, std::function> // Each entry consists one pair of
std::vector<std::tuple< // <hook_id, <out_rank, std::shared_ptr<TensorHook>>>
/* slot id */ size_t, /* rank */ size_t, std::map<int64_t, std::tuple<
/* hook */ std::function<paddle::experimental::Tensor( /* slot id */ size_t, /* rank */ size_t,
const paddle::experimental::Tensor&)>>> /* hook */ std::shared_ptr<TensorHook>>>
gradient_hooks_; gradient_hooks_;
int64_t next_hook_id_{0};
}; };
class Edge { class Edge {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/phi/api/include/tensor.h"
namespace egr {
class TensorHook {
public:
virtual ~TensorHook() = default;
virtual paddle::experimental::Tensor operator()(
const paddle::experimental::Tensor& var) = 0;
};
class TensorVoidHook {
public:
virtual ~TensorVoidHook() = default;
virtual void operator()() = 0;
};
class CppTensorHook : public TensorHook {
public:
explicit CppTensorHook(std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>&& fn)
: fn_(std::move(fn)) {}
paddle::experimental::Tensor operator()(
const paddle::experimental::Tensor& var) override {
return fn_(var);
}
private:
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
fn_;
};
class CppTensorVoidHook : public TensorVoidHook {
public:
explicit CppTensorVoidHook(std::function<void()>&& fn) : fn_(std::move(fn)) {}
void operator()() override { return fn_(); }
private:
std::function<void()> fn_;
};
} // namespace egr
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/eager/grad_tensor_holder.h" #include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -116,7 +117,8 @@ TEST(AccumulationNode, Tensor) { ...@@ -116,7 +117,8 @@ TEST(AccumulationNode, Tensor) {
VLOG(6) << "Running Reduce Hook"; VLOG(6) << "Running Reduce Hook";
}; };
node->RegisterReduceHook(reduce_hook_1); node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1));
// operator() // operator()
paddle::experimental::Tensor _ret = node->operator()({{et0}})[0][0]; paddle::experimental::Tensor _ret = node->operator()({{et0}})[0][0];
...@@ -141,7 +143,8 @@ TEST(AccumulationNode, Tensor) { ...@@ -141,7 +143,8 @@ TEST(AccumulationNode, Tensor) {
ret_et0_ptr[0] = 100.0; // set to 100.0 ret_et0_ptr[0] = 100.0; // set to 100.0
VLOG(6) << "Running Reduce Hook"; VLOG(6) << "Running Reduce Hook";
}; };
node->RegisterReduceHook(reduce_hook_2); node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_2));
node->ApplyReduceHooks(); node->ApplyReduceHooks();
// Check ApplyReduceHooks result // Check ApplyReduceHooks result
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h" #include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
...@@ -32,7 +33,7 @@ TEST(GradNodeInfo, GradSlotMeta) { ...@@ -32,7 +33,7 @@ TEST(GradNodeInfo, GradSlotMeta) {
CHECK_EQ(grad_slot.Size(), 2); CHECK_EQ(grad_slot.Size(), 2);
} }
TEST(GradNodeInfo, GradNodeBase) { void TestGradNodeBase(bool is_remove_gradient_hook) {
VLOG(6) << "Construct Grad Node"; VLOG(6) << "Construct Grad Node";
auto grad_test_node0 = std::make_shared<eager_test::GradTestNode>( auto grad_test_node0 = std::make_shared<eager_test::GradTestNode>(
/* val */ 5.0, /* in_num */ 2, /* out_num */ 2); /* val */ 5.0, /* in_num */ 2, /* out_num */ 2);
...@@ -112,13 +113,25 @@ TEST(GradNodeInfo, GradNodeBase) { ...@@ -112,13 +113,25 @@ TEST(GradNodeInfo, GradNodeBase) {
VLOG(6) << "Running Gradient Hook"; VLOG(6) << "Running Gradient Hook";
return res; return res;
}; };
grad_test_node0->RegisterGradientHook(0, 0, gradient_hook); int64_t hook_id = grad_test_node0->RegisterGradientHook(
// 5 + 6 0, 0, std::make_shared<egr::CppTensorHook>(gradient_hook));
if (is_remove_gradient_hook) {
// Remove GradientHook
grad_test_node0->RemoveGradientHook(hook_id);
}
// Check results
auto grad_hook_res = grad_test_node0->ApplyGradientHooks(grads); auto grad_hook_res = grad_test_node0->ApplyGradientHooks(grads);
CHECK_EQ( CHECK_EQ(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_hook_res[0][0].impl()) std::dynamic_pointer_cast<phi::DenseTensor>(grad_hook_res[0][0].impl())
->data<float>()[0], ->data<float>()[0],
11.0); is_remove_gradient_hook ? 5.0 : 11.0);
}
TEST(GradNodeInfo, GradNodeBase) {
TestGradNodeBase(true);
TestGradNodeBase(false);
} }
TEST(GradNodeInfo, Edge) { TEST(GradNodeInfo, Edge) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/tests/test_utils.h" #include "paddle/fluid/eager/tests/test_utils.h"
namespace egr { namespace egr {
...@@ -221,10 +222,6 @@ TEST(FwdBwdJoint, GradientHook) { ...@@ -221,10 +222,6 @@ TEST(FwdBwdJoint, GradientHook) {
phi::DataLayout::NCHW, 5.0 /*value*/, true /*is_leaf*/); phi::DataLayout::NCHW, 5.0 /*value*/, true /*is_leaf*/);
egr_utils_api::RetainGradForTensor(tensor); egr_utils_api::RetainGradForTensor(tensor);
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
// 3. Run Forward // 3. Run Forward
// Run Forward Node 0 // Run Forward Node 0
float scale0 = 2.0; float scale0 = 2.0;
...@@ -232,24 +229,27 @@ TEST(FwdBwdJoint, GradientHook) { ...@@ -232,24 +229,27 @@ TEST(FwdBwdJoint, GradientHook) {
paddle::experimental::Tensor out0 = paddle::experimental::Tensor out0 =
egr::scale(tensor, scale0, bias0, true /*bias_after_scale*/, egr::scale(tensor, scale0, bias0, true /*bias_after_scale*/,
true /*trace_backward*/); true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out0); // hook: +5 egr_utils_api::RetainGradForTensor(out0); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out0, hook); // hook: +5 egr_utils_api::RegisterGradientHookForTensor(
out0, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
// Run Forward Node 1 // Run Forward Node 1
float scale1 = 5.0; float scale1 = 5.0;
float bias1 = 10.0; float bias1 = 10.0;
paddle::experimental::Tensor out1 = egr::scale( paddle::experimental::Tensor out1 = egr::scale(
out0, scale1, bias1, true /*bias_after_scale*/, true /*trace_backward*/); out0, scale1, bias1, true /*bias_after_scale*/, true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out1); // hook: +5 egr_utils_api::RetainGradForTensor(out1); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out1, hook); // hook: +5 egr_utils_api::RegisterGradientHookForTensor(
out1, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
// Run Forward Node 2 // Run Forward Node 2
float scale2 = 10.0; float scale2 = 10.0;
float bias2 = 20.0; float bias2 = 20.0;
paddle::experimental::Tensor out2 = egr::scale( paddle::experimental::Tensor out2 = egr::scale(
out0, scale2, bias2, true /*bias_after_scale*/, true /*trace_backward*/); out0, scale2, bias2, true /*bias_after_scale*/, true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out2); // hook: +5 egr_utils_api::RetainGradForTensor(out2); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out2, hook); // hook: +5 egr_utils_api::RegisterGradientHookForTensor(
out2, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
// 4. Run Backward // 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2}; std::vector<paddle::experimental::Tensor> outs = {out1, out2};
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/tests/test_utils.h" #include "paddle/fluid/eager/tests/test_utils.h"
namespace egr { namespace egr {
...@@ -83,9 +84,6 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -83,9 +84,6 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
// Apply RetainGrad // Apply RetainGrad
{ {
// ScaleNode Hook: +3 // ScaleNode Hook: +3
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
auto auto_grad_meta = std::make_shared<AutogradMeta>(); auto auto_grad_meta = std::make_shared<AutogradMeta>();
auto_grad_meta->SetGradNode( auto_grad_meta->SetGradNode(
...@@ -96,7 +94,8 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -96,7 +94,8 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta)); auto_grad_meta));
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(
target_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RetainGradForTensor( egr_utils_api::RetainGradForTensor(
target_tensor); // result: 1.0 + 3.0 = 4.0 target_tensor); // result: 1.0 + 3.0 = 4.0
egr_utils_api::RetainGradForTensor( egr_utils_api::RetainGradForTensor(
...@@ -107,9 +106,6 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -107,9 +106,6 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor(); paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor();
{ {
// AccumulationNode Hook: +3 // AccumulationNode Hook: +3
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
auto auto_grad_meta = std::make_shared<AutogradMeta>(); auto auto_grad_meta = std::make_shared<AutogradMeta>();
...@@ -126,7 +122,8 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -126,7 +122,8 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta)); auto_grad_meta));
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RetainGradForTensor( egr_utils_api::RetainGradForTensor(
leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0 leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0
} }
...@@ -161,9 +158,6 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -161,9 +158,6 @@ TEST(RetainGrad, HookAfterRetainGrad) {
// Apply RetainGrad // Apply RetainGrad
{ {
// ScaleNode Hook: +3 // ScaleNode Hook: +3
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
auto auto_grad_meta = std::make_shared<AutogradMeta>(); auto auto_grad_meta = std::make_shared<AutogradMeta>();
auto_grad_meta->SetGradNode( auto_grad_meta->SetGradNode(
...@@ -175,16 +169,14 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -175,16 +169,14 @@ TEST(RetainGrad, HookAfterRetainGrad) {
auto_grad_meta)); auto_grad_meta));
egr_utils_api::RetainGradForTensor(target_tensor); // result: 1.0 egr_utils_api::RetainGradForTensor(target_tensor); // result: 1.0
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(
target_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
} }
// Retain Grad for leaf tensor1 // Retain Grad for leaf tensor1
paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor(); paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor();
{ {
// AccumulationNode Hook: +3 // AccumulationNode Hook: +3
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
auto auto_grad_meta = std::make_shared<AutogradMeta>(); auto auto_grad_meta = std::make_shared<AutogradMeta>();
auto acc_node_ptr = auto acc_node_ptr =
...@@ -199,7 +191,8 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -199,7 +191,8 @@ TEST(RetainGrad, HookAfterRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta)); auto_grad_meta));
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
} }
RunBackward(target_tensors, {}); RunBackward(target_tensors, {});
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h" #include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace egr { namespace egr {
...@@ -54,7 +55,7 @@ paddle::experimental::Tensor hook_function( ...@@ -54,7 +55,7 @@ paddle::experimental::Tensor hook_function(
return ret; return ret;
} }
TEST(Hook_intermidiate, Sigmoid) { void test_sigmoid(bool is_remove_gradient_hook) {
// Prepare Device Contexts // Prepare Device Contexts
VLOG(6) << "Init Env"; VLOG(6) << "Init Env";
eager_test::InitEnv(paddle::platform::CPUPlace()); eager_test::InitEnv(paddle::platform::CPUPlace());
...@@ -67,11 +68,6 @@ TEST(Hook_intermidiate, Sigmoid) { ...@@ -67,11 +68,6 @@ TEST(Hook_intermidiate, Sigmoid) {
ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 0.0, true); phi::DataLayout::NCHW, 0.0, true);
VLOG(6) << "Make Hook function";
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
VLOG(6) << "Make ReduceHook function"; VLOG(6) << "Make ReduceHook function";
auto reduce_hook = [&](void) -> void { auto reduce_hook = [&](void) -> void {
auto* t_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()) auto* t_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl())
...@@ -85,10 +81,12 @@ TEST(Hook_intermidiate, Sigmoid) { ...@@ -85,10 +81,12 @@ TEST(Hook_intermidiate, Sigmoid) {
egr_utils_api::RetainGradForTensor(tensor); egr_utils_api::RetainGradForTensor(tensor);
VLOG(6) << "Register GradientHook for Tensor"; VLOG(6) << "Register GradientHook for Tensor";
egr_utils_api::RegisterGradientHookForTensor(tensor, hook); int64_t hook_id = egr_utils_api::RegisterGradientHookForTensor(
tensor, std::make_shared<CppTensorHook>(hook_function));
VLOG(6) << "Register ReduceHook for Tensor"; VLOG(6) << "Register ReduceHook for Tensor";
egr_utils_api::RegisterReduceHookForTensor(tensor, reduce_hook); egr_utils_api::RegisterReduceHookForTensor(
tensor, std::make_shared<CppTensorVoidHook>(reduce_hook));
VLOG(6) << "Runing Forward"; VLOG(6) << "Runing Forward";
auto output_tensor = sigmoid_dygraph_function(tensor, {}); auto output_tensor = sigmoid_dygraph_function(tensor, {});
...@@ -98,11 +96,17 @@ TEST(Hook_intermidiate, Sigmoid) { ...@@ -98,11 +96,17 @@ TEST(Hook_intermidiate, Sigmoid) {
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
if (is_remove_gradient_hook) {
std::shared_ptr<GradNodeBase> grad_node_tmp = EagerUtils::grad_node(tensor);
grad_node_tmp->RemoveGradientHook(hook_id);
}
VLOG(6) << "Runing Backward"; VLOG(6) << "Runing Backward";
RunBackward(target_tensors, {}); RunBackward(target_tensors, {});
VLOG(6) << "Finish Backward"; VLOG(6) << "Finish Backward";
eager_test::CompareGradTensorWithValue<float>(tensor, 0.25 + 3); eager_test::CompareGradTensorWithValue<float>(
tensor, is_remove_gradient_hook ? 0.25 : 0.25 + 3.0);
VLOG(6) << "Checking ReduceHook results"; VLOG(6) << "Checking ReduceHook results";
for (int i = 0; i < tensor.numel(); i++) { for (int i = 0; i < tensor.numel(); i++) {
...@@ -113,7 +117,7 @@ TEST(Hook_intermidiate, Sigmoid) { ...@@ -113,7 +117,7 @@ TEST(Hook_intermidiate, Sigmoid) {
VLOG(6) << "After Tests"; VLOG(6) << "After Tests";
} }
TEST(Hook_intermidiate, ElementwiseAdd) { void test_elementwiseAdd(bool is_remove_gradient_hook) {
// Prepare Device Contexts // Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace()); eager_test::InitEnv(paddle::platform::CPUPlace());
...@@ -132,11 +136,7 @@ TEST(Hook_intermidiate, ElementwiseAdd) { ...@@ -132,11 +136,7 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
ddimY, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, ddimY, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 2.0, true); phi::DataLayout::NCHW, 2.0, true);
std::function<paddle::experimental::Tensor( auto reduce_hook = [&]() -> void {
const paddle::experimental::Tensor&)>
hook = &hook_function;
auto reduce_hook = [&](void) -> void {
auto* t_ptr = auto* t_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(Y.impl())->data<float>(); std::dynamic_pointer_cast<phi::DenseTensor>(Y.impl())->data<float>();
for (int i = 0; i < Y.numel(); i++) { for (int i = 0; i < Y.numel(); i++) {
...@@ -145,18 +145,26 @@ TEST(Hook_intermidiate, ElementwiseAdd) { ...@@ -145,18 +145,26 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
}; };
egr_utils_api::RetainGradForTensor(Y); egr_utils_api::RetainGradForTensor(Y);
egr_utils_api::RegisterGradientHookForTensor(Y, hook); int64_t hook_id = egr_utils_api::RegisterGradientHookForTensor(
egr_utils_api::RegisterReduceHookForTensor(Y, reduce_hook); Y, std::make_shared<CppTensorHook>(hook_function));
egr_utils_api::RegisterReduceHookForTensor(
Y, std::make_shared<CppTensorVoidHook>(reduce_hook));
auto output_tensor = elementwise_add_dygraph_function(X, Y, {}); auto output_tensor = elementwise_add_dygraph_function(X, Y, {});
eager_test::CompareTensorWithValue<float>(output_tensor, 5); eager_test::CompareTensorWithValue<float>(output_tensor, 5);
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
if (is_remove_gradient_hook) {
std::shared_ptr<GradNodeBase> grad_node_tmp = EagerUtils::grad_node(Y);
grad_node_tmp->RemoveGradientHook(hook_id);
}
RunBackward(target_tensors, {}); RunBackward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(X, 1.0); eager_test::CompareGradTensorWithValue<float>(X, 1.0);
eager_test::CompareGradTensorWithValue<float>(Y, 4.0); eager_test::CompareGradTensorWithValue<float>(
Y, is_remove_gradient_hook ? 1.0 : 1.0 + 3.0);
// Checking ReduceHook results // Checking ReduceHook results
for (int i = 0; i < Y.numel(); i++) { for (int i = 0; i < Y.numel(); i++) {
...@@ -166,7 +174,7 @@ TEST(Hook_intermidiate, ElementwiseAdd) { ...@@ -166,7 +174,7 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
} }
} }
TEST(Hook_intermidiate, Matmul_v2) { void test_matmul(bool is_remove_gradient_hook) {
// Prepare Device Contexts // Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace()); eager_test::InitEnv(paddle::platform::CPUPlace());
...@@ -185,10 +193,6 @@ TEST(Hook_intermidiate, Matmul_v2) { ...@@ -185,10 +193,6 @@ TEST(Hook_intermidiate, Matmul_v2) {
ddimY, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, ddimY, paddle::platform::CPUPlace(), phi::DataType::FLOAT32,
phi::DataLayout::NCHW, 2.0, true); phi::DataLayout::NCHW, 2.0, true);
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = &hook_function;
auto reduce_hook = [&](void) -> void { auto reduce_hook = [&](void) -> void {
auto* t_ptr = auto* t_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(Y.impl())->data<float>(); std::dynamic_pointer_cast<phi::DenseTensor>(Y.impl())->data<float>();
...@@ -198,19 +202,27 @@ TEST(Hook_intermidiate, Matmul_v2) { ...@@ -198,19 +202,27 @@ TEST(Hook_intermidiate, Matmul_v2) {
}; };
egr_utils_api::RetainGradForTensor(Y); egr_utils_api::RetainGradForTensor(Y);
egr_utils_api::RegisterGradientHookForTensor(Y, hook); int64_t hook_id = egr_utils_api::RegisterGradientHookForTensor(
egr_utils_api::RegisterReduceHookForTensor(Y, reduce_hook); Y, std::make_shared<CppTensorHook>(hook_function));
egr_utils_api::RegisterReduceHookForTensor(
Y, std::make_shared<CppTensorVoidHook>(reduce_hook));
auto output_tensor = matmul_v2_dygraph_function( auto output_tensor = matmul_v2_dygraph_function(
X, Y, {{"trans_x", false}, {"trans_y", false}}); X, Y, {{"trans_x", false}, {"trans_y", false}});
eager_test::CompareTensorWithValue<float>(output_tensor, 96); eager_test::CompareTensorWithValue<float>(output_tensor, 96);
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor}; std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
if (is_remove_gradient_hook) {
std::shared_ptr<GradNodeBase> grad_node_tmp = EagerUtils::grad_node(Y);
grad_node_tmp->RemoveGradientHook(hook_id);
}
RunBackward(target_tensors, {}); RunBackward(target_tensors, {});
eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20); eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20);
eager_test::CompareGradTensorWithValue<float>(Y, 3.0 * 4 + 3); eager_test::CompareGradTensorWithValue<float>(
Y, is_remove_gradient_hook ? 3.0 * 4 : 3.0 * 4 + 3);
// Checking ReduceHook results // Checking ReduceHook results
for (int i = 0; i < Y.numel(); i++) { for (int i = 0; i < Y.numel(); i++) {
...@@ -219,6 +231,22 @@ TEST(Hook_intermidiate, Matmul_v2) { ...@@ -219,6 +231,22 @@ TEST(Hook_intermidiate, Matmul_v2) {
static_cast<float>(100.0f)); static_cast<float>(100.0f));
} }
} }
TEST(Hook_intermidiate, Sigmoid) {
// True or false represents whether to call RemoveGradientHook
test_sigmoid(true);
test_sigmoid(false);
}
TEST(Hook_intermidiate, ElementwiseAdd) {
test_elementwiseAdd(true);
test_elementwiseAdd(false);
}
TEST(Hook_intermidiate, Matmul_v2) {
test_matmul(true);
test_matmul(false);
}
} // namespace egr } // namespace egr
USE_OP(sigmoid); USE_OP(sigmoid);
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true, PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true,
......
...@@ -193,19 +193,19 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va ...@@ -193,19 +193,19 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
IF(WITH_XPU) IF(WITH_XPU)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info xpu_op_list) cc_library(phi_utils SRCS phi_utils.cc DEPS lod_tensor selected_rows_utils place phi var_type_traits phi_api_utils op_info xpu_op_list)
ELSE() ELSE()
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) cc_library(phi_utils SRCS phi_utils.cc DEPS lod_tensor selected_rows_utils place phi var_type_traits phi_api_utils op_info)
ENDIF() ENDIF()
IF(WITH_XPU) IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils op_utils) phi phi_utils kernel_factory infershape_utils op_utils)
ELSE() ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils op_utils) phi phi_utils kernel_factory infershape_utils op_utils)
ENDIF() ENDIF()
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -412,7 +412,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) ...@@ -412,7 +412,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place phi var_type_traits phi phi_api_utils op_info shape_inference)
cc_test(infershape_utils_test SRCS infershape_utils_test.cc DEPS infershape_utils infermeta_utils meta_tensor) cc_test(infershape_utils_test SRCS infershape_utils_test.cc DEPS infershape_utils infermeta_utils meta_tensor)
# Get the current working branch # Get the current working branch
...@@ -436,8 +436,8 @@ message(STATUS "branch: ${PADDLE_BRANCH}") ...@@ -436,8 +436,8 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h) configure_file(commit.h.in commit.h)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_meta_info pten_api) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper phi_tensor op_meta_info phi_api)
cc_library(custom_kernel SRCS custom_kernel.cc DEPS op_registry pten_custom_kernel pten_tensor_raw) cc_library(custom_kernel SRCS custom_kernel.cc DEPS op_registry phi_custom_kernel phi_tensor_raw)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) #cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
...@@ -450,7 +450,7 @@ if(WITH_TESTING AND TEST selected_rows_utils_test) ...@@ -450,7 +450,7 @@ if(WITH_TESTING AND TEST selected_rows_utils_test)
endif() endif()
cc_test(scope_guard_test SRCS scope_guard_test.cc) cc_test(scope_guard_test SRCS scope_guard_test.cc)
cc_test(pten_utils_test SRCS pten_utils_test.cc DEPS pten_utils) cc_test(phi_utils_test SRCS phi_utils_test.cc DEPS phi_utils)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
......
...@@ -33,7 +33,7 @@ limitations under the License. */ ...@@ -33,7 +33,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
namespace paddle { namespace paddle {
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
paddle::experimental::DataType TransToPtenDataType( paddle::experimental::DataType TransToPhiDataType(
const paddle::framework::proto::VarType::Type& dtype) { const paddle::framework::proto::VarType::Type& dtype) {
// Set the order of case branches according to the frequency with // Set the order of case branches according to the frequency with
// the data type is used // the data type is used
......
...@@ -32,7 +32,7 @@ namespace framework { ...@@ -32,7 +32,7 @@ namespace framework {
using DataType = paddle::experimental::DataType; using DataType = paddle::experimental::DataType;
using DataLayout = paddle::experimental::DataLayout; using DataLayout = paddle::experimental::DataLayout;
DataType TransToPtenDataType( DataType TransToPhiDataType(
const paddle::framework::proto::VarType::Type& dtype); const paddle::framework::proto::VarType::Type& dtype);
paddle::framework::proto::VarType::Type TransToProtoVarType( paddle::framework::proto::VarType::Type TransToProtoVarType(
......
...@@ -43,35 +43,35 @@ TEST(ConvertUtils, DataType) { ...@@ -43,35 +43,35 @@ TEST(ConvertUtils, DataType) {
CHECK(paddle::framework::TransToProtoVarType(paddle::DataType::FLOAT16) == CHECK(paddle::framework::TransToProtoVarType(paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16); paddle::framework::proto::VarType::FP16);
// proto -> enum // proto -> enum
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::FP64) == paddle::framework::proto::VarType::FP64) ==
paddle::DataType::FLOAT64); paddle::DataType::FLOAT64);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::FP32) == paddle::framework::proto::VarType::FP32) ==
paddle::DataType::FLOAT32); paddle::DataType::FLOAT32);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT64) == paddle::framework::proto::VarType::INT64) ==
paddle::DataType::INT64); paddle::DataType::INT64);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT32) == paddle::framework::proto::VarType::INT32) ==
paddle::DataType::INT32); paddle::DataType::INT32);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT8) == paddle::DataType::INT8); paddle::framework::proto::VarType::INT8) == paddle::DataType::INT8);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::UINT8) == paddle::framework::proto::VarType::UINT8) ==
paddle::DataType::UINT8); paddle::DataType::UINT8);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT16) == paddle::framework::proto::VarType::INT16) ==
paddle::DataType::INT16); paddle::DataType::INT16);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL); paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::COMPLEX64) == paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64); paddle::DataType::COMPLEX64);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::COMPLEX128) == paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128); paddle::DataType::COMPLEX128);
CHECK(paddle::framework::TransToPtenDataType( CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::FP16) == paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16); paddle::DataType::FLOAT16);
} }
......
...@@ -30,7 +30,7 @@ limitations under the License. */ ...@@ -30,7 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_meta_info_helper.h" #include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -779,13 +779,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -779,13 +779,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) { for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i); auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back( vec_custom_dtype.emplace_back(
paddle::framework::TransToPtenDataType(dtype)); paddle::framework::TransToPhiDataType(dtype));
} }
vec_input_dtypes.emplace_back(vec_custom_dtype); vec_input_dtypes.emplace_back(vec_custom_dtype);
} else { } else {
auto dtype = ctx->GetInputDataType(in_name); auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back( input_dtypes.emplace_back(
paddle::framework::TransToPtenDataType(dtype)); paddle::framework::TransToPhiDataType(dtype));
} }
} }
......
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -28,7 +28,7 @@ TEST(DataType, float16) { ...@@ -28,7 +28,7 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::TransToPtenDataType(dtype)); tensor.mutable_data(cpu, f::TransToPhiDataType(dtype));
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(f::TransToProtoVarType(tensor.dtype()), EXPECT_EQ(f::TransToProtoVarType(tensor.dtype()),
...@@ -51,7 +51,7 @@ TEST(DataType, bfloat16) { ...@@ -51,7 +51,7 @@ TEST(DataType, bfloat16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::TransToPtenDataType(dtype)); tensor.mutable_data(cpu, f::TransToPhiDataType(dtype));
// test bf16 tensor // test bf16 tensor
EXPECT_EQ(f::TransToProtoVarType(tensor.dtype()), EXPECT_EQ(f::TransToProtoVarType(tensor.dtype()),
......
...@@ -231,6 +231,8 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -231,6 +231,8 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
OpHandleBase *op, OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) { const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_; ++remaining_;
platform::RecordEvent("WorkQueue::AddTask",
platform::TracerEventType::UserDefined, 10 /*level*/);
this->pool_->enqueue([=] { this->pool_->enqueue([=] {
std::deque<OpHandleBase *> op_queue; std::deque<OpHandleBase *> op_queue;
op_queue.push_front(op); op_queue.push_front(op);
......
...@@ -34,7 +34,7 @@ limitations under the License. */ ...@@ -34,7 +34,7 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -161,7 +161,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope, ...@@ -161,7 +161,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope,
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
place, framework::TransToPtenDataType(ToVarType(req_var.data_type()))); place, framework::TransToPhiDataType(ToVarType(req_var.data_type())));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
memory::Copy(place, tensor_data, platform::CPUPlace(), req_var.data().data(), memory::Copy(place, tensor_data, platform::CPUPlace(), req_var.data().data(),
...@@ -202,7 +202,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope, ...@@ -202,7 +202,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope,
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
place, framework::TransToPtenDataType(ToVarType(req_var.data_type()))); place, framework::TransToPhiDataType(ToVarType(req_var.data_type())));
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
memory::Copy(place, tensor_data, platform::CPUPlace(), req_var.data().data(), memory::Copy(place, tensor_data, platform::CPUPlace(), req_var.data().data(),
......
...@@ -38,7 +38,7 @@ void SetMicroId(paddle::framework::Scope* scope, ...@@ -38,7 +38,7 @@ void SetMicroId(paddle::framework::Scope* scope,
std::vector<int> dims{1}; std::vector<int> dims{1};
tensor->Resize(phi::make_ddim(dims)); tensor->Resize(phi::make_ddim(dims));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
place, framework::TransToPtenDataType(framework::proto::VarType::FP32)); place, framework::TransToPhiDataType(framework::proto::VarType::FP32));
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::vector<char> temp; std::vector<char> temp;
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
...@@ -144,7 +144,7 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -144,7 +144,7 @@ class CompatMetaTensor : public phi::MetaTensor {
} }
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); auto* var = BOOST_GET_CONST(VarDesc*, var_);
return paddle::framework::TransToPtenDataType(var->GetDataType()); return paddle::framework::TransToPhiDataType(var->GetDataType());
} }
} }
...@@ -341,24 +341,37 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -341,24 +341,37 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
if (infershape_inputs.size() != 1) { if (infershape_inputs.size() != 1) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVarList(vars))); std::move(experimental::MakePhiScalarArrayFromVarList(vars)));
} else { } else {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVar(*vars[0]))); std::move(experimental::MakePhiScalarArrayFromVar(*vars[0])));
} }
} else { } else {
// If is not in runtime, we will set default value(-1) for ScalarArray // If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 1; int64_t num_ele = 0;
std::vector<VarDesc*> vars; std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size()); vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) { for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
} }
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape(); if (vars.size() == 1) {
num_ele = 1;
const auto& tensor_dims = vars[0]->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) { for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i]; num_ele *= tensor_dims[i];
} }
} else {
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
PADDLE_ENFORCE_EQ(tensor_dims.size(), 1,
platform::errors::InvalidArgument(
"The shape is constructed by multi-tensor, "
"every tensor's dims should be 1. But your "
"shape has tensor that dims is %s.",
tensor_dims.size()));
num_ele += tensor_dims[0];
}
} }
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1)); phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true); tensor_attr.SetFromTensor(true);
...@@ -406,7 +419,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -406,7 +419,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarFromVar(*var))); std::move(experimental::MakePhiScalarFromVar(*var)));
} else { } else {
phi::Scalar tensor_scalar(-1); phi::Scalar tensor_scalar(-1);
tensor_scalar.SetFromTensor(true); tensor_scalar.SetFromTensor(true);
...@@ -468,7 +481,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -468,7 +481,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
BOOST_GET_CONST(std::vector<std::string>, attr)); BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) { std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPtenDataType( auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
infer_meta_context.EmplaceBackAttr(data_type); infer_meta_context.EmplaceBackAttr(data_type);
......
...@@ -276,13 +276,13 @@ bool FuseOptimizerOpPass::OpWithKernelSupportCPUAndGPU( ...@@ -276,13 +276,13 @@ bool FuseOptimizerOpPass::OpWithKernelSupportCPUAndGPU(
bool support_gpu = false; bool support_gpu = false;
auto &kernel_factory = phi::KernelFactory::Instance(); auto &kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map = auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPtenKernelName(op_type)); kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
bool has_op_kernel = kernel_key_map.size() > 0 ? true : false; bool has_op_kernel = kernel_key_map.size() > 0 ? true : false;
for (auto &kernel : kernel_key_map) { for (auto &kernel : kernel_key_map) {
if (platform::is_gpu_place(phi::TransToPtenPlace(kernel.first.backend()))) { if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
support_gpu = true; support_gpu = true;
} else if (platform::is_cpu_place( } else if (platform::is_cpu_place(
phi::TransToPtenPlace(kernel.first.backend()))) { phi::TransToPhiPlace(kernel.first.backend()))) {
support_cpu = true; support_cpu = true;
} }
} }
......
...@@ -96,7 +96,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -96,7 +96,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>(); auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(proto::VarType::FP32), 1); framework::TransToPhiDataType(proto::VarType::FP32), 1);
} }
void MainTest(bool convWithExistingBias) { void MainTest(bool convWithExistingBias) {
......
...@@ -126,7 +126,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -126,7 +126,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>(); auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(proto::VarType::FP32), 1); framework::TransToPhiDataType(proto::VarType::FP32), 1);
} }
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog, void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
......
...@@ -526,7 +526,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -526,7 +526,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>(); auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(proto::VarType::FP32), 1); framework::TransToPhiDataType(proto::VarType::FP32), 1);
} }
void PrepareGraph(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog) { void PrepareGraph(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog) {
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
USE_OP(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
......
...@@ -447,7 +447,7 @@ void MergeLoDTensor(LoDTensor *target, ...@@ -447,7 +447,7 @@ void MergeLoDTensor(LoDTensor *target,
target->set_layout(new_layout); target->set_layout(new_layout);
target->set_lod(new_lod); target->set_lod(new_lod);
target->mutable_data(dst_place, target->mutable_data(dst_place,
paddle::framework::TransToPtenDataType(new_type)); paddle::framework::TransToPhiDataType(new_type));
int begin = 0; int begin = 0;
for (auto *src : lod_tensors) { for (auto *src : lod_tensors) {
......
...@@ -389,7 +389,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -389,7 +389,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{ {
platform::RecordEvent infershape_event( platform::RecordEvent infershape_event(
"InferShape", platform::TracerEventType::OperatorInner, 1, "infer_shape", platform::TracerEventType::OperatorInner, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
// If it is OperatorBase, InferShape do nothing. // If it is OperatorBase, InferShape do nothing.
if (op_with_kernel != nullptr) if (op_with_kernel != nullptr)
...@@ -411,23 +411,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -411,23 +411,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
{ {
platform::RecordEvent compute_event( platform::RecordEvent compute_event(
"Compute", platform::TracerEventType::OperatorInner, 1, "compute", platform::TracerEventType::OperatorInner, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (op_with_kernel == nullptr) { if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_); instr_node.OpBase()->Run(*local_scope, place_);
} else { } else {
// fit for pten // fit for phi
if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) { if (instr_node.PhiKernel() && instr_node.PhiKernel()->IsValid()) {
VLOG(4) << "Run pten kernel: " << op->Type(); VLOG(4) << "Run phi kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " " VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext(); << &instr_node.DeviceContext();
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext( op_with_kernel->BuildPhiKernelContext(
*instr_node.InnerRuntimeContext().get(), *instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()), const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&pt_kernel_context); &pt_kernel_context);
(*instr_node.PtenKernel())(&pt_kernel_context); (*instr_node.PhiKernel())(&pt_kernel_context);
} else { } else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
...@@ -561,7 +561,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -561,7 +561,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
<< " runs on " << platform::GetCurrentThreadName(); << " runs on " << platform::GetCurrentThreadName();
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
platform::RecordEvent instruction_event(op->Type().c_str()); platform::RecordEvent instruction_event(
op->Type(), platform::TracerEventType::Operator, 1);
interpreter::WaitEvent(instr_node, place_); interpreter::WaitEvent(instr_node, place_);
try { try {
......
...@@ -407,14 +407,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -407,14 +407,14 @@ void build_op_func_list(const platform::Place& place,
auto exec_ctx = auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto run_pten_kernel = false; auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel( if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) { op_with_kernel->Type())) {
auto pt_kernel_key = op_with_kernel->ChoosePtenKernel(exec_ctx); auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
auto pt_kernel_name = op_with_kernel->PtenKernelSignature()->name; auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
if (op_with_kernel->PtenKernel()->IsValid()) { if (op_with_kernel->PhiKernel()->IsValid()) {
run_pten_kernel = true; run_phi_kernel = true;
} else { } else {
auto kernels_iter = all_op_kernels.find(op_with_kernel->Type()); auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
if (kernels_iter == all_op_kernels.end() || if (kernels_iter == all_op_kernels.end() ||
...@@ -422,26 +422,26 @@ void build_op_func_list(const platform::Place& place, ...@@ -422,26 +422,26 @@ void build_op_func_list(const platform::Place& place,
kernels_iter->second.end()) { kernels_iter->second.end()) {
auto pt_cpu_kernel_key = FallBackToCpu( auto pt_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, pt_kernel_key, *op_with_kernel); expected_kernel_key, pt_kernel_key, *op_with_kernel);
op_with_kernel->ResetPtenKernel( op_with_kernel->ResetPhiKernel(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key))); pt_kernel_name, pt_cpu_kernel_key)));
if (op_with_kernel->PtenKernel()->IsValid()) { if (op_with_kernel->PhiKernel()->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: " VLOG(6) << "Static mode PrepareImpl - kernel name: "
<< pt_kernel_name << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *(op_with_kernel->PtenKernel()); << " | kernel: " << *(op_with_kernel->PhiKernel());
run_pten_kernel = true; run_phi_kernel = true;
} }
} }
} }
} }
VLOG(3) << op_with_kernel->Type() VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key; << " : expected_kernel_key : " << expected_kernel_key;
if (run_pten_kernel) { if (run_phi_kernel) {
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx, op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
&pt_kernel_context); &pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
(*op_func_node.pt_kernel_)(&pt_kernel_context); (*op_func_node.pt_kernel_)(&pt_kernel_context);
} else { } else {
......
...@@ -688,9 +688,7 @@ OpKernelComputeFunc Instruction::KernelFunc() const { ...@@ -688,9 +688,7 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_; return op_func_node_.kernel_func_;
} }
phi::Kernel* Instruction::PtenKernel() const { phi::Kernel* Instruction::PhiKernel() const { return op_func_node_.pt_kernel_; }
return op_func_node_.pt_kernel_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
......
...@@ -300,7 +300,7 @@ struct OpFuncNode { ...@@ -300,7 +300,7 @@ struct OpFuncNode {
OpKernelComputeFunc kernel_func_; OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned platform::DeviceContext* dev_ctx_; // not owned
// fit for pten kernel // fit for phi kernel
phi::Kernel* pt_kernel_{nullptr}; // not owned phi::Kernel* pt_kernel_{nullptr}; // not owned
OpFuncType type_; OpFuncType type_;
...@@ -321,7 +321,7 @@ class Instruction { ...@@ -321,7 +321,7 @@ class Instruction {
OpKernelComputeFunc KernelFunc() const; OpKernelComputeFunc KernelFunc() const;
phi::Kernel* PtenKernel() const; phi::Kernel* PhiKernel() const;
OpFuncType KernelType() const; OpFuncType KernelType() const;
......
...@@ -44,7 +44,6 @@ class ThreadDataRegistry { ...@@ -44,7 +44,6 @@ class ThreadDataRegistry {
template <typename Alias = T, template <typename Alias = T,
typename = std::enable_if_t<std::is_copy_assignable<Alias>::value>> typename = std::enable_if_t<std::is_copy_assignable<Alias>::value>>
void SetCurrentThreadData(const T& val) { void SetCurrentThreadData(const T& val) {
std::lock_guard<std::mutex> lock(lock_);
CurrentThreadData() = val; CurrentThreadData() = val;
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h" #include "paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -61,6 +62,8 @@ class WorkQueueImpl : public WorkQueue { ...@@ -61,6 +62,8 @@ class WorkQueueImpl : public WorkQueue {
} }
void AddTask(std::function<void()> fn) override { void AddTask(std::function<void()> fn) override {
platform::RecordEvent("WorkQueue::AddTask",
platform::TracerEventType::UserDefined, 10 /*level*/);
if (tracker_ != nullptr) { if (tracker_ != nullptr) {
fn = [ fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_) task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
...@@ -156,6 +159,8 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() { ...@@ -156,6 +159,8 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
} }
void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) { void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
platform::RecordEvent("WorkQueue::AddTask",
platform::TracerEventType::UserDefined, 10 /*level*/);
assert(queue_idx < queues_.size()); assert(queue_idx < queues_.size());
if (queues_options_.at(queue_idx).track_task) { if (queues_options_.at(queue_idx).track_task) {
fn = [ fn = [
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
...@@ -263,11 +263,11 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -263,11 +263,11 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
// in order to record different op type cost time // in order to record different op type cost time
// and different op name cost time,we set two event. // and different op name cost time,we set two event.
platform::RecordEvent op_type_record_event( platform::RecordEvent op_type_record_event(
Type().c_str(), platform::TracerEventType::Operator, 1); Type(), platform::TracerEventType::Operator, 1);
auto op_name = platform::OpName(outputs_, Type()); // auto op_name = platform::OpName(outputs_, Type());
platform::RecordEvent op_name_record_event( // platform::RecordEvent op_name_record_event(
op_name, platform::TracerEventType::Operator, 1, // op_name, platform::TracerEventType::Operator, 1,
platform::EventRole::kUniqueOp); // platform::EventRole::kUniqueOp);
RunImpl(scope, place); RunImpl(scope, place);
} }
...@@ -616,9 +616,9 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -616,9 +616,9 @@ bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first // check in new Function kernel first
auto& kernel_factory = phi::KernelFactory::Instance(); auto& kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map = auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPtenKernelName(op_type)); kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
for (auto& kernel : kernel_key_map) { for (auto& kernel : kernel_key_map) {
if (platform::is_gpu_place(phi::TransToPtenPlace(kernel.first.backend()))) { if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
return true; return true;
} }
} }
...@@ -1186,10 +1186,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1186,10 +1186,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// phase // phase
phi::KernelKey pt_kernel_key; phi::KernelKey pt_kernel_key;
std::string pt_kernel_name; std::string pt_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset( pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx)))); new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset( kernel_type_.reset(
...@@ -1197,17 +1197,17 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1197,17 +1197,17 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
pt_kernel_name = pt_kernel_signature_->name; pt_kernel_name = pt_kernel_signature_->name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get()); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset( pt_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key))); pt_kernel_name, pt_kernel_key)));
if (pt_kernel_->IsValid()) { if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode ChoosePtenKernel - kernel name: " VLOG(6) << "Static mode ChoosePhiKernel - kernel name: "
<< pt_kernel_name << " | kernel key: " << pt_kernel_key << pt_kernel_name << " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_; << " | kernel: " << *pt_kernel_;
} else { } else {
VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name VLOG(6) << "Static mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
} }
} }
...@@ -1222,7 +1222,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1222,7 +1222,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
&& !is_xpu_unsupport && !is_xpu_unsupport
#endif #endif
) { ) {
run_pten_kernel_ = true; run_phi_kernel_ = true;
} else { } else {
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_); auto kernels_iter = all_op_kernels.find(type_);
...@@ -1244,12 +1244,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1244,12 +1244,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *pt_kernel_; << " | kernel: " << *pt_kernel_;
run_pten_kernel_ = true; run_phi_kernel_ = true;
} }
} }
} }
} }
if (!run_pten_kernel_) { if (!run_phi_kernel_) {
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(exe_ctx); ChooseKernel(exe_ctx);
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
...@@ -1290,13 +1290,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1290,13 +1290,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute", platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
if (run_pten_kernel_) { if (run_phi_kernel_) {
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext // Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack // TODO(zhiqiu): support TransferInplaceVarsBack
PreparePtenData(exec_scope, *pt_kernel_, *pt_kernel_signature_, PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_,
runtime_ctx); runtime_ctx);
BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context); (*pt_kernel_)(&pt_kernel_context);
} else { } else {
(*kernel_func_)( (*kernel_func_)(
...@@ -1388,26 +1388,26 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1388,26 +1388,26 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
return expected_kernel_key; return expected_kernel_key;
} }
phi::KernelKey OperatorWithKernel::ChoosePtenKernel( phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
pt_kernel_signature_.reset( pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx)))); new KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto pt_kernel_name = pt_kernel_signature_->name; auto pt_kernel_name = pt_kernel_signature_->name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get()); auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key))); pt_kernel_name, pt_kernel_key)));
if (pt_kernel_->IsValid()) { if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode ChoosePtenKernel - kernel name: " << pt_kernel_name VLOG(6) << "Static mode ChoosePhiKernel - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key << " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_; << " | kernel: " << *pt_kernel_;
} else { } else {
VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name VLOG(6) << "Static mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
} }
return pt_kernel_key; return pt_kernel_key;
...@@ -1918,7 +1918,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -1918,7 +1918,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout()); tensor.layout());
} }
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
InitDefaultKernelSignatureMap(); InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
...@@ -1926,7 +1926,7 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( ...@@ -1926,7 +1926,7 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
arg_mapping_ctx); arg_mapping_ctx);
} }
Scope* OperatorWithKernel::PreparePtenData( Scope* OperatorWithKernel::PreparePhiData(
const Scope& scope, const phi::Kernel& pt_kernel, const Scope& scope, const phi::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const { const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
auto& input_names = std::get<0>(pt_kernel_signature.args); auto& input_names = std::get<0>(pt_kernel_signature.args);
...@@ -1981,12 +1981,12 @@ Scope* OperatorWithKernel::PreparePtenData( ...@@ -1981,12 +1981,12 @@ Scope* OperatorWithKernel::PreparePtenData(
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue; continue;
} }
auto expected_place = phi::TransToPtenPlace(in_def.backend); auto expected_place = phi::TransToPhiPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) { if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue; continue;
} }
VLOG(3) << "PTen Transform Variable " << input_names[i] << " from " VLOG(3) << "phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place; << tensor_in->place() << " to " << expected_place;
if (!new_scope) { if (!new_scope) {
...@@ -2007,7 +2007,7 @@ Scope* OperatorWithKernel::PreparePtenData( ...@@ -2007,7 +2007,7 @@ Scope* OperatorWithKernel::PreparePtenData(
return new_scope; return new_scope;
} }
void OperatorWithKernel::BuildPtenKernelContext( void OperatorWithKernel::BuildPhiKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx, const RuntimeContext& ctx, platform::DeviceContext* dev_ctx,
phi::KernelContext* pt_kernel_context) const { phi::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx); pt_kernel_context->SetDeviceContext(dev_ctx);
...@@ -2111,7 +2111,7 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2111,7 +2111,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
experimental::ResetTensorDtypeAndLayoutByArgDef(tensor_out, experimental::ResetTensorDtypeAndLayoutByArgDef(tensor_out,
output_defs.at(i)); output_defs.at(i));
SetAllocationForOutputTenosr( SetAllocationForOutputTenosr(
tensor_out, phi::TransToPtenPlace(output_defs.at(i).backend)); tensor_out, phi::TransToPhiPlace(output_defs.at(i).backend));
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
...@@ -2145,10 +2145,10 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2145,10 +2145,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
auto& ins_vector = ctx.inputs.at(attr_names[i]); auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front()))); experimental::MakePhiScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList } else { // ShapeTensorList
pt_kernel_context->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector))); experimental::MakePhiScalarArrayFromVarList(ins_vector)));
} }
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
...@@ -2178,8 +2178,8 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2178,8 +2178,8 @@ void OperatorWithKernel::BuildPtenKernelContext(
} }
} else { } else {
auto& ins_vector = ctx.inputs.at(attr_names[i]); auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(
experimental::MakePtenScalarFromVar(*ins_vector.front()))); std::move(experimental::MakePhiScalarFromVar(*ins_vector.front())));
} }
} else { } else {
...@@ -2198,7 +2198,7 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2198,7 +2198,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) { std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPtenDataType( auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
pt_kernel_context->EmplaceBackAttr(data_type); pt_kernel_context->EmplaceBackAttr(data_type);
...@@ -2206,7 +2206,7 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2206,7 +2206,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) == if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end()); vector_int_attr.end());
......
...@@ -30,7 +30,7 @@ limitations under the License. */ ...@@ -30,7 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -423,7 +423,7 @@ class ExecutionContext { ...@@ -423,7 +423,7 @@ class ExecutionContext {
"size(%d).", "size(%d).",
allocation_ptr->size(), phi::product(dim) * sizeof(T))); allocation_ptr->size(), phi::product(dim) * sizeof(T)));
paddle::framework::Tensor temp_tensor(framework::TransToPtenDataType( paddle::framework::Tensor temp_tensor(framework::TransToPhiDataType(
framework::ToDataType(std::type_index(typeid(T))))); framework::ToDataType(std::type_index(typeid(T)))));
temp_tensor.Resize(dim); temp_tensor.Resize(dim);
temp_tensor.ResetHolder(std::move(shared_allocation)); temp_tensor.ResetHolder(std::move(shared_allocation));
...@@ -538,14 +538,14 @@ class OperatorWithKernel : public OperatorBase { ...@@ -538,14 +538,14 @@ class OperatorWithKernel : public OperatorBase {
} }
bool SupportGPU() const override { bool SupportGPU() const override {
auto pten_kernels = phi::KernelFactory::Instance().SelectKernelMap( auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPtenKernelName(type_)); phi::TransToPhiKernelName(type_));
auto has_pten_kernel = auto has_phi_kernel =
std::any_of(pten_kernels.begin(), pten_kernels.end(), std::any_of(phi_kernels.begin(), phi_kernels.end(),
[](phi::KernelKeyMap::const_reference kern_pair) { [](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::GPU; return kern_pair.first.backend() == phi::Backend::GPU;
}); });
if (has_pten_kernel) { if (has_phi_kernel) {
return true; return true;
} else { } else {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
...@@ -558,7 +558,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -558,7 +558,7 @@ class OperatorWithKernel : public OperatorBase {
} }
bool SupportNPU() const override { bool SupportNPU() const override {
// TODO(zhiqiu): support pten if needed? // TODO(zhiqiu): support phi if needed?
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(), return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) { [](OpKernelMap::const_reference kern_pair) {
...@@ -566,7 +566,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -566,7 +566,7 @@ class OperatorWithKernel : public OperatorBase {
}); });
} }
bool SupportMLU() const override { bool SupportMLU() const override {
// TODO(zhiqiu): support pten if needed? // TODO(zhiqiu): support phi if needed?
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(), return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) { [](OpKernelMap::const_reference kern_pair) {
...@@ -603,39 +603,39 @@ class OperatorWithKernel : public OperatorBase { ...@@ -603,39 +603,39 @@ class OperatorWithKernel : public OperatorBase {
return kernel_type_->place_; return kernel_type_->place_;
} }
/* member functions for adapting to pten lib */ /* member functions for adapting to phi lib */
/** In the Tensor calculation library, the new Kernel adopts a clearer and /** In the Tensor calculation library, the new Kernel adopts a clearer and
* more streamlined design. The arguments of the Kernel and the input and * more streamlined design. The arguments of the Kernel and the input and
* output arguments registered in the original OpMaker do not match in some * output arguments registered in the original OpMaker do not match in some
* cases, so we use map to record the arguments required by the kernel. * cases, so we use map to record the arguments required by the kernel.
* When selecting Kernel during Op execution, select the arguments of the * When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtenKernelArgs returned arguments. * original Op according to the GetExpectedPhiKernelArgs returned arguments.
*/ */
phi::KernelSignature GetExpectedPtenKernelArgs( phi::KernelSignature GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const; const ExecutionContext& ctx) const;
/* member functions for adapting to pten lib */ /* member functions for adapting to phi lib */
phi::KernelKey ChoosePtenKernel(const ExecutionContext& ctx) const; phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const;
/** /**
* Transfer data place for pten kernel * Transfer data place for phi kernel
* Is this really needed? * Is this really needed?
*/ */
Scope* PreparePtenData(const Scope& scope, const phi::Kernel& pt_kernel, Scope* PreparePhiData(const Scope& scope, const phi::Kernel& pt_kernel,
const phi::KernelSignature& pt_kernel_signature, const phi::KernelSignature& pt_kernel_signature,
RuntimeContext* ctx) const; RuntimeContext* ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx, void BuildPhiKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
phi::KernelContext* pt_kernel_context) const; phi::KernelContext* pt_kernel_context) const;
phi::KernelSignature* PtenKernelSignature() const { phi::KernelSignature* PhiKernelSignature() const {
return pt_kernel_signature_.get(); return pt_kernel_signature_.get();
} }
phi::Kernel* PtenKernel() const { return pt_kernel_.get(); } phi::Kernel* PhiKernel() const { return pt_kernel_.get(); }
void ResetPtenKernel(phi::Kernel* kernel) const { void ResetPhiKernel(phi::Kernel* kernel) const {
return pt_kernel_.reset(kernel); return pt_kernel_.reset(kernel);
} }
...@@ -692,9 +692,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -692,9 +692,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::mutex cache_update_mutex_; mutable std::mutex cache_update_mutex_;
mutable bool enable_cache_transfer_scope_ = false; mutable bool enable_cache_transfer_scope_ = false;
// NOTE(chenweihang): Similar op members are used to adapt to // NOTE(chenweihang): Similar op members are used to adapt to
// new pten kernel, if there is a better design in the future, // new phi kernel, if there is a better design in the future,
// we may polish the implementation here // we may polish the implementation here
mutable bool run_pten_kernel_ = false; mutable bool run_phi_kernel_ = false;
mutable bool run_kp_kernel = false; mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_; mutable std::unique_ptr<phi::Kernel> pt_kernel_;
......
...@@ -44,11 +44,6 @@ DECLARE_string(deny_cinn_ops); ...@@ -44,11 +44,6 @@ DECLARE_string(deny_cinn_ops);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace paddle2cinn { namespace paddle2cinn {
using framework::ir::Graph; using framework::ir::Graph;
...@@ -398,9 +393,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -398,9 +393,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
kNoNeedBufferFeeds, no_need_buffer_feeds.release()); kNoNeedBufferFeeds, no_need_buffer_feeds.release());
// initialize empty map for kMemOptVarInfoFromMainGraph attribute, // initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass // it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<std::unordered_map< subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
std::string, std::shared_ptr<framework::ir::MemOptVarInfo>>>(
kMemOptVarInfoFromMainGraph);
return subgraph; return subgraph;
} }
......
...@@ -18,6 +18,10 @@ limitations under the License. */ ...@@ -18,6 +18,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace paddle2cinn { namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "cinn_launch"; constexpr char kCinnLaunchOp[] = "cinn_launch";
...@@ -27,6 +31,9 @@ constexpr char kInternalVars[] = "InternalVars"; ...@@ -27,6 +31,9 @@ constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars"; constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] = constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph"; "mem_opt_var_info_from_main_graph";
using Name2VarInfoMap =
std::unordered_map<std::string,
std::shared_ptr<framework::ir::MemOptVarInfo>>;
// A pass named BuildCinnPass, the function of this pass is: // A pass named BuildCinnPass, the function of this pass is:
// //
......
...@@ -255,7 +255,9 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -255,7 +255,9 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_EQ( ASSERT_EQ(
std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()), std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()),
std::unordered_set<Node*>({v0, v1, v2, v4})); std::unordered_set<Node*>({v0, v1, v2, v4}));
ASSERT_EQ(cinn_op->outputs, std::vector<Node*>({v6, v7})); ASSERT_EQ(std::unordered_set<Node*>(cinn_op->outputs.begin(),
cinn_op->outputs.end()),
std::unordered_set<Node*>({v6, v7}));
ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op})); ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op}));
ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op})); ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op}));
......
...@@ -248,10 +248,10 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -248,10 +248,10 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
*compiled_obj = {std::move(graph_compiler), *compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope, std::move(compiled_res.runtime_program), scope,
symbol.var_model_to_program_map()}; symbol.var_model_to_program_map()};
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(
compiled_obj->paddle2cinn_varmap, compiled_obj->scope);
compiled_obj->cached_index = compiled_num; compiled_obj->cached_index = compiled_num;
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(graph,
*compiled_obj);
return compiled_obj; return compiled_obj;
} }
......
...@@ -209,7 +209,7 @@ class CinnGraphSymbolizationTest : public ::testing::Test { ...@@ -209,7 +209,7 @@ class CinnGraphSymbolizationTest : public ::testing::Test {
tensor.Resize(dims); tensor.Resize(dims);
tensor.mutable_data( tensor.mutable_data(
platform::CPUPlace(), platform::CPUPlace(),
framework::TransToPtenDataType(framework::proto::VarType::FP32)); framework::TransToPhiDataType(framework::proto::VarType::FP32));
return tensor; return tensor;
}; };
#define FillFeedList(Name) feed_targets[#Name] = create_tensor(); #define FillFeedList(Name) feed_targets[#Name] = create_tensor();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <sstream> #include <sstream>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -57,17 +57,16 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ...@@ -57,17 +57,16 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
paddle::SmallVector<std::string> attr_names_; paddle::SmallVector<std::string> attr_names_;
}; };
OpKernelType TransPtenKernelKeyToOpKernelType( OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
const phi::KernelKey& kernel_key) {
proto::VarType::Type data_type = proto::VarType::Type data_type =
paddle::framework::TransToProtoVarType(kernel_key.dtype()); paddle::framework::TransToProtoVarType(kernel_key.dtype());
// no need to set current device id here // no need to set current device id here
platform::Place place = phi::TransToPtenPlace(kernel_key.backend(), false); platform::Place place = phi::TransToPhiPlace(kernel_key.backend(), false);
DataLayout data_layout = kernel_key.layout(); DataLayout data_layout = kernel_key.layout();
LibraryType library_type = LibraryType::kPlain; LibraryType library_type = LibraryType::kPlain;
if (kernel_key.backend() == phi::Backend::MKLDNN) { if (kernel_key.backend() == phi::Backend::MKLDNN) {
library_type = LibraryType::kMKLDNN; library_type = LibraryType::kMKLDNN;
} else if (kernel_key.backend() == phi::Backend::CUDNN) { } else if (kernel_key.backend() == phi::Backend::GPUDNN) {
library_type = LibraryType::kCUDNN; library_type = LibraryType::kCUDNN;
} else { } else {
// do nothing // do nothing
...@@ -76,19 +75,19 @@ OpKernelType TransPtenKernelKeyToOpKernelType( ...@@ -76,19 +75,19 @@ OpKernelType TransPtenKernelKeyToOpKernelType(
return OpKernelType(data_type, place, data_layout, library_type); return OpKernelType(data_type, place, data_layout, library_type);
} }
phi::KernelKey TransOpKernelTypeToPtenKernelKey( phi::KernelKey TransOpKernelTypeToPhiKernelKey(
const OpKernelType& kernel_type) { const OpKernelType& kernel_type) {
phi::Backend backend = phi::TransToPtenBackend(kernel_type.place_); phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
if (kernel_type.library_type_ == LibraryType::kMKLDNN) { if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
backend = phi::Backend::MKLDNN; backend = phi::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) { } else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
backend = phi::Backend::CUDNN; backend = phi::Backend::GPUDNN;
} else { } else {
// do // do
} }
paddle::experimental::DataLayout layout = kernel_type.data_layout_; paddle::experimental::DataLayout layout = kernel_type.data_layout_;
paddle::experimental::DataType dtype = paddle::experimental::DataType dtype =
paddle::framework::TransToPtenDataType(kernel_type.data_type_); paddle::framework::TransToPhiDataType(kernel_type.data_type_);
return phi::KernelKey(backend, layout, dtype); return phi::KernelKey(backend, layout, dtype);
} }
...@@ -98,8 +97,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -98,8 +97,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(expected_kernel_key.place_) || if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) { paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "pten missing XPU kernel: " << op.Type() VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << "phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -107,8 +106,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -107,8 +106,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) { if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing NPU kernel: " << op.Type() VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << "phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -116,8 +115,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -116,8 +115,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#endif #endif
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) { if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing MLU kernel: " << op.Type() VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << "phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -132,17 +131,17 @@ KernelArgsNameMakerByOpProto::GetInputArgsNames() { ...@@ -132,17 +131,17 @@ KernelArgsNameMakerByOpProto::GetInputArgsNames() {
auto& in = op_proto_->inputs()[i]; auto& in = op_proto_->inputs()[i];
auto& in_name = in.name(); auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) { if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(6) << "Parse PtenKernel input: skip extra & quant input - " VLOG(6) << "Parse PhiKernel input: skip extra & quant input - "
<< in_name; << in_name;
continue; continue;
} }
// If contains dispensable input, we should override the // If contains dispensable input, we should override the
// OpArgumentMapping method self in phi/ops/compat dir // OpArgumentMapping method self in phi/ops/compat dir
if (in.has_dispensable() && in.dispensable()) { if (in.has_dispensable() && in.dispensable()) {
VLOG(6) << "Parse PtenKernel input: skip dispensable input - " << in_name; VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name;
continue; continue;
} }
VLOG(6) << "Parse PtenKernel input: " << in_name; VLOG(6) << "Parse PhiKernel input: " << in_name;
input_names_.emplace_back(in_name); input_names_.emplace_back(in_name);
} }
return input_names_; return input_names_;
...@@ -154,11 +153,11 @@ KernelArgsNameMakerByOpProto::GetOutputArgsNames() { ...@@ -154,11 +153,11 @@ KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
auto& out = op_proto_->outputs()[i]; auto& out = op_proto_->outputs()[i];
auto& out_name = out.name(); auto& out_name = out.name();
if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) { if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
VLOG(6) << "Parse PtenKernel output: skip extra & quant output - " VLOG(6) << "Parse PhiKernel output: skip extra & quant output - "
<< out_name; << out_name;
continue; continue;
} }
VLOG(6) << "Parse PtenKernel output: " << out_name; VLOG(6) << "Parse PhiKernel output: " << out_name;
output_names_.emplace_back(out_name); output_names_.emplace_back(out_name);
} }
return output_names_; return output_names_;
...@@ -173,17 +172,17 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -173,17 +172,17 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
attr_name == "op_role" || attr_name == "op_role_var" || attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" || attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") { attr_name == "op_device") {
VLOG(6) << "Parse PtenKernel attribute: skip needless attr - " VLOG(6) << "Parse PhiKernel attribute: skip needless attr - "
<< attr_name; << attr_name;
continue; continue;
} }
if ((attr.has_extra() && attr.extra()) || if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) { (attr.has_quant() && attr.quant())) {
VLOG(6) << "Parse PtenKernel attribute: skip extra & quant attr - " VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - "
<< attr_name; << attr_name;
continue; continue;
} }
VLOG(6) << "Parse PtenKernel attribute: " << attr_name; VLOG(6) << "Parse PhiKernel attribute: " << attr_name;
attr_names_.emplace_back(attr_name); attr_names_.emplace_back(attr_name);
} }
...@@ -191,7 +190,7 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -191,7 +190,7 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPtenKernelName(op_proto_->type()), return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(), GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames()); GetOutputArgsNames());
} }
...@@ -203,7 +202,7 @@ void InitDefaultKernelSignatureMap() { ...@@ -203,7 +202,7 @@ void InitDefaultKernelSignatureMap() {
for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) { for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
const auto& op_type = pair.first; const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_; const auto* op_proto = pair.second.proto_;
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
op_proto) { op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto); paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type; VLOG(10) << "Register kernel signature for " << op_type;
......
...@@ -44,9 +44,8 @@ using KernelSignature = phi::KernelSignature; ...@@ -44,9 +44,8 @@ using KernelSignature = phi::KernelSignature;
/* Kernel Key translate */ /* Kernel Key translate */
OpKernelType TransPtenKernelKeyToOpKernelType(const phi::KernelKey& kernel_key); OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
phi::KernelKey TransOpKernelTypeToPtenKernelKey( phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type);
const OpKernelType& kernel_type);
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
const phi::KernelKey& kernel_key, const phi::KernelKey& kernel_key,
const framework::OperatorBase& op); const framework::OperatorBase& op);
...@@ -68,25 +67,25 @@ void SetAllocationForOutputTenosr(phi::TensorBase* tensor, ...@@ -68,25 +67,25 @@ void SetAllocationForOutputTenosr(phi::TensorBase* tensor,
// TODO(Wilber): support others device context. // TODO(Wilber): support others device context.
template <typename T> template <typename T>
struct ConvertToPtenContext { struct ConvertToPhiContext {
using TYPE = T; using TYPE = T;
}; };
template <> template <>
struct ConvertToPtenContext<platform::CPUDeviceContext> { struct ConvertToPhiContext<platform::CPUDeviceContext> {
using TYPE = phi::CPUContext; using TYPE = phi::CPUContext;
}; };
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <> template <>
struct ConvertToPtenContext<platform::CUDADeviceContext> { struct ConvertToPhiContext<platform::CUDADeviceContext> {
using TYPE = phi::GPUContext; using TYPE = phi::GPUContext;
}; };
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
template <> template <>
struct ConvertToPtenContext<platform::XPUDeviceContext> { struct ConvertToPhiContext<platform::XPUDeviceContext> {
using TYPE = phi::XPUContext; using TYPE = phi::XPUContext;
}; };
#endif #endif
......
...@@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { TEST(PhiUtils, TransPhiKernelKeyToOpKernelType) {
phi::KernelKey kernel_key(phi::Backend::CPU, phi::DataLayout::NCHW, phi::KernelKey kernel_key(phi::Backend::CPU, phi::DataLayout::NCHW,
phi::DataType::FLOAT32); phi::DataType::FLOAT32);
auto op_kernel_type = auto op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key); paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32); ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW); ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW);
ASSERT_TRUE(paddle::platform::is_cpu_place(op_kernel_type.place_)); ASSERT_TRUE(paddle::platform::is_cpu_place(op_kernel_type.place_));
...@@ -33,7 +33,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { ...@@ -33,7 +33,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
phi::KernelKey kernel_key_mkldnn(phi::Backend::MKLDNN, phi::DataLayout::NCHW, phi::KernelKey kernel_key_mkldnn(phi::Backend::MKLDNN, phi::DataLayout::NCHW,
phi::DataType::FLOAT32); phi::DataType::FLOAT32);
op_kernel_type = op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_mkldnn); paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key_mkldnn);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32); ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW); ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW);
ASSERT_TRUE(paddle::platform::is_cpu_place(op_kernel_type.place_)); ASSERT_TRUE(paddle::platform::is_cpu_place(op_kernel_type.place_));
...@@ -42,10 +42,10 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { ...@@ -42,10 +42,10 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
#endif #endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
phi::KernelKey kernel_key_cudnn(phi::Backend::CUDNN, phi::DataLayout::NCHW, phi::KernelKey kernel_key_cudnn(phi::Backend::GPUDNN, phi::DataLayout::NCHW,
phi::DataType::FLOAT32); phi::DataType::FLOAT32);
op_kernel_type = op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_cudnn); paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key_cudnn);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32); ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW); ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW);
ASSERT_TRUE(paddle::platform::is_gpu_place(op_kernel_type.place_)); ASSERT_TRUE(paddle::platform::is_gpu_place(op_kernel_type.place_));
...@@ -53,3 +53,38 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { ...@@ -53,3 +53,38 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
paddle::framework::LibraryType::kCUDNN); paddle::framework::LibraryType::kCUDNN);
#endif #endif
} }
TEST(PhiUtils, TransOpKernelTypeToPhiKernelKey) {
paddle::framework::OpKernelType op_kernel_type(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kNCHW);
auto kernel_key =
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type);
ASSERT_EQ(kernel_key.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(kernel_key.backend(), phi::Backend::CPU);
#ifdef PADDLE_WITH_MKLDNN
paddle::framework::OpKernelType op_kernel_type_mkldnn(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kMKLDNN,
paddle::framework::LibraryType::kMKLDNN);
auto kernel_key_mkldnn =
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_mkldnn);
ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::MKLDNN);
ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::MKLDNN);
#endif
#ifdef PADDLE_WITH_CUDA
paddle::framework::OpKernelType op_kernel_type_cudnn(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kNCHW,
paddle::framework::LibraryType::kCUDNN);
auto kernel_key_cudnn =
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_cudnn);
ASSERT_EQ(kernel_key_cudnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_cudnn.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(kernel_key_cudnn.backend(), phi::Backend::GPUDNN);
#endif
}
...@@ -1457,7 +1457,7 @@ std::ostream& print_tensor<paddle::platform::complex<double>>( ...@@ -1457,7 +1457,7 @@ std::ostream& print_tensor<paddle::platform::complex<double>>(
std::ostream& operator<<(std::ostream& os, const LoD& lod) { std::ostream& operator<<(std::ostream& os, const LoD& lod) {
// NOTE(xiongkun): // NOTE(xiongkun):
// https://stackoverflow.com/questions/5195512/namespaces-and-operator-resolution // https://stackoverflow.com/questions/5195512/namespaces-and-operator-resolution
// if we don't redefine, the operator << of pten / framework LoD is not found. // if we don't redefine, the operator << of phi / framework LoD is not found.
paddle::string::operator<<(os, lod); paddle::string::operator<<(os, lod);
return os; return os;
} }
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags) cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
cc_library(var_helper SRCS var_helper.cc DEPS tensor pten_api) cc_library(var_helper SRCS var_helper.cc DEPS tensor phi_api)
IF(WITH_XPU) IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten_api pten pten_utils var_helper) cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils phi_api phi phi_utils var_helper)
ELSE() ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten_api pten pten_utils var_helper) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils phi_api phi phi_utils var_helper)
ENDIF() ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry var_helper pten_api) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry var_helper phi_api)
add_subdirectory(jit) add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper) cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper)
...@@ -47,9 +47,9 @@ if(WITH_GLOO) ...@@ -47,9 +47,9 @@ if(WITH_GLOO)
endif() endif()
if(NOT WITH_ASCEND_CL) if(NOT WITH_ASCEND_CL)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function pten_tensor) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function phi_tensor)
else() else()
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function npu_op_runner pten_tensor) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function npu_op_runner phi_tensor)
endif() endif()
add_subdirectory(tests) add_subdirectory(tests)
...@@ -70,12 +70,12 @@ OpSupportedInfos(const std::string& place, ...@@ -70,12 +70,12 @@ OpSupportedInfos(const std::string& place,
} }
} }
auto pten_kernels = phi::KernelFactory::Instance().kernels(); auto phi_kernels = phi::KernelFactory::Instance().kernels();
for (auto& kernel_pair : pten_kernels) { for (auto& kernel_pair : phi_kernels) {
auto op_type = phi::TransToFluidOpName(kernel_pair.first); auto op_type = phi::TransToFluidOpName(kernel_pair.first);
for (auto& info_pair : kernel_pair.second) { for (auto& info_pair : kernel_pair.second) {
framework::OpKernelType kernel_type = framework::OpKernelType kernel_type =
framework::TransPtenKernelKeyToOpKernelType(info_pair.first); framework::TransPhiKernelKeyToOpKernelType(info_pair.first);
if (is_target_place[query_place](kernel_type.place_) && if (is_target_place[query_place](kernel_type.place_) &&
kernel_type.data_type_ == dtype && all_ops.count(op_type)) { kernel_type.data_type_ == dtype && all_ops.count(op_type)) {
VLOG(4) << op_type << " " << supported_ops.size(); VLOG(4) << op_type << " " << supported_ops.size();
...@@ -273,8 +273,9 @@ static inline std::shared_ptr<VarType> CastToBF16( ...@@ -273,8 +273,9 @@ static inline std::shared_ptr<VarType> CastToBF16(
template <typename VarType> template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType( static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarMap<VarType>& ins) { const std::string& op_type, const NameVarMap<VarType>& ins,
auto dst_type = framework::proto::VarType::FP16; const framework::proto::VarType::Type amp_dtype) {
auto dst_type = amp_dtype;
for (const auto& pair : ins) { for (const auto& pair : ins) {
for (const auto& var : pair.second) { for (const auto& var : pair.second) {
if (GetDataType<VarType>(var) == framework::proto::VarType::FP32) { if (GetDataType<VarType>(var) == framework::proto::VarType::FP32) {
...@@ -337,7 +338,8 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, ...@@ -337,7 +338,8 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
} }
return new_ins; return new_ins;
} else { } else {
auto dst_type = GetPromoteType<VarType>(op_type, ins); auto dst_type =
GetPromoteType<VarType>(op_type, ins, framework::proto::VarType::FP16);
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32. // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::FP16 && if (dst_type == framework::proto::VarType::FP16 &&
...@@ -435,7 +437,7 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type, ...@@ -435,7 +437,7 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
} }
} }
return new_ins; return new_ins;
} else { } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (auto& pair : new_ins) { for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float"; << GetDtypeStr(*pair.second.cbegin()) << " to float";
...@@ -444,6 +446,26 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type, ...@@ -444,6 +446,26 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
} }
} }
return new_ins; return new_ins;
} else {
auto dst_type =
GetPromoteType<VarType>(op_type, ins, framework::proto::VarType::BF16);
// NOTE(zhangbo): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::BF16 &&
AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(
op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32
? CastToFP32<VarType>(var)
: CastToBF16<VarType>(var));
}
}
return new_ins;
} }
return new_ins; return new_ins;
} }
......
...@@ -154,7 +154,7 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) { ...@@ -154,7 +154,7 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
// Here, we use the type of the corresponding forward datatype. // Here, we use the type of the corresponding forward datatype.
tensor->mutable_data( tensor->mutable_data(
op.place(), framework::TransToPtenDataType(var->ForwardDataType())); op.place(), framework::TransToPhiDataType(var->ForwardDataType()));
VLOG(6) << "Set ungenerated Grad: " << var->Name() VLOG(6) << "Set ungenerated Grad: " << var->Name()
<< " as zero with dtype " << " as zero with dtype "
<< framework::DataTypeToString(var->ForwardDataType()); << framework::DataTypeToString(var->ForwardDataType());
......
...@@ -791,13 +791,13 @@ void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, ...@@ -791,13 +791,13 @@ void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
<< var->Var().Get<framework::LoDTensor>().dims(); << var->Var().Get<framework::LoDTensor>().dims();
tensor->Resize(var->Var().Get<framework::LoDTensor>().dims()); tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType())); framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0); phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
} else { } else {
auto* tensor = auto* tensor =
dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType())); framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0); phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
} }
} }
...@@ -925,13 +925,13 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, ...@@ -925,13 +925,13 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
<< var->Var().Get<framework::LoDTensor>().dims(); << var->Var().Get<framework::LoDTensor>().dims();
tensor->Resize(var->Var().Get<framework::LoDTensor>().dims()); tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType())); framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0); phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
} else { } else {
auto* tensor = auto* tensor =
dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place, tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType())); framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0); phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
} }
} }
......
...@@ -314,10 +314,10 @@ static void FillConstantLike(const VariableWrapper &ref_var, ...@@ -314,10 +314,10 @@ static void FillConstantLike(const VariableWrapper &ref_var,
// default data_type for now. // default data_type for now.
if (ref_var.ForwardDataType() != -1) { if (ref_var.ForwardDataType() != -1) {
dst_tensor->mutable_data( dst_tensor->mutable_data(
place, framework::TransToPtenDataType(ref_var.ForwardDataType())); place, framework::TransToPhiDataType(ref_var.ForwardDataType()));
} else { } else {
dst_tensor->mutable_data( dst_tensor->mutable_data(place,
place, framework::TransToPtenDataType(ref_var.DataType())); framework::TransToPhiDataType(ref_var.DataType()));
} }
phi::funcs::set_constant(*dev_ctx, dst_tensor, value); phi::funcs::set_constant(*dev_ctx, dst_tensor, value);
} }
......
...@@ -121,7 +121,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -121,7 +121,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
kernel_type_(kernel_type), kernel_type_(kernel_type),
func_(nullptr), func_(nullptr),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
run_pten_kernel_(true), run_phi_kernel_(true),
pt_kernel_signature_(kernel_signature), pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel) {} pt_kernel_(pt_kernel) {}
...@@ -151,7 +151,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -151,7 +151,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
// NOTE(zhiqiu): for kernels on given device, for example NPU, the order to // NOTE(zhiqiu): for kernels on given device, for example NPU, the order to
// choose is: // choose is:
// pten npu kernel > fluid npu kernel > pten cpu kernel > fluid cpu kernel // phi npu kernel > fluid npu kernel > phi cpu kernel > fluid cpu kernel
// 1. get expected kernel key // 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>( auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
...@@ -168,12 +168,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -168,12 +168,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
expected_kernel_key) || expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type()); paddle::platform::is_in_xpu_black_list(op.Type());
#endif #endif
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx);
VLOG(6) << pt_kernel_signature; VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key); pt_kernel_key);
...@@ -195,7 +195,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -195,7 +195,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx); pt_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
} }
} }
...@@ -211,7 +211,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -211,7 +211,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
|| is_xpu_unsupport || is_xpu_unsupport
#endif #endif
) { ) {
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key = auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op); FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel( auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
...@@ -423,12 +423,12 @@ static void PreparedOpRunPtImpl( ...@@ -423,12 +423,12 @@ static void PreparedOpRunPtImpl(
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
PreparePtenData<VarType>(pt_kernel, pt_kernel_signature, ins); PreparePhiData<VarType>(pt_kernel, pt_kernel_signature, ins);
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins, BuildDygraphPhiKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx, outs, attrs, default_attrs, dev_ctx,
&pt_kernel_context); &pt_kernel_context);
pt_kernel(&pt_kernel_context); pt_kernel(&pt_kernel_context);
} }
...@@ -451,7 +451,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -451,7 +451,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_phi_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_, PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, dev_ctx_, ins, outs, attrs, pt_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs); default_attrs);
...@@ -465,7 +465,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -465,7 +465,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_phi_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>( PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs); outs, attrs, default_attrs);
...@@ -479,7 +479,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins, ...@@ -479,7 +479,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerVariable>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_phi_kernel_) {
PreparedOpRunPtImpl<egr::EagerVariable>( PreparedOpRunPtImpl<egr::EagerVariable>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs); outs, attrs, default_attrs);
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
...@@ -201,9 +201,9 @@ class PreparedOp { ...@@ -201,9 +201,9 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelFunc func_; framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_; platform::DeviceContext* dev_ctx_;
// NOTE(chenweihang): Similar op members are used to adapt to // NOTE(chenweihang): Similar op members are used to adapt to
// new pten kernel, if there is a better design in the future, // new phi kernel, if there is a better design in the future,
// we may polish the implementation here // we may polish the implementation here
bool run_pten_kernel_{false}; bool run_phi_kernel_{false};
bool run_kp_kernel_{false}; bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_; framework::KernelSignature pt_kernel_signature_;
phi::Kernel pt_kernel_; phi::Kernel pt_kernel_;
...@@ -225,7 +225,7 @@ const inline framework::Attribute& GetAttr( ...@@ -225,7 +225,7 @@ const inline framework::Attribute& GetAttr(
} }
template <typename VarType> template <typename VarType>
void BuildDygraphPtenKernelContext( void BuildDygraphPhiKernelContext(
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const phi::Kernel& pt_kernel, const NameVarMap<VarType>& ins, const phi::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
...@@ -327,7 +327,7 @@ void BuildDygraphPtenKernelContext( ...@@ -327,7 +327,7 @@ void BuildDygraphPtenKernelContext(
experimental::ResetTensorDtypeAndLayoutByArgDef(tensor_out, experimental::ResetTensorDtypeAndLayoutByArgDef(tensor_out,
output_defs.at(i)); output_defs.at(i));
framework::SetAllocationForOutputTenosr( framework::SetAllocationForOutputTenosr(
tensor_out, phi::TransToPtenPlace(output_defs.at(i).backend)); tensor_out, phi::TransToPhiPlace(output_defs.at(i).backend));
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
...@@ -369,7 +369,7 @@ void BuildDygraphPtenKernelContext( ...@@ -369,7 +369,7 @@ void BuildDygraphPtenKernelContext(
auto& ins_vector = ins.at(attr_names[i]); auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var()))); experimental::MakePhiScalarArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList } else { // ShapeTensorList
std::vector<framework::Variable*> variables; std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size()); variables.reserve(ins_vector.size());
...@@ -377,7 +377,7 @@ void BuildDygraphPtenKernelContext( ...@@ -377,7 +377,7 @@ void BuildDygraphPtenKernelContext(
variables.push_back(var_base->MutableVar()); variables.push_back(var_base->MutableVar());
} }
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(variables))); experimental::MakePhiScalarArrayFromVarList(variables)));
} }
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
...@@ -409,7 +409,7 @@ void BuildDygraphPtenKernelContext( ...@@ -409,7 +409,7 @@ void BuildDygraphPtenKernelContext(
} else { // scalar is in the input } else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]); auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(ins_vector[0]->Var()))); experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
} }
} else { } else {
...@@ -428,7 +428,7 @@ void BuildDygraphPtenKernelContext( ...@@ -428,7 +428,7 @@ void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) { std::type_index(typeid(phi::DataType))) {
auto data_type = framework::TransToPtenDataType( auto data_type = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
kernel_ctx->EmplaceBackAttr(data_type); kernel_ctx->EmplaceBackAttr(data_type);
...@@ -436,7 +436,7 @@ void BuildDygraphPtenKernelContext( ...@@ -436,7 +436,7 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) == if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end()); vector_int_attr.end());
...@@ -456,9 +456,9 @@ void BuildDygraphPtenKernelContext( ...@@ -456,9 +456,9 @@ void BuildDygraphPtenKernelContext(
} }
template <typename VarType> template <typename VarType>
void PreparePtenData(const phi::Kernel& pt_kernel, void PreparePhiData(const phi::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) { const NameVarMap<VarType>& ins) {
auto& input_names = std::get<0>(pt_kernel_signature.args); auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = pt_kernel.args_def().input_defs();
...@@ -482,12 +482,12 @@ void PreparePtenData(const phi::Kernel& pt_kernel, ...@@ -482,12 +482,12 @@ void PreparePtenData(const phi::Kernel& pt_kernel,
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue; continue;
} }
auto expected_place = phi::TransToPtenPlace(in_def.backend); auto expected_place = phi::TransToPhiPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) { if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue; continue;
} }
VLOG(3) << "Pten Transform Variable " << input_names[i] << " from " VLOG(3) << "Phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place; << tensor_in->place() << " to " << expected_place;
framework::Tensor tmp_tensor; framework::Tensor tmp_tensor;
......
...@@ -446,7 +446,7 @@ void Reducer::InitializeGroups( ...@@ -446,7 +446,7 @@ void Reducer::InitializeGroups(
InitializeDenseGroups(variable_indices_, &group); InitializeDenseGroups(variable_indices_, &group);
auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>(); auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({group.all_length_})) tensor->Resize(phi::make_ddim({group.all_length_}))
.mutable_data(place_, framework::TransToPtenDataType(group.dtype_)); .mutable_data(place_, framework::TransToPhiDataType(group.dtype_));
} }
// map variables to this group by VariableLocator // map variables to this group by VariableLocator
...@@ -738,7 +738,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { ...@@ -738,7 +738,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
if (!group_tensor.IsInitialized()) { if (!group_tensor.IsInitialized()) {
group_tensor.Resize({static_cast<int64_t>(length)}); group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(place_, group_tensor.mutable_data(place_,
framework::TransToPtenDataType(group.dtype_)); framework::TransToPhiDataType(group.dtype_));
} }
#ifdef PADDLE_WITH_XPU_BKCL #ifdef PADDLE_WITH_XPU_BKCL
......
...@@ -15,7 +15,7 @@ else() ...@@ -15,7 +15,7 @@ else()
endif(WIN32) endif(WIN32)
cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows_utils selected_rows_functor gradient_accumulator math_function pten_tensor pten_api pten_api_utils) cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows_utils selected_rows_functor gradient_accumulator math_function phi_tensor phi_api phi_api_utils)
cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
......
...@@ -96,7 +96,7 @@ void GroupConcatSplit(Place place, size_t size) { ...@@ -96,7 +96,7 @@ void GroupConcatSplit(Place place, size_t size) {
{ // concat { // concat
auto* tensor = group.dense_contents_.GetMutable<framework::LoDTensor>(); auto* tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({group.all_length_})) tensor->Resize(phi::make_ddim({group.all_length_}))
.mutable_data(place, framework::TransToPtenDataType(group.dtype_)); .mutable_data(place, framework::TransToPhiDataType(group.dtype_));
group.ConcatTensors(*dev_ctx); group.ConcatTensors(*dev_ctx);
group.DivNRanks(*dev_ctx, 1); group.DivNRanks(*dev_ctx, 1);
......
...@@ -175,7 +175,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins, ...@@ -175,7 +175,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
paddle::framework::AttributeMap* passed_default_attrs_, paddle::framework::AttributeMap* passed_default_attrs_,
bool use_default_attr_map) { bool use_default_attr_map) {
platform::RecordEvent op_type_record_event( platform::RecordEvent op_type_record_event(
type, platform::TracerEventType::Operator, 2); type, platform::TracerEventType::Operator, 1);
platform::ScopedFlushDenormal flush; platform::ScopedFlushDenormal flush;
VLOG(1) << "Trace Op: " << type; VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) { if (FLAGS_use_mkldnn) {
...@@ -205,17 +205,19 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins, ...@@ -205,17 +205,19 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
NameVarMap<VarType> new_ins = ins; NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) { if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
if (amp_dtype_ == phi::DataType::FLOAT16) { if (amp_dtype_ == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins); new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) { } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins); new_ins = AutoCastBF16Inputs<VarType>(type, ins);
} }
} else if (amp_level_ == AmpLevel::O2) { } else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
if (amp_dtype_ == phi::DataType::FLOAT16) { if (amp_dtype_ == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins); new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) { } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins); new_ins = CastPureBf16Inputs<VarType>(type, ins);
} }
} }
......
...@@ -35,7 +35,7 @@ endif() ...@@ -35,7 +35,7 @@ endif()
# fluid_modules exclude API-interface of inference/api and inference/capi_exp # fluid_modules exclude API-interface of inference/api and inference/capi_exp
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES) get_property(phi_modules GLOBAL PROPERTY PHI_MODULES)
set(utils_modules stringpiece pretty_log string_helper) set(utils_modules stringpiece pretty_log string_helper)
add_subdirectory(api) add_subdirectory(api)
...@@ -47,11 +47,11 @@ set(STATIC_INFERENCE_API paddle_inference_api analysis_predictor ...@@ -47,11 +47,11 @@ set(STATIC_INFERENCE_API paddle_inference_api analysis_predictor
analysis_config paddle_pass_builder activation_functions ${mkldnn_quantizer_cfg}) analysis_config paddle_pass_builder activation_functions ${mkldnn_quantizer_cfg})
#TODO(wilber, T8T9): Do we still need to support windows gpu static library? #TODO(wilber, T8T9): Do we still need to support windows gpu static library?
if(WIN32 AND WITH_GPU) if(WIN32 AND WITH_GPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API} ${utils_modules}) cc_library(paddle_inference DEPS ${fluid_modules} ${phi_modules} ${STATIC_INFERENCE_API} ${utils_modules})
elseif(WITH_IPU) elseif(WITH_IPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API} ${utils_modules} paddle_ipu) cc_library(paddle_inference DEPS ${fluid_modules} ${phi_modules} ${STATIC_INFERENCE_API} ${utils_modules} paddle_ipu)
else() else()
create_static_lib(paddle_inference ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API} ${utils_modules}) create_static_lib(paddle_inference ${fluid_modules} ${phi_modules} ${STATIC_INFERENCE_API} ${utils_modules})
endif() endif()
if(NOT APPLE) if(NOT APPLE)
...@@ -81,7 +81,7 @@ set(SHARED_INFERENCE_SRCS ...@@ -81,7 +81,7 @@ set(SHARED_INFERENCE_SRCS
${PADDLE_CUSTOM_OP_SRCS}) ${PADDLE_CUSTOM_OP_SRCS})
# shared inference library deps # shared inference library deps
set(SHARED_INFERENCE_DEPS ${fluid_modules} ${pten_modules} analysis_predictor) set(SHARED_INFERENCE_DEPS ${fluid_modules} ${phi_modules} analysis_predictor)
if (WITH_CRYPTO) if (WITH_CRYPTO)
set(SHARED_INFERENCE_DEPS ${SHARED_INFERENCE_DEPS} paddle_crypto) set(SHARED_INFERENCE_DEPS ${SHARED_INFERENCE_DEPS} paddle_crypto)
......
...@@ -56,8 +56,10 @@ cc_test(test_paddle_inference_api SRCS api_tester.cc DEPS paddle_inference_api) ...@@ -56,8 +56,10 @@ cc_test(test_paddle_inference_api SRCS api_tester.cc DEPS paddle_inference_api)
if(WITH_TESTING) if(WITH_TESTING)
if (NOT APPLE AND NOT WIN32) if (NOT APPLE AND NOT WIN32)
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS paddle_inference_shared if (WITH_GPU)
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR}) inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS paddle_inference_shared
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
endif()
elseif(WIN32) elseif(WIN32)
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS ${inference_deps} inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS ${inference_deps}
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR}) ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma."); DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
......
...@@ -198,7 +198,7 @@ void InitDstTensor(framework::LoDTensor* dst, ...@@ -198,7 +198,7 @@ void InitDstTensor(framework::LoDTensor* dst,
const paddle::lite_api::Tensor& src) { const paddle::lite_api::Tensor& src) {
dst->mutable_data( dst->mutable_data(
inference::lite::utils::GetNativePlace(src.target()), inference::lite::utils::GetNativePlace(src.target()),
framework::TransToPtenDataType(GetNativePrecisionType(src.precision()))); framework::TransToPhiDataType(GetNativePrecisionType(src.precision())));
SetLoD(dst->mutable_lod(), src.lod()); SetLoD(dst->mutable_lod(), src.lod());
} }
...@@ -269,7 +269,7 @@ void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) { ...@@ -269,7 +269,7 @@ void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) {
SetLoD(dst->mutable_lod(), src->lod()); SetLoD(dst->mutable_lod(), src->lod());
dst->ResetHolderWithType( dst->ResetHolderWithType(
holder, holder,
framework::TransToPtenDataType(GetNativePrecisionType(src->precision()))); framework::TransToPhiDataType(GetNativePrecisionType(src->precision())));
} }
} // namespace utils } // namespace utils
......
...@@ -88,5 +88,5 @@ class SoftMaxOpConverter : public OpConverter { ...@@ -88,5 +88,5 @@ class SoftMaxOpConverter : public OpConverter {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(softmax); USE_OP_ITSELF(softmax);
REGISTER_TRT_OP_CONVERTER(softmax, SoftMaxOpConverter); REGISTER_TRT_OP_CONVERTER(softmax, SoftMaxOpConverter);
...@@ -45,4 +45,4 @@ TEST(SoftMaxOpConverter, main) { ...@@ -45,4 +45,4 @@ TEST(SoftMaxOpConverter, main) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(softmax); USE_OP_ITSELF(softmax);
...@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( ...@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T> template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input, __global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) { const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x * gridDim.y; const int hidden = blockDim.x * gridDim.x;
const int batch = blockIdx.x; const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x; const int batch_id = blockIdx.y;
output[batch * hidden + local_idx] = output[batch_id * hidden + hidden_id] =
slice_input[cu_seqlens[batch] * hidden + local_idx]; slice_input[cu_seqlens[batch_id] * hidden + hidden_id];
} }
int SpecialSlicePluginDynamic::enqueue( int SpecialSlicePluginDynamic::enqueue(
...@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue( ...@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue(
"hidden should be multiple of 128.")); "hidden should be multiple of 128."));
constexpr int num_threads = 128; constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);
const half* slice_input = static_cast<const half*>(inputs[0]); const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]); const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]); half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input, const int32_t num_blocks_x = hidden / num_threads;
cu_seqlens, output); const int32_t num_blocks_y = out_dims.d[0]; // batchs
const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -299,7 +299,9 @@ inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR} ...@@ -299,7 +299,9 @@ inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR}
set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie") set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie")
download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" aa59192dd41ed377f9f168e3a1309fa6 "Ernie_data.txt.tar.gz" 5396e63548edad7ca561e7e26a9476d1) download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" aa59192dd41ed377f9f168e3a1309fa6 "Ernie_data.txt.tar.gz" 5396e63548edad7ca561e7e26a9476d1)
download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" 73beea65abda2edb61c1662cd3180c62) download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" 73beea65abda2edb61c1662cd3180c62)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc) if (WITH_GPU)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc)
endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} analyzer_ernie_int8_tester.cc) inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} analyzer_ernie_int8_tester.cc)
# Ernie large # Ernie large
...@@ -551,7 +553,9 @@ endif() ...@@ -551,7 +553,9 @@ endif()
# bert, max_len=20, embedding_dim=128 # bert, max_len=20, embedding_dim=128
set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert_emb128") set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert_emb128")
download_model_and_data_without_verify(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_data_len20.txt.tar.gz") download_model_and_data_without_verify(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_data_len20.txt.tar.gz")
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc) if (WITH_GPU)
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc)
endif()
# multiple models prediction # multiple models prediction
set(MMP_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/multi_model_prediction") set(MMP_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/multi_model_prediction")
...@@ -741,13 +745,15 @@ set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120) ...@@ -741,13 +745,15 @@ set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_mobilenet_transpose PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_mobilenet_transpose PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_resnet50 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_resnet50 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_ner PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_ner PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_ernie_int8 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_ernie_int8 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_googlenet PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_googlenet PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_small_dam PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_small_dam PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_transformer PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_bert PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_mobilenet_depthwise_conv PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_mobilenet_depthwise_conv PROPERTIES TIMEOUT 120)
if (WITH_GPU)
set_tests_properties(test_analyzer_bert PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120)
endif()
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 120) set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 120)
if(WITH_MKLDNN) if(WITH_MKLDNN)
......
...@@ -493,7 +493,8 @@ class AllocatorFacadePrivate { ...@@ -493,7 +493,8 @@ class AllocatorFacadePrivate {
"support allocating managed memory.\n" "support allocating managed memory.\n"
"If you don't actually need to use managed memory, please disable " "If you don't actually need to use managed memory, please disable "
"it with command `export FLAGS_use_cuda_managed_memory=false`.\n" "it with command `export FLAGS_use_cuda_managed_memory=false`.\n"
"Or you must use the gpu device that supports managed memory.")); "Or you must use the gpu device that supports managed memory.",
p.device));
} }
return std::make_shared<CUDAManagedAllocator>(p); return std::make_shared<CUDAManagedAllocator>(p);
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/aligned_allocator.h" #include "paddle/fluid/memory/allocation/aligned_allocator.h"
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
PADDLE_DEFINE_EXPORTED_READONLY_bool( PADDLE_DEFINE_EXPORTED_READONLY_bool(
free_idle_chunk, false, free_idle_chunk, false,
...@@ -47,6 +48,8 @@ AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator( ...@@ -47,6 +48,8 @@ AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator(
phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl( phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl(
size_t unaligned_size) { size_t unaligned_size) {
platform::RecordEvent("AutoGrowthBestFitAllocator::Allocate",
platform::TracerEventType::UserDefined, 9 /*level*/);
size_t size = AlignedSize(unaligned_size, alignment_); size_t size = AlignedSize(unaligned_size, alignment_);
VLOG(10) << "Allocate " << unaligned_size << " bytes, aligned to " << size; VLOG(10) << "Allocate " << unaligned_size << " bytes, aligned to " << size;
...@@ -108,6 +111,8 @@ phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl( ...@@ -108,6 +111,8 @@ phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl(
} }
void AutoGrowthBestFitAllocator::FreeImpl(phi::Allocation *allocation) { void AutoGrowthBestFitAllocator::FreeImpl(phi::Allocation *allocation) {
platform::RecordEvent("AutoGrowthBestFitAllocator::Free",
platform::TracerEventType::UserDefined, 9 /*level*/);
VLOG(10) << "Free " << allocation->size() VLOG(10) << "Free " << allocation->size()
<< " bytes, ptr = " << allocation->ptr(); << " bytes, ptr = " << allocation->ptr();
std::lock_guard<SpinLock> guard(spinlock_); std::lock_guard<SpinLock> guard(spinlock_);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -117,6 +118,8 @@ StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() { ...@@ -117,6 +118,8 @@ StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() {
bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; } bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; }
phi::Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) { phi::Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) {
platform::RecordEvent("StreamSafeCUDAAllocator::Allocate",
platform::TracerEventType::UserDefined, 9 /*level*/);
ProcessUnfreedAllocations(); ProcessUnfreedAllocations();
VLOG(8) << "Try allocate " << size << " bytes"; VLOG(8) << "Try allocate " << size << " bytes";
AllocationPtr underlying_allocation; AllocationPtr underlying_allocation;
...@@ -144,6 +147,8 @@ phi::Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) { ...@@ -144,6 +147,8 @@ phi::Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) {
} }
void StreamSafeCUDAAllocator::FreeImpl(phi::Allocation* allocation) { void StreamSafeCUDAAllocator::FreeImpl(phi::Allocation* allocation) {
platform::RecordEvent("StreamSafeCUDAAllocator::Free",
platform::TracerEventType::UserDefined, 9 /*level*/);
StreamSafeCUDAAllocation* stream_safe_cuda_allocation = StreamSafeCUDAAllocation* stream_safe_cuda_allocation =
dynamic_cast<StreamSafeCUDAAllocation*>(allocation); dynamic_cast<StreamSafeCUDAAllocation*>(allocation);
PADDLE_ENFORCE_NOT_NULL(stream_safe_cuda_allocation, PADDLE_ENFORCE_NOT_NULL(stream_safe_cuda_allocation,
......
...@@ -128,6 +128,9 @@ TEST(ManagedMemoryTest, OversubscribeGPUMemoryTest) { ...@@ -128,6 +128,9 @@ TEST(ManagedMemoryTest, OversubscribeGPUMemoryTest) {
} }
TEST(ManagedMemoryTest, OOMExceptionTest) { TEST(ManagedMemoryTest, OOMExceptionTest) {
if (!platform::IsGPUManagedMemorySupported(0)) {
return;
}
EXPECT_THROW(Alloc(platform::CUDAPlace(0), size_t(1) << 60), EXPECT_THROW(Alloc(platform::CUDAPlace(0), size_t(1) << 60),
memory::allocation::BadAlloc); memory::allocation::BadAlloc);
} }
......
...@@ -100,7 +100,7 @@ else() ...@@ -100,7 +100,7 @@ else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif() endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils gather_scatter_kernel) set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
......
...@@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/addmm_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -24,6 +27,8 @@ limitations under the License. */ ...@@ -24,6 +27,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr int kMULMKLDNNINT8 = 1;
using framework::OpKernelType; using framework::OpKernelType;
using framework::Tensor; using framework::Tensor;
...@@ -31,85 +36,6 @@ class AddMMOp : public framework::OperatorWithKernel { ...@@ -31,85 +36,6 @@ class AddMMOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
platform::errors::NotFound(
"Input(Input) of AddMMOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of AddMMOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of AddMMOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of AddMMOp should not be null."));
auto input_dims = ctx->GetInputDim("Input");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto ndim_input = input_dims.size();
auto ndim_x = x_dims.size();
auto ndim_y = y_dims.size();
float alpha = ctx->Attrs().Get<float>("Alpha");
float beta = ctx->Attrs().Get<float>("Beta");
VLOG(3) << "addmm operator input.shape=" << input_dims
<< " x.shape=" << x_dims << " y.shape=" << y_dims
<< " beta=" << beta << " alpha=" << alpha
<< " ndim_input=" << ndim_input << " ndim_x=" << ndim_x
<< " ndim_y=" << ndim_y;
PADDLE_ENFORCE_NE(phi::product(input_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Input(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Input").front()));
PADDLE_ENFORCE_NE(phi::product(x_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable X(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("X").front()));
PADDLE_ENFORCE_NE(phi::product(y_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Y").front()));
// dim check
PADDLE_ENFORCE_EQ(ndim_input, 2,
platform::errors::InvalidArgument(
"The input tensor input's dimension must be 2. "
"But received input's dimension = [%s].",
ndim_input));
PADDLE_ENFORCE_EQ(ndim_x, 2,
platform::errors::InvalidArgument(
"The input tensor x's dimension must be 2. "
"But received x's dimension = [%s].",
ndim_x));
PADDLE_ENFORCE_EQ(ndim_y, 2,
platform::errors::InvalidArgument(
"The input tensor y's dimension must be 2. "
"But received y's dimension = [%s].",
ndim_y));
std::vector<int64_t> output_dims;
output_dims.push_back(x_dims[0]);
output_dims.push_back(y_dims[1]);
ctx->SetOutputDim("Out", phi::make_ddim(output_dims));
ctx->ShareLoD("Input", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
...@@ -221,17 +147,11 @@ class AddMMOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -221,17 +147,11 @@ class AddMMOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(addmm, AddmmInferShapeFunctor,
PT_INFER_META(phi::AddmmInferMeta));
REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker, REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker,
ops::AddMMOpGradMaker<paddle::framework::OpDesc>, ops::AddMMOpGradMaker<paddle::framework::OpDesc>,
ops::AddMMOpGradMaker<paddle::imperative::OpBase>); ops::AddMMOpGradMaker<paddle::imperative::OpBase>,
AddmmInferShapeFunctor);
REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp); REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp);
REGISTER_OP_CPU_KERNEL(
addmm, ops::AddMMKernel<paddle::platform::CPUDeviceContext, float>,
ops::AddMMKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
addmm_grad, ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
using Tensor = framework::Tensor;
constexpr int kMULMKLDNNINT8 = 1;
template <typename DeviceContext, typename T>
class AddMMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
auto input_dims = input->dims();
auto x_dims = x->dims();
auto y_dims = y->dims();
// broadcast mode check
if (x_dims[0] != input_dims[0]) {
PADDLE_ENFORCE_EQ(input_dims[0], 1,
platform::errors::InvalidArgument(
"When x_dims[0] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[0]));
PADDLE_ENFORCE_EQ(
y_dims[1] == input_dims[1] || input_dims[1] == 1, true,
platform::errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims, x_dims, y_dims));
}
// broadcast mode check
if (y_dims[1] != input_dims[1]) {
PADDLE_ENFORCE_EQ(input_dims[1], 1,
platform::errors::InvalidArgument(
"When y_dims[1] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[1]));
PADDLE_ENFORCE_EQ(
x_dims[0] == input_dims[0] || input_dims[0] == 1, true,
platform::errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims, x_dims, y_dims));
}
// broadcast mode check
PADDLE_ENFORCE_EQ(
x_dims[1], y_dims[0],
platform::errors::InvalidArgument(
"The input tensor X's width must be equal with matrix Y' height. "
"But received X's shape = [%s], Y's shape = [%s].",
x_dims[1], y_dims[0]));
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>({x_dims[0], y_dims[1]}, context.GetPlace());
float alpha = context.template Attr<float>("Alpha");
float beta = context.template Attr<float>("Beta");
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
// calc broadcast dim
Array2 bcast_dims;
bcast_dims[0] = x_dims[0] / input_dims[0];
bcast_dims[1] = y_dims[1] / input_dims[1];
VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
// broadcast using eigen
auto eigen_input = EigenTensor<T, 2>::From(*input);
auto eigen_out = EigenTensor<T, 2>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);
blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha,
x->data<T>(), x_dims[1], y->data<T>(), y_dims[1], beta,
out->data<T>(), y_dims[1]);
}
};
template <typename DeviceContext, typename T>
class AddMMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto in_dims = ctx.Input<framework::LoDTensor>("Input")->dims();
auto* dinput =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
float alpha = ctx.Attr<float>("Alpha");
float beta = ctx.Attr<float>("Beta");
int total_elems = 0;
VLOG(3) << "alpha: " << alpha << " beta: " << beta;
if (dinput != nullptr) {
dinput->set_lod(dout->lod());
}
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
if (dinput) {
dinput->mutable_data<T>(ctx.GetPlace());
total_elems = in_dims[0] * in_dims[1];
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto eigen_dout = EigenTensor<T, 2>::From(*dout);
auto eigen_dinput = EigenTensor<T, 2>::From(*dinput);
bool row_compress = in_dims[0] != dout->dims()[0];
bool col_compress = in_dims[1] != dout->dims()[1];
auto eigen_dinput_shape = Array2(dinput->dims()[0], dinput->dims()[1]);
if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else if (row_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else if (col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
blas.VCOPY(total_elems, dout->data<T>(), dinput->data<T>());
}
blas.SCAL(total_elems, beta, dinput->data<T>());
}
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
total_elems = x->dims()[0] * x->dims()[1];
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(*dout, false, *y, true, dx);
blas.SCAL(total_elems, alpha, dx->data<T>());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
total_elems = x->dims()[1] * y->dims()[1];
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(*x, true, *dout, false, dy);
blas.SCAL(total_elems, alpha, dy->data<T>());
}
}
};
} // namespace operators
} // namespace paddle
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
namespace paddle { namespace paddle {
......
...@@ -12,84 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,84 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/bilinear_tensor_product_op.h" #include "paddle/fluid/framework/infershape_utils.h"
#include <memory> #include "paddle/fluid/framework/op_registry.h"
#include <string> #include "paddle/phi/core/infermeta_utils.h"
#include <vector> #include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
class BilinearTensorProductOp : public framework::OperatorWithKernel { class BilinearTensorProductOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::InvalidArgument("Input(Y) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Weight"), true,
platform::errors::InvalidArgument("Input(Weight) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument("Output(Out) should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(
x_dims.size(), 2UL,
platform::errors::InvalidArgument("The input(X) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
y_dims.size(), 2UL,
platform::errors::InvalidArgument("The input(Y) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_dims.size(), 3UL,
platform::errors::InvalidArgument("Expected the input(Weight) is a 3D "
"tensor. But received %dD tensor.",
weight_dims.size()));
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0)) {
PADDLE_ENFORCE_EQ(
x_dims[0], y_dims[0],
platform::errors::InvalidArgument(
"The first dimension(batch_size) of input(X) must be "
"equal to the first dimension of the input(Y)."));
}
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
platform::errors::InvalidArgument(
"The second dimension of input(X) must be equal to "
"the second dimension of the input(Weight)."));
PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2],
platform::errors::InvalidArgument(
"The second dimension of input(Y) must be equal to "
"the third dimension of the input(Weight)."));
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims.size(), 2UL,
platform::errors::InvalidArgument(
"The Input(Bias) must be a 2-D tensor with "
"the 2nd dimension fixed to 1 (a row vector)."));
PADDLE_ENFORCE_EQ(bias_dims[0], 1UL,
platform::errors::InvalidArgument(
"The Input(Bias) must be a 2-D tensor with "
"the 2nd dimension fixed to 1 (a row vector)."));
PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0],
platform::errors::InvalidArgument(
"The second dimension of input(Bias) must be equal "
"to the first dimension of the input(Weight)."));
}
ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]});
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -125,59 +59,6 @@ Where $W_i$ is the $i$-th slice of Input(Weight); ...@@ -125,59 +59,6 @@ Where $W_i$ is the $i$-th slice of Input(Weight);
class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::InvalidArgument("Input(Y) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Weight"), true,
platform::errors::InvalidArgument("Input(Weight) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(out_dims.size(), 2UL,
platform::errors::InvalidArgument(
"The input(Out@GRAD) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0],
platform::errors::InvalidArgument(
"The first dimension(batch_size) of input(Out@GRAD) must be "
"equal to the first dimension of the Input(X)."));
PADDLE_ENFORCE_EQ(
weight_dims[0], out_dims[1],
platform::errors::InvalidArgument(
"The second dimension of input(Out@GRAD) must be equal to "
"the third dimension of the Input(Weight)."));
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) {
ctx->SetOutputDim(bias_grad_name, {1, out_dims[1]});
}
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
if (ctx->HasOutput(weight_grad_name)) {
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
}
}; };
template <typename T> template <typename T>
...@@ -208,21 +89,20 @@ class BilinearTensorProductGradOpMaker ...@@ -208,21 +89,20 @@ class BilinearTensorProductGradOpMaker
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(bilinear_tensor_product,
BilinearTensorProductInferShapeFunctor,
PT_INFER_META(phi::BilinearTensorProductInferMeta));
DELCARE_INFER_SHAPE_FUNCTOR(
bilinear_tensor_product_grad, BilinearTensorProductGradInferShapeFunctor,
PT_INFER_META(phi::BilinearTensorProductGradInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
bilinear_tensor_product, ops::BilinearTensorProductOp, bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpMaker, ops::BilinearTensorProductOpMaker,
ops::BilinearTensorProductGradOpMaker<paddle::framework::OpDesc>, ops::BilinearTensorProductGradOpMaker<paddle::framework::OpDesc>,
ops::BilinearTensorProductGradOpMaker<paddle::imperative::OpBase>); ops::BilinearTensorProductGradOpMaker<paddle::imperative::OpBase>,
BilinearTensorProductInferShapeFunctor);
REGISTER_OPERATOR(bilinear_tensor_product_grad, REGISTER_OPERATOR(bilinear_tensor_product_grad,
ops::BilinearTensorProductOpGrad); ops::BilinearTensorProductOpGrad,
REGISTER_OP_CPU_KERNEL( BilinearTensorProductGradInferShapeFunctor);
bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext, float>,
ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::CUDADeviceContext,
float>,
ops::BilinearTensorProductKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto y_mat = EigenMatrix<T>::From(*y);
auto output_mat = EigenMatrix<T>::From(*out);
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims();
int out_dim = weight_dims[0];
auto x_dim = weight_dims[1];
auto y_dim = weight_dims[2];
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to calculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is:
// left_mul = X Weight_i.
Tensor left_mul;
left_mul.mutable_data<T>(phi::make_ddim({batch_size, y_dim}),
ctx.GetPlace());
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
for (int i = 0; i < out_dim; ++i) {
auto output_col_vec = output_mat.chip(i, 1);
Tensor weight_mat =
weight->Slice(i, i + 1).Resize(phi::make_ddim({x_dim, y_dim}));
phi::funcs::GetBlas<DeviceContext, T>(dev_ctx).GEMM(
CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
}
if (bias) {
auto bias_vec = EigenMatrix<T>::From(*bias);
Eigen::DSizes<int, 2> bcast(batch_size, 1);
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
}
}
};
template <typename DeviceContext, typename T>
class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor* weight = ctx.Input<Tensor>("Weight");
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* d_y = ctx.Output<Tensor>(framework::GradVarName("Y"));
Tensor* d_weight = ctx.Output<Tensor>(framework::GradVarName("Weight"));
Tensor* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims();
int out_dim = weight_dims[0];
auto x_dim = weight_dims[1];
auto y_dim = weight_dims[2];
auto x_mat = EigenMatrix<T>::From(*x);
auto y_mat = EigenMatrix<T>::From(*y);
auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to calculate the Output(Y@Grad).
Tensor x_scale;
x_scale.mutable_data<T>(phi::make_ddim({batch_size, x_dim}),
ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
// Create the intermediate variable to calculate the Output(X@Grad).
Tensor y_scale;
y_scale.mutable_data<T>(phi::make_ddim({batch_size, y_dim}),
ctx.GetPlace());
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
phi::funcs::SetConstant<DeviceContext, T> set_zero;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_x, static_cast<T>(0));
}
if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_y, static_cast<T>(0));
}
if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace());
}
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
// Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y || d_weight) {
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
Eigen::DSizes<int, 2> bcast_for_y(1, x_dim);
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
for (int i = 0; i < out_dim; ++i) {
Tensor weight_i =
weight->Slice(i, i + 1).Resize(phi::make_ddim({x_dim, y_dim}));
auto output_vec = d_out_mat.chip(i, 1);
if (d_x) {
y_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) *
y_mat;
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
}
if (d_y || d_weight) {
auto output_vec_y =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y);
x_scale_mat.device(place) = output_vec_y * x_mat;
if (d_y) {
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
}
if (d_weight) {
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
phi::make_ddim({x_dim, y_dim}));
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
}
}
}
}
// calculate the gradient of Input(Bias).
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias);
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -138,7 +138,7 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -138,7 +138,7 @@ class CastOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
// cast use pten kernel, so no need to REGISTER_OP_CPU_KERNEL here. // cast use phi kernel, so no need to REGISTER_OP_CPU_KERNEL here.
REGISTER_OPERATOR(cast, ops::CastOp, REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>, ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>, ops::CastOpGradMaker<paddle::imperative::OpBase>,
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册