未验证 提交 3ab9aef1 编写于 作者: L Leo Chen 提交者: GitHub

[pten] remove deprecated fluid op kernel for pten (#38842)

* update cmake file to remove fluid kernel

* add pten declaration.h to where pybind.h used

* fix sync_bn and tensorrt_engine

* refine detection_library

* fix interpreter_core

* support eager legacy

* fit eager legacy for pten

* fall back to cpu if not found kernel

* fix compile problem

* fix compile problem

* refine fallback logic

* fit operator.run()

* fix xpu compile

* fit for new_exec

* add REGISTER_OP_WITHOUT_GRADIENT

* un-cache pt_kernel_context

* fix compile

* fix cudnn

* fix compiling with on_infer

* fix mkldnn

* fix isfinite_v2

* fix xpu problem

* fix op_device

* refine fallback for xpu

* fix xpu compile

* merge develop

* refine code format

* fix compile

* fix compile

* add data_transfer

* fix PreparePtenData

* fix cpu context

* merge develop

* fix compile

* fix error device context

* fix xpu

* fix dev_ctx
上级 801159ce
# CMake file `unity_build` is used to handle Unity Build compilation. # CMake file `unity_build` is used to handle Unity Build compilation.
include(unity_build) include(unity_build)
set(PART_CUDA_KERNEL_FILES) set(PART_CUDA_KERNEL_FILES)
function(find_register FILENAME PATTERN OUTPUT)
# find the op_name of REGISTER_OPERATOR(op_name, ...), REGISTER_OP_CPU_KERNEL(op_name, ...) , etc.
# set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
set(multiValueArgs "")
file(READ ${FILENAME} CONTENT)
# message ("number of arguments sent to function: ${ARGC}")
# message ("all function arguments: ${ARGV}")
# message("PATTERN ${PATTERN}")
string(REGEX MATCH "${PATTERN}\\([ \t\r\n]*[a-z0-9_]*," register "${CONTENT}")
if (NOT register STREQUAL "")
string(REPLACE "${PATTERN}(" "" register "${register}")
string(REPLACE "," "" register "${register}")
# [ \t\r\n]+ is used for blank characters.
# Here we use '+' instead of '*' since it is a REPLACE operation.
string(REGEX REPLACE "[ \t\r\n]+" "" register "${register}")
endif()
set(${OUTPUT} ${register} PARENT_SCOPE)
endfunction()
function(op_library TARGET) function(op_library TARGET)
# op_library is a function to create op library. The interface is same as # op_library is a function to create op library. The interface is same as
# cc_library. But it handle split GPU/CPU code and link some common library # cc_library. But it handle split GPU/CPU code and link some common library
...@@ -119,16 +142,16 @@ function(op_library TARGET) ...@@ -119,16 +142,16 @@ function(op_library TARGET)
list(APPEND miopen_cu_cc_srcs ${src}) list(APPEND miopen_cu_cc_srcs ${src})
elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu.cc$") elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu.cc$")
list(APPEND hip_cc_srcs ${src}) list(APPEND hip_cc_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu$") elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu$")
list(APPEND cudnn_cu_srcs ${src}) list(APPEND cudnn_cu_srcs ${src})
elseif (${src} MATCHES ".*\\.cu$") elseif (WITH_GPU AND ${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src}) list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu.cc$")
list(APPEND cudnn_cu_cc_srcs ${src}) list(APPEND cudnn_cu_cc_srcs ${src})
elseif(WITH_GPU AND ${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src})
elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
list(APPEND mkldnn_cc_srcs ${src}) list(APPEND mkldnn_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src})
elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$") elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
list(APPEND xpu_cc_srcs ${src}) list(APPEND xpu_cc_srcs ${src})
elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
...@@ -228,135 +251,136 @@ function(op_library TARGET) ...@@ -228,135 +251,136 @@ function(op_library TARGET)
endif() endif()
endif() endif()
list(LENGTH cu_srcs cu_srcs_len)
list(LENGTH hip_srcs hip_srcs_len)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
list(LENGTH hip_cc_srcs hip_cc_srcs_len)
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
list(LENGTH npu_cc_srcs npu_cc_srcs_len)
list(LENGTH mlu_cc_srcs mlu_cc_srcs_len)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op" foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" if ("${TARGET}" STREQUAL "${manual_pybind_op}")
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" set(pybind_flag 1)
"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op" "fused_feedforward_op") endif()
endforeach()
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
endforeach()
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
# Note that it's enough to just adding one operator to pybind in a *_op.cc file. # Note that it's enough to just adding one operator to pybind in a *_op.cc file.
# And for detail pybind information, please see generated paddle/pybind/pybind.h. # And for detail pybind information, please see generated paddle/pybind/pybind.h.
set(ORIGINAL_TARGET ${TARGET}) set(ORIGINAL_TARGET ${TARGET})
file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX REPLACE "_op" "" TARGET "${TARGET}")
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
# [ \t\r\n]* is used for blank characters
string(REGEX MATCH "REGISTER_OPERATOR\\([ \t\r\n]*[a-z0-9_]*," one_register "${multi_register}")
if (one_register STREQUAL "") foreach(cc_src ${cc_srcs})
string(REPLACE "_op" "" TARGET "${TARGET}") # pybind USE_OP_ITSELF
else () set(op_name "")
string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}") find_register(${cc_src} "REGISTER_OPERATOR" op_name)
string(REPLACE "," "" TARGET "${TARGET}") if(NOT ${op_name} EQUAL "")
# [ \t\r\n]+ is used for blank characters. file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
# Here we use '+' instead of '*' since it is a REPLACE operation. # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn
string(REGEX REPLACE "[ \t\r\n]+" "" TARGET "${TARGET}") set(TARGET ${op_name})
endif() set(pybind_flag 1)
endif()
set(op_name "")
find_register(${cc_src} "REGISTER_OP_WITHOUT_GRADIENT" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
# hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn
set(TARGET ${op_name})
set(pybind_flag 1)
endif()
# pybind USE_NO_KERNEL_OP # pybind USE_OP_DEVICE_KERNEL for CPU
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel set(op_name "")
string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") find_register(${cc_src} "REGISTER_OP_CPU_KERNEL" op_name)
string(REPLACE "_op" "" TARGET "${TARGET}") if(NOT ${op_name} EQUAL "")
if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CPU);\n")
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") # why change TARGET here?
set(pybind_flag 1) # when building padle with on_infer, the REGISTER_OPERATOR(*_grad) will be removed before compiling (see details in remove_grad_op_and_kernel.py)
endif() # in elementwise_op.cc, it will find REGISTER_OPERATOR(grad_add) and set TARGET to grad_add
# and, in the following "mkldnn" part, it will add USE_OP_DEVICE_KERNEL(grad_add, MKLDNN) to pybind.h
# however, grad_add has no mkldnn kernel.
set(TARGET ${op_name})
set(pybind_flag 1)
endif()
endforeach()
# pybind USE_CPU_ONLY_OP # pybind USE_OP_DEVICE_KERNEL for CUDA
list(LENGTH cu_srcs cu_srcs_len) list (APPEND cu_srcs ${cu_cc_srcs})
list(LENGTH hip_srcs hip_srcs_len) # message("cu_srcs ${cu_srcs}")
list(LENGTH cu_cc_srcs cu_cc_srcs_len) foreach(cu_src ${cu_srcs})
list(LENGTH hip_cc_srcs hip_cc_srcs_len) set(op_name "")
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) if(NOT ${op_name} EQUAL "")
list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
list(LENGTH npu_cc_srcs npu_cc_srcs_len) set(pybind_flag 1)
list(LENGTH mlu_cc_srcs mlu_cc_srcs_len) endif()
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND endforeach()
${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND
${npu_cc_srcs_len} EQUAL 0 AND ${mlu_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
endif()
# pybind USE_OP_DEVICE_KERNEL for MIOPEN # pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN
list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs})
if (WITH_ROCM AND ${miopen_cu_cc_srcs_len} GREATER 0) list(APPEND cudnn_cu_srcs ${miopen_cu_cc_srcs})
if(${TARGET} STREQUAL "activation") list(APPEND cudnn_cu_srcs ${miopen_cu_srcs})
list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len)
#message("cudnn_cu_srcs ${cudnn_cu_srcs}")
if(${cudnn_cu_srcs_len} GREATER 0 AND ${ORIGINAL_TARGET} STREQUAL "activation_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") foreach(cudnn_src ${cudnn_cu_srcs})
endif() set(op_name "")
find_register(${cudnn_src} "REGISTER_OP_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDNN);\n")
set(pybind_flag 1)
endif()
endforeach()
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len)
if (WITH_GPU AND ${cudnn_cu_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MIOPEN if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0)
list(LENGTH miopen_cu_srcs miopen_cu_srcs_len) if(${ORIGINAL_TARGET} STREQUAL "activation_op")
if (WITH_ROCM AND ${miopen_cu_srcs_len} GREATER 0) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, XPU);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") else()
foreach(xpu_src ${xpu_cc_srcs})
set(op_name "")
find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
set(pybind_flag 1)
endif()
endforeach()
endif() endif()
if (WITH_XPU AND ${pybind_flag} EQUAL 0 AND ${xpu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n")
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for NPU
if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0) if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0)
file(READ ${ORIGINAL_TARGET}_npu.cc TARGET_NPU_CONTENT) foreach(npu_src ${npu_cc_srcs})
# It is different from the logic above, becareful set(op_name "")
string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\(.*" multi_npu_register "${TARGET_NPU_CONTENT}") find_register(${npu_src} "REGISTER_OP_NPU_KERNEL" op_name)
# [ \t\r\n]* is used for blank characters if(NOT ${op_name} EQUAL "")
string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_npu_register "${multi_npu_register}") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, NPU);\n")
set(pybind_flag 1)
if (one_npu_register STREQUAL "")
string(REPLACE "_op" "" NPU_TARGET "${TARGET}")
else ()
string(REPLACE "REGISTER_OP_NPU_KERNEL(" "" NPU_TARGET "${one_npu_register}")
string(REPLACE "," "" NPU_TARGET "${NPU_TARGET}")
# [ \t\r\n]+ is used for blank characters.
# Here we use '+' instead of '*' since it is a REPLACE operation.
string(REGEX REPLACE "[ \t\r\n]+" "" NPU_TARGET "${NPU_TARGET}")
endif() endif()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n") endforeach()
endif() endif()
if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0)
file(READ ${ORIGINAL_TARGET}_mlu.cc TARGET_MLU_CONTENT)
# It is different from the logic above, becareful
string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\(.*" multi_mlu_register "${TARGET_MLU_CONTENT}")
# [ \t\r\n]* is used for blank characters
string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_mlu_register "${multi_mlu_register}")
if (one_mlu_register STREQUAL "") # pybind USE_OP_DEVICE_KERNEL for MLU
string(REPLACE "_op" "" MLU_TARGET "${TARGET}") if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0)
else () foreach(mlu_src ${mlu_cc_srcs})
string(REPLACE "REGISTER_OP_MLU_KERNEL(" "" MLU_TARGET "${one_mlu_register}") set(op_name "")
string(REPLACE "," "" MLU_TARGET "${MLU_TARGET}") find_register(${mlu_src} "REGISTER_OP_MLU_KERNEL" op_name)
# [ \t\r\n]+ is used for blank characters. if(NOT ${op_name} EQUAL "")
# Here we use '+' instead of '*' since it is a REPLACE operation. file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MLU);\n")
string(REGEX REPLACE "[ \t\r\n]+" "" MLU_TARGET "${MLU_TARGET}") set(pybind_flag 1)
endif() endif()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MLU_TARGET}, MLU);\n") endforeach()
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for MKLDNN # pybind USE_OP_DEVICE_KERNEL for MKLDNN
...@@ -377,10 +401,26 @@ function(op_library TARGET) ...@@ -377,10 +401,26 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") foreach(mkldnn_src ${mkldnn_cc_srcs})
set(op_name "")
find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n")
set(pybind_flag 1)
endif()
endforeach()
endif() endif()
endif() endif()
# pybind USE_NO_KERNEL_OP
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
string(REPLACE "_op" "" TARGET "${TARGET}")
if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
# pybind USE_OP # pybind USE_OP
if (${pybind_flag} EQUAL 0) if (${pybind_flag} EQUAL 0)
# NOTE(*): activation use macro to regist the kernels, set use_op manually. # NOTE(*): activation use macro to regist the kernels, set use_op manually.
......
...@@ -27,6 +27,9 @@ ...@@ -27,6 +27,9 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
// pten
#include "paddle/pten/kernels/declarations.h"
#define NUM_CREATED_DUP_INPUTS 4 #define NUM_CREATED_DUP_INPUTS 4
namespace paddle { namespace paddle {
...@@ -535,7 +538,8 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -535,7 +538,8 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Skip ooerator which is not inherit form OperatorWithKernel, like while, // Skip ooerator which is not inherit form OperatorWithKernel, like while,
// since only OperatorWithKernel can run in dygraph mode. // since only OperatorWithKernel can run in dygraph mode.
auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
if (!all_kernels.count(op_type)) { if (!all_kernels.count(op_type) &&
!pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
return false; return false;
} }
......
...@@ -93,6 +93,7 @@ void OpRunImpl(const paddle::framework::OperatorBase& op, ...@@ -93,6 +93,7 @@ void OpRunImpl(const paddle::framework::OperatorBase& op,
prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs); prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs);
} }
VLOG(6) << "Run Prepared Op end";
// TODO(jiabin): Set the output var's grad Forward DataType // TODO(jiabin): Set the output var's grad Forward DataType
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/eager/legacy/prepared_operator.h" #include "paddle/fluid/eager/legacy/prepared_operator.h"
#include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/eager/legacy/infer_shape_context.h" #include "paddle/fluid/eager/legacy/infer_shape_context.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
...@@ -71,6 +72,21 @@ PreparedOp::PreparedOp( ...@@ -71,6 +72,21 @@ PreparedOp::PreparedOp(
func_(func), func_(func),
dev_ctx_(dev_ctx) {} dev_ctx_(dev_ctx) {}
PreparedOp::PreparedOp(
const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel, paddle::platform::DeviceContext* dev_ctx)
: op_(op),
ctx_(ctx),
kernel_type_(kernel_type),
func_(nullptr),
dev_ctx_(dev_ctx),
run_pten_kernel_(true),
pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel) {}
PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs, PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::OperatorWithKernel& op, const paddle::framework::OperatorWithKernel& op,
const paddle::platform::Place& place, const paddle::platform::Place& place,
...@@ -104,17 +120,71 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs, ...@@ -104,17 +120,71 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs,
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
// fit for pten
pten::KernelSignature pt_kernel_signature;
pten::KernelKey pt_kernel_key;
std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
auto pt_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << pt_kernel;
// TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
// 2. check if op[type] has kernel registered. // 2. check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels(); auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type()); auto kernels_iter = all_op_kernels.find(op.Type());
if (kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end()
#ifdef PADDLE_WITH_XPU
||
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type())
#endif
) {
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_cpu_kernel, dev_ctx);
}
}
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(), kernels_iter, all_op_kernels.end(),
paddle::platform::errors::NotFound( paddle::platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.",
op.Type())); op.Type()));
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || (kernel_iter == kernels.end() ||
...@@ -202,11 +272,46 @@ static void PreparedOpRunImpl( ...@@ -202,11 +272,46 @@ static void PreparedOpRunImpl(
VLOG(6) << "Finish Runing Prepared Op"; VLOG(6) << "Finish Runing Prepared Op";
} }
static void PreparedOpRunPtImpl(
const paddle::framework::OperatorBase& op,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, paddle::platform::DeviceContext* dev_ctx,
const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) {
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
static_cast<const paddle::framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
paddle::imperative::PreparePtenData<EagerTensor>(
pt_kernel, pt_kernel_signature,
static_cast<paddle::imperative::NameTensorMap>(ins));
pten::KernelContext pt_kernel_context;
paddle::imperative::BuildDygraphPtenKernelContext<EagerTensor>(
pt_kernel_signature, pt_kernel,
static_cast<paddle::imperative::NameTensorMap>(ins),
static_cast<paddle::imperative::NameTensorMap>(outs), attrs,
default_attrs, dev_ctx, &pt_kernel_context);
pt_kernel(&pt_kernel_context);
// TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later
}
void PreparedOp::Run(const NameTensorMap& ins, const NameTensorMap& outs, void PreparedOp::Run(const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::AttributeMap& attrs, const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) { const paddle::framework::AttributeMap& default_attrs) {
PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, if (run_pten_kernel_) {
default_attrs); PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, pt_kernel_,
dev_ctx_, ins, outs, attrs, default_attrs);
} else {
PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs,
attrs, default_attrs);
}
} }
std::shared_ptr<NameTensorMap> PrepareData( std::shared_ptr<NameTensorMap> PrepareData(
......
...@@ -55,6 +55,13 @@ class PreparedOp { ...@@ -55,6 +55,13 @@ class PreparedOp {
const paddle::framework::OperatorWithKernel::OpKernelFunc& func, const paddle::framework::OperatorWithKernel::OpKernelFunc& func,
paddle::platform::DeviceContext* dev_ctx); paddle::platform::DeviceContext* dev_ctx);
PreparedOp(const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel,
paddle::platform::DeviceContext* dev_ctx);
static PreparedOp Prepare( static PreparedOp Prepare(
const NameTensorMap& ins, const NameTensorMap& outs, const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::OperatorWithKernel& op, const paddle::framework::OperatorWithKernel& op,
...@@ -76,6 +83,13 @@ class PreparedOp { ...@@ -76,6 +83,13 @@ class PreparedOp {
paddle::framework::OpKernelType kernel_type_; paddle::framework::OpKernelType kernel_type_;
paddle::framework::OperatorWithKernel::OpKernelFunc func_; paddle::framework::OperatorWithKernel::OpKernelFunc func_;
paddle::platform::DeviceContext* dev_ctx_; paddle::platform::DeviceContext* dev_ctx_;
// NOTE(chenweihang): Similar op members are used to adapt to
// new pten kernel, if there is a better design in the future,
// we may polish the implementation here
bool run_pten_kernel_{false};
paddle::framework::KernelSignature pt_kernel_signature_;
pten::Kernel pt_kernel_;
}; };
} // namespace legacy } // namespace legacy
......
...@@ -185,10 +185,17 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co ...@@ -185,10 +185,17 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_test(no_need_buffer_vars_inference_test SRCS no_need_buffer_vars_inference_test.cc DEPS no_need_buffer_vars_inference layer) cc_test(no_need_buffer_vars_inference_test SRCS no_need_buffer_vars_inference_test.cc DEPS no_need_buffer_vars_inference layer)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_vars_inference) cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_vars_inference)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
IF(WITH_XPU)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info xpu_op_list)
ELSE()
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info)
ENDIF()
IF(WITH_XPU) IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
...@@ -403,8 +410,8 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) ...@@ -403,8 +410,8 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS pten_utils attribute shape_inference op_utils)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
......
...@@ -33,6 +33,9 @@ limitations under the License. */ ...@@ -33,6 +33,9 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten
#include "paddle/pten/kernels/declarations.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place) AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
......
...@@ -32,6 +32,9 @@ limitations under the License. */ ...@@ -32,6 +32,9 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten
#include "paddle/pten/kernels/declarations.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -31,7 +31,7 @@ endif() ...@@ -31,7 +31,7 @@ endif()
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
# skip win32 since wget is not installed by default on windows machine. # skip win32 since wget is not installed by default on windows machine.
# skip COVERAGE_CI since the test runs slowly because of instrumentation. # skip COVERAGE_CI since the test runs slowly because of instrumentation.
if (WITH_TESTING AND NOT WIN32 AND NOT WITH_COVERAGE AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON") if (WITH_CUDA AND WITH_TESTING AND NOT WIN32 AND NOT WITH_COVERAGE AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
add_custom_target( add_custom_target(
download_program download_program
COMMAND wget -nc https://paddle-ci.gz.bcebos.com/new_exec/lm_main_program COMMAND wget -nc https://paddle-ci.gz.bcebos.com/new_exec/lm_main_program
......
...@@ -426,8 +426,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -426,8 +426,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
(*instr_node.PtenKernel())(&pt_kernel_context); (*instr_node.PtenKernel())(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get(), &pt_kernel_context);
} else { } else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
} }
......
...@@ -358,15 +358,6 @@ void build_op_func_list(const platform::Place& place, ...@@ -358,15 +358,6 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->Info().infer_shape_(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
} }
auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
op->Type()));
OpKernelMap& kernels = kernels_iter->second;
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -404,26 +395,41 @@ void build_op_func_list(const platform::Place& place, ...@@ -404,26 +395,41 @@ void build_op_func_list(const platform::Place& place,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
op_func_node.dev_ctx_ = dev_ctx; op_func_node.dev_ctx_ = dev_ctx;
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
auto exec_ctx = auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(
kernel_iter, kernels.end(),
platform::errors::NotFound(
"Operator (%s) does not have kernel for %s.", op->Type(),
KernelTypeToString(expected_kernel_key)));
auto run_pten_kernel = false; auto run_pten_kernel = false;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(
op_with_kernel->Type())) { op_with_kernel->Type())) {
op_with_kernel->ChoosePtenKernel(exec_ctx); auto pt_kernel_key = op_with_kernel->ChoosePtenKernel(exec_ctx);
run_pten_kernel = op_with_kernel->PtenKernel()->IsValid(); auto pt_kernel_name = op_with_kernel->PtenKernelSignature()->name;
}
if (op_with_kernel->PtenKernel()->IsValid()) {
run_pten_kernel = true;
} else {
auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
if (kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end()) {
auto pt_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, pt_kernel_key, *op_with_kernel);
op_with_kernel->ResetPtenKernel(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key)));
if (op_with_kernel->PtenKernel()->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: "
<< pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *(op_with_kernel->PtenKernel());
run_pten_kernel = true;
}
}
}
}
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
if (run_pten_kernel) { if (run_pten_kernel) {
pten::KernelContext pt_kernel_context; pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx, op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx,
...@@ -431,9 +437,22 @@ void build_op_func_list(const platform::Place& place, ...@@ -431,9 +437,22 @@ void build_op_func_list(const platform::Place& place,
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
(*op_func_node.pt_kernel_)(&pt_kernel_context); (*op_func_node.pt_kernel_)(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs(&runtime_context,
&pt_kernel_context);
} else { } else {
auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
op->Type()));
OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(
kernel_iter, kernels.end(),
platform::errors::NotFound(
"Operator (%s) does not have kernel for %s.", op->Type(),
KernelTypeToString(expected_kernel_key)));
// TODO(zhiqiu): add fallback logic
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx); op_func_node.kernel_func_(exec_ctx);
} }
......
...@@ -313,7 +313,7 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const { ...@@ -313,7 +313,7 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
return ((op_with_kernel.kernel_type()) && return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ == (op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN)); framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) { } catch (std::bad_cast& exp) {
return false; return false;
} }
} }
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
...@@ -1144,22 +1145,80 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1144,22 +1145,80 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
#endif #endif
auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx); auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx);
// using cache
if (kernel_type_.get()) {
dev_ctx = pool.Get(kernel_type_->place_);
}
// TODO(chenweihang): Now we are still reusing a lot of the original fluid // TODO(chenweihang): Now we are still reusing a lot of the original fluid
// implementation, this is a gradual replacement process // implementation, this is a gradual replacement process
// TODO(chenweihang): in the first phase of project, we only support CPU, CUDA // TODO(chenweihang): in the first phase of project, we only support CPU, CUDA
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second // and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// phase // phase
if (FLAGS_run_pten_kernel && pten::KernelKey pt_kernel_key;
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
ChoosePtenKernel(exe_ctx); pt_kernel_signature_.reset(new KernelSignature(
std::move(this->GetExpectedPtenKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_);
pt_kernel_name = pt_kernel_signature_->name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));
if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode ChoosePtenKernel - kernel name: "
<< pt_kernel_name << " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_;
} else {
VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
if (pt_kernel_->IsValid()) {
run_pten_kernel_ = true;
} else {
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(*kernel_type_.get()) ==
kernels_iter->second.end()
#ifdef PADDLE_WITH_XPU
||
paddle::platform::is_xpu_place(kernel_type_->place_) && // NOLINT
!paddle::platform::is_xpu_support_op(
type_, *kernel_type_.get()) // NOLINT
|| paddle::platform::is_in_xpu_black_list(type_)
#endif
) {
auto pt_cpu_kernel_key =
FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this);
pt_kernel_.reset(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key)));
dev_ctx = pool.Get(platform::CPUPlace());
if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *pt_kernel_;
run_pten_kernel_ = true;
}
}
} }
run_pten_kernel_ = pt_kernel_->IsValid();
} }
if (!run_pten_kernel_) { if (!run_pten_kernel_) {
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(exe_ctx); ChooseKernel(exe_ctx);
dev_ctx = pool.Get(kernel_type_->place_);
} }
} }
...@@ -1178,10 +1237,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1178,10 +1237,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
const Scope& exec_scope = const Scope& exec_scope =
(transfer_scope == nullptr ? scope : *transfer_scope); (transfer_scope == nullptr ? scope : *transfer_scope);
if (!(kernel_type_->place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(kernel_type_->place_);
}
if (!all_kernels_must_compute_runtime_shape_) { if (!all_kernels_must_compute_runtime_shape_) {
platform::RecordEvent record_event("infer_shape", platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
...@@ -1201,6 +1256,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1201,6 +1256,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (run_pten_kernel_) { if (run_pten_kernel_) {
pten::KernelContext pt_kernel_context; pten::KernelContext pt_kernel_context;
// Do data transform before building KernelContext // Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
PreparePtenData(exec_scope, *pt_kernel_, *pt_kernel_signature_, PreparePtenData(exec_scope, *pt_kernel_, *pt_kernel_signature_,
runtime_ctx); runtime_ctx);
BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
...@@ -1289,7 +1345,8 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1289,7 +1345,8 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
return expected_kernel_key; return expected_kernel_key;
} }
void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const { pten::KernelKey OperatorWithKernel::ChoosePtenKernel(
const ExecutionContext& ctx) const {
pt_kernel_signature_.reset( pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx)))); new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *pt_kernel_signature_.get();
...@@ -1311,6 +1368,7 @@ void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const { ...@@ -1311,6 +1368,7 @@ void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
} }
return pt_kernel_key;
} }
void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
...@@ -1839,25 +1897,21 @@ Scope* OperatorWithKernel::PreparePtenData( ...@@ -1839,25 +1897,21 @@ Scope* OperatorWithKernel::PreparePtenData(
continue; continue;
} }
// TODO(zyfncg): Now there is no kernel which need to transform input VLOG(3) << "PTen Transform Variable " << input_names[i] << " from "
// data, so we commented out following code temporarily, << tensor_in->place() << " to " << expected_place;
// and it will be used in the future.
// VLOG(3) << "PTen Transform Variable " << input_names[i] << " from " if (!new_scope) {
// << tensor_in->place() << " to " << expected_place; new_scope = &scope.NewScope();
}
// if (!new_scope) {
// new_scope = &scope.NewScope();
// }
// // Create new var with the same name in transfer scopes // Create new var with the same name in transfer scopes
// auto* trans_var = new_scope->Var(input_names[i]); auto* trans_var = new_scope->Var(input_names[i]);
// ins_vector[i] = trans_var; ins_vector[offset] = trans_var;
// // Do transfer // Do transfer
// Tensor out; Tensor out;
// framework::TensorCopySync(*tensor_in, expected_place, &out); framework::TensorCopySync(*tensor_in, expected_place, &out);
// SetTensorToVariable(*var, out, trans_var); SetTensorToVariable(*var, out, trans_var);
} }
} }
......
...@@ -525,13 +525,27 @@ class OperatorWithKernel : public OperatorBase { ...@@ -525,13 +525,27 @@ class OperatorWithKernel : public OperatorBase {
} }
bool SupportGPU() const override { bool SupportGPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto pten_kernels = pten::KernelFactory::Instance().SelectKernelMap(
return std::any_of(op_kernels.begin(), op_kernels.end(), pten::TransToPtenKernelName(type_));
[](OpKernelMap::const_reference kern_pair) { auto has_pten_kernel = std::any_of(
return platform::is_gpu_place(kern_pair.first.place_); pten_kernels.begin(), pten_kernels.end(),
}); [](pten::KernelFactory::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == pten::Backend::GPU;
});
if (has_pten_kernel) {
return true;
} else {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(
op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
} }
bool SupportNPU() const override { bool SupportNPU() const override {
// TODO(zhiqiu): support pten if needed?
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(), return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) { [](OpKernelMap::const_reference kern_pair) {
...@@ -539,6 +553,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -539,6 +553,7 @@ class OperatorWithKernel : public OperatorBase {
}); });
} }
bool SupportMLU() const override { bool SupportMLU() const override {
// TODO(zhiqiu): support pten if needed?
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(), return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) { [](OpKernelMap::const_reference kern_pair) {
...@@ -583,18 +598,18 @@ class OperatorWithKernel : public OperatorBase { ...@@ -583,18 +598,18 @@ class OperatorWithKernel : public OperatorBase {
* When selecting Kernel during Op execution, select the arguments of the * When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtenKernelArgs returned arguments. * original Op according to the GetExpectedPtenKernelArgs returned arguments.
*/ */
virtual KernelSignature GetExpectedPtenKernelArgs( virtual pten::KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const; const ExecutionContext& ctx) const;
/* member functions for adapting to pten lib */ /* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const; pten::KernelKey ChoosePtenKernel(const ExecutionContext& ctx) const;
/** /**
* Transfer data place for pten kernel * Transfer data place for pten kernel
* Is this really needed? * Is this really needed?
*/ */
Scope* PreparePtenData(const Scope& scope, const pten::Kernel& pt_kernel, Scope* PreparePtenData(const Scope& scope, const pten::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, const pten::KernelSignature& pt_kernel_signature,
RuntimeContext* ctx) const; RuntimeContext* ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx, void BuildPtenKernelContext(const RuntimeContext& ctx,
...@@ -604,8 +619,16 @@ class OperatorWithKernel : public OperatorBase { ...@@ -604,8 +619,16 @@ class OperatorWithKernel : public OperatorBase {
void WriteBackToOutputs(RuntimeContext* ctx, void WriteBackToOutputs(RuntimeContext* ctx,
pten::KernelContext* pt_kernel_context) const; pten::KernelContext* pt_kernel_context) const;
pten::KernelSignature* PtenKernelSignature() const {
return pt_kernel_signature_.get();
}
pten::Kernel* PtenKernel() const { return pt_kernel_.get(); } pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }
void ResetPtenKernel(pten::Kernel* kernel) const {
return pt_kernel_.reset(kernel);
}
const OpKernelType* kernel_type() const { return kernel_type_.get(); } const OpKernelType* kernel_type() const { return kernel_type_.get(); }
private: private:
...@@ -662,7 +685,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -662,7 +685,7 @@ class OperatorWithKernel : public OperatorBase {
// new pten kernel, if there is a better design in the future, // new pten kernel, if there is a better design in the future,
// we may polish the implementation here // we may polish the implementation here
mutable bool run_pten_kernel_ = false; mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<pten::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_; mutable std::unique_ptr<pten::Kernel> pt_kernel_;
}; };
......
...@@ -90,6 +90,40 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( ...@@ -90,6 +90,40 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
return pten::KernelKey(backend, layout, dtype); return pten::KernelKey(backend, layout, dtype);
} }
pten::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
const pten::KernelKey& kernel_key,
const framework::OperatorBase& op) {
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "pten missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return pten::KernelKey(pten::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return pten::KernelKey(pten::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing MLU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return pten::KernelKey(pten::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif
return pten::KernelKey();
}
const paddle::SmallVector<std::string>& const paddle::SmallVector<std::string>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() { KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) { for (int i = 0; i < op_proto_->inputs_size(); ++i) {
......
...@@ -24,12 +24,18 @@ limitations under the License. */ ...@@ -24,12 +24,18 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,6 +47,9 @@ OpKernelType TransPtenKernelKeyToOpKernelType( ...@@ -41,6 +47,9 @@ OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key); const pten::KernelKey& kernel_key);
pten::KernelKey TransOpKernelTypeToPtenKernelKey( pten::KernelKey TransOpKernelTypeToPtenKernelKey(
const OpKernelType& kernel_type); const OpKernelType& kernel_type);
pten::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
const pten::KernelKey& kernel_key,
const framework::OperatorBase& op);
/* Kernel Args parse */ /* Kernel Args parse */
......
...@@ -55,21 +55,6 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { ...@@ -55,21 +55,6 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
} }
} }
static const framework::Attribute& GetAttr(
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const std::string& name) {
auto it = attrs.find(name);
bool found = it != attrs.end();
if (!found) {
it = default_attrs.find(name);
found = it != default_attrs.end();
}
PADDLE_ENFORCE_EQ(
found, true,
platform::errors::NotFound("(%s) is not found in AttributeMap.", name));
return it->second;
}
template <typename VarType> template <typename VarType>
static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) { static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
for (auto& pair : outs) { for (auto& pair : outs) {
...@@ -152,6 +137,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -152,6 +137,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
} }
} }
#endif #endif
// NOTE(zhiqiu): for kernels on given device, for example NPU, the order to
// choose is:
// pten npu kernel > fluid npu kernel > pten cpu kernel > fluid cpu kernel
// 1. get expected kernel key // 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>( auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
...@@ -159,13 +147,15 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -159,13 +147,15 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
if (FLAGS_run_pten_kernel && framework::KernelSignature pt_kernel_signature;
pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { pten::KernelKey pt_kernel_key;
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
VLOG(6) << pt_kernel_signature; VLOG(6) << pt_kernel_signature;
auto pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
auto pt_kernel = pten::KernelFactory::Instance().SelectKernel( auto pt_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key); pt_kernel_name, pt_kernel_key);
...@@ -191,14 +181,42 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -191,14 +181,42 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// 2. check if op[type] has kernel registered. // 2. check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels(); auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type()); auto kernels_iter = all_op_kernels.find(op.Type());
if ((kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end())
#ifdef PADDLE_WITH_XPU
||
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type())
#endif
) {
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_cpu_kernel, cpu_ctx);
}
}
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(), kernels_iter, all_op_kernels.end(),
platform::errors::NotFound( platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.",
op.Type())); op.Type()));
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || (kernel_iter == kernels.end() ||
...@@ -264,237 +282,6 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, ...@@ -264,237 +282,6 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
default_attrs); default_attrs);
} }
template <typename VarType>
void PreparePtenData(const pten::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) {
auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
"the size of inputs_args names (%d) must be equal to "
"the size of kernel input_defs (%d).",
input_names.size(), input_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) {
auto& in_def = input_defs.at(i);
auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var_base = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var_base->Var());
if (tensor_in && tensor_in->IsInitialized()) {
auto expected_place = pten::TransToFluidPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue;
}
// TODO(zyfncg): Now there is no kernel which need to transform input
// data, so we commented out following code temporarily,
// and it will be used in the future.
// VLOG(3) << "Pten Transform Variable " << var_base->Name() << " from "
// << tensor_in->place() << " to " << expected_place;
// framework::Tensor tmp_tensor;
// framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor);
// SetTensorToVariable(var_base->Var(), tmp_tensor,
// var_base->MutableVar());
}
}
}
}
template <typename VarType>
static void BuildDygraphPtenKernelContext(
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) {
kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& attr_names = std::get<1>(pt_kernel_signature.args);
auto& output_names = std::get<2>(pt_kernel_signature.args);
auto& input_defs = pt_kernel.args_def().input_defs();
auto& output_defs = pt_kernel.args_def().output_defs();
auto& attr_defs = pt_kernel.args_def().attribute_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
"the size of inputs_args names (%d) must be equal to "
"the size of kernel input_defs (%d).",
input_names.size(), input_defs.size()));
PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(),
platform::errors::InvalidArgument(
"the size of outputs_args names (%d) must be equal to "
"the size of kernel output_defs (%d).",
output_names.size(), output_defs.size()));
PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(),
platform::errors::InvalidArgument(
"the size of attribute_args names (%d) must be equal "
"to the size of kernel attribute_defs (%d).",
attr_names.size(), attr_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) {
auto& ins_vector = ins.at(input_names[i]);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const auto* tensor_in = GetTensorFromVar(ins_vector[offset]->Var());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
}
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
for (size_t i = 0; i < output_names.size(); ++i) {
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}
auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (outs_vector[offset] == nullptr) {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
continue;
}
auto* var = outs_vector[offset]->MutableVar();
framework::Tensor* tensor_out = nullptr;
if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
} // TODO(zyfncg): Add support for SelectedRows
experimental::ResetTensorByArgDef(tensor_out, output_defs.at(i));
framework::SetAllocationForOutputTenosr(
tensor_out, pten::TransToFluidPlace(output_defs.at(i).backend));
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}
for (size_t i = 0; i < attr_names.size(); ++i) {
if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList
std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size());
for (const auto& var_base : ins_vector) {
variables.push_back(var_base->MutableVar());
}
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(variables)));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (attrs.find(attr_names[i]) != attrs.end() ||
default_attrs.find(attr_names[i]) !=
default_attrs.end()) { // scalar is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(int, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(ins_vector[0]->Var())));
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
kernel_ctx->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr);
}
// TODO(YuanRisheng) Need support vector<int64_t> attr
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
}
}
}
template <typename VarType> template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
......
...@@ -194,5 +194,247 @@ class PreparedOp { ...@@ -194,5 +194,247 @@ class PreparedOp {
pten::Kernel pt_kernel_; pten::Kernel pt_kernel_;
}; };
const inline framework::Attribute& GetAttr(
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const std::string& name) {
auto it = attrs.find(name);
bool found = it != attrs.end();
if (!found) {
it = default_attrs.find(name);
found = it != default_attrs.end();
}
PADDLE_ENFORCE_EQ(
found, true,
platform::errors::NotFound("(%s) is not found in AttributeMap.", name));
return it->second;
}
template <typename VarType>
void BuildDygraphPtenKernelContext(
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) {
kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& attr_names = std::get<1>(pt_kernel_signature.args);
auto& output_names = std::get<2>(pt_kernel_signature.args);
auto& input_defs = pt_kernel.args_def().input_defs();
auto& output_defs = pt_kernel.args_def().output_defs();
auto& attr_defs = pt_kernel.args_def().attribute_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
"the size of inputs_args names (%d) must be equal to "
"the size of kernel input_defs (%d).",
input_names.size(), input_defs.size()));
PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(),
platform::errors::InvalidArgument(
"the size of outputs_args names (%d) must be equal to "
"the size of kernel output_defs (%d).",
output_names.size(), output_defs.size()));
PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(),
platform::errors::InvalidArgument(
"the size of attribute_args names (%d) must be equal "
"to the size of kernel attribute_defs (%d).",
attr_names.size(), attr_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) {
auto& ins_vector = ins.at(input_names[i]);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const auto* tensor_in = GetTensorFromVar(ins_vector[offset]->Var());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
}
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
for (size_t i = 0; i < output_names.size(); ++i) {
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}
auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (outs_vector[offset] == nullptr) {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
continue;
}
auto* var = outs_vector[offset]->MutableVar();
framework::Tensor* tensor_out = nullptr;
if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
} // TODO(zyfncg): Add support for SelectedRows
experimental::ResetTensorByArgDef(tensor_out, output_defs.at(i));
framework::SetAllocationForOutputTenosr(
tensor_out, pten::TransToFluidPlace(output_defs.at(i).backend));
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}
for (size_t i = 0; i < attr_names.size(); ++i) {
if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList
std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size());
for (const auto& var_base : ins_vector) {
variables.push_back(var_base->MutableVar());
}
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(variables)));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (attrs.find(attr_names[i]) != attrs.end() ||
default_attrs.find(attr_names[i]) !=
default_attrs.end()) { // scalar is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(int, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(ins_vector[0]->Var())));
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
kernel_ctx->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr);
}
// TODO(YuanRisheng) Need support vector<int64_t> attr
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
}
}
}
template <typename VarType>
void PreparePtenData(const pten::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) {
auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
"the size of inputs_args names (%d) must be equal to "
"the size of kernel input_defs (%d).",
input_names.size(), input_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) {
auto& in_def = input_defs.at(i);
auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var_base = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var_base->Var());
if (tensor_in && tensor_in->IsInitialized()) {
auto expected_place = pten::TransToFluidPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue;
}
VLOG(3) << "Pten Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place;
framework::Tensor tmp_tensor;
framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor);
SetTensorToVariable(var_base->Var(), tmp_tensor,
var_base->MutableVar());
}
}
}
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,9 @@ limitations under the License. */ ...@@ -26,6 +26,9 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten
#include "paddle/pten/kernels/declarations.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma."); DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
DEFINE_int32(math_num_threads, 1, DEFINE_int32(math_num_threads, 1,
"Number of threads used to run math functions."); "Number of threads used to run math functions.");
......
...@@ -115,10 +115,8 @@ if (WITH_GPU OR WITH_ROCM) ...@@ -115,10 +115,8 @@ if (WITH_GPU OR WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
op_library(sync_batch_norm_op) op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.3) ) if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.3) )
op_library(sparse_attention_op) op_library(sparse_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n")
endif() endif()
else() else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
...@@ -142,7 +140,6 @@ endif() ...@@ -142,7 +140,6 @@ endif()
if (WITH_ASCEND_CL) if (WITH_ASCEND_CL)
op_library(sync_batch_norm_op) op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(sync_batch_norm);\n")
endif() endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
...@@ -153,7 +150,6 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) ...@@ -153,7 +150,6 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
if (WITH_DGC) if (WITH_DGC)
op_library(dgc_op DEPS dgc) op_library(dgc_op DEPS dgc)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc)
endif() endif()
......
...@@ -24,6 +24,9 @@ limitations under the License. */ ...@@ -24,6 +24,9 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
// pten
#include "paddle/pten/kernels/declarations.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace benchmark { namespace benchmark {
......
...@@ -133,23 +133,27 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -133,23 +133,27 @@ class CastOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
#define REGISTER_CAST_CPU_BASE(op_name, ...) \
REGISTER_OPERATOR(op_name, ops::CastOp, \ // cast use pten kernel, so no need to REGISTER_OP_CPU_KERNEL here.
ops::CastOpGradMaker<paddle::framework::OpDesc>, \ REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::imperative::OpBase>, \ ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpProtoMaker); \ ops::CastOpGradMaker<paddle::imperative::OpBase>,
REGISTER_OP_CPU_KERNEL( \ ops::CastOpProtoMaker);
op_name, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>, \
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>, \
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int16_t>, \
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>, \
ops::CastOpKernel<CPU, paddle::platform::float16>, \
ops::CastOpKernel<CPU, paddle::platform::bfloat16>, \
ops::CastOpKernel<CPU, paddle::platform::complex<float>>, \
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
REGISTER_CAST_CPU_BASE(cast)
// [ why register transfer_dtype_op alias with cast_op? ] // [ why register transfer_dtype_op alias with cast_op? ]
// In case of InterpreterCore, if we reuse cast_op, we cannot distinguish // In case of InterpreterCore, if we reuse cast_op, we cannot distinguish
// which cast_op is inserted by new executor when we do profiling. // which cast_op is inserted by new executor when we do profiling.
REGISTER_CAST_CPU_BASE(transfer_dtype) REGISTER_OPERATOR(transfer_dtype, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
transfer_dtype, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>, ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>, ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int16_t>, ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
...@@ -30,10 +30,8 @@ using CUDA = paddle::platform::CUDADeviceContext; ...@@ -30,10 +30,8 @@ using CUDA = paddle::platform::CUDADeviceContext;
ops::CastOpKernel<CUDA, plat::complex<double>>, ##__VA_ARGS__); ops::CastOpKernel<CUDA, plat::complex<double>>, ##__VA_ARGS__);
#if !defined(PADDLE_WITH_HIP) #if !defined(PADDLE_WITH_HIP)
REGISTER_CAST_CUDA_BASE(cast, ops::CastOpKernel<CUDA, plat::bfloat16>)
// See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc // See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc
REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel<CUDA, plat::bfloat16>) REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel<CUDA, plat::bfloat16>)
#else #else
REGISTER_CAST_CUDA_BASE(cast)
REGISTER_CAST_CUDA_BASE(transfer_dtype) REGISTER_CAST_CUDA_BASE(transfer_dtype)
#endif #endif
...@@ -22,9 +22,3 @@ endif() ...@@ -22,9 +22,3 @@ endif()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n") file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n")
file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n")
if(WITH_XPU)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(equal, XPU);\nUSE_OP_DEVICE_KERNEL(not_equal, XPU);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(less_than, XPU);\nUSE_OP_DEVICE_KERNEL(less_equal, XPU);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(greater_than, XPU);\nUSE_OP_DEVICE_KERNEL(greater_equal, XPU);\n")
endif()
...@@ -8,7 +8,20 @@ function(detection_library TARGET_NAME) ...@@ -8,7 +8,20 @@ function(detection_library TARGET_NAME)
set(pybind_flag 0) set(pybind_flag 0)
cmake_parse_arguments(detection_library "${options}" "${oneValueArgs}" cmake_parse_arguments(detection_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN}) "${multiValueArgs}" ${ARGN})
op_library(${TARGET_NAME} SRCS ${detection_library_SRCS} DEPS ${common_deps} ${detection_library_DEPS}) set(srcs)
# filter cuda source file when not build with cuda/rocm
foreach(src ${detection_library_SRCS})
if (NOT WITH_GPU AND NOT WITH_ROCM)
if(${src} MATCHES ".*\\.cc$")
list(APPEND srcs ${src})
endif()
else()
list(APPEND srcs ${src})
endif()
endforeach()
op_library(${TARGET_NAME} SRCS ${srcs} DEPS ${common_deps} ${detection_library_DEPS})
set(LOCAL_DETECTION_LIBS set(LOCAL_DETECTION_LIBS
${TARGET_NAME} ${TARGET_NAME}
${LOCAL_DETECTION_LIBS} ${LOCAL_DETECTION_LIBS}
......
...@@ -24,8 +24,6 @@ register_operators(EXCLUDES ...@@ -24,8 +24,6 @@ register_operators(EXCLUDES
# fusion_gru_op does not have CUDA kernel # fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op) op_library(fusion_gru_op)
op_library(fusion_lstm_op) op_library(fusion_lstm_op)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n")
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_lstm);\n")
if (WITH_GPU OR WITH_ROCM) if (WITH_GPU OR WITH_ROCM)
...@@ -33,46 +31,36 @@ if (WITH_GPU OR WITH_ROCM) ...@@ -33,46 +31,36 @@ if (WITH_GPU OR WITH_ROCM)
# HIP not support bn act fuse in MIOPEN # HIP not support bn act fuse in MIOPEN
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401)) if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401))
op_library(fused_bn_activation_op) op_library(fused_bn_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif() endif()
# conv_fusion_op needs cudnn 7 above # conv_fusion_op needs cudnn 7 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100) if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
op_library(conv_fusion_op) op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif() endif()
# fusion_transpose_flatten_concat_op # fusion_transpose_flatten_concat_op
# HIP not support cudnnTransformTensor # HIP not support cudnnTransformTensor
if(NOT WITH_ROCM) if(NOT WITH_ROCM)
op_library(fusion_transpose_flatten_concat_op) op_library(fusion_transpose_flatten_concat_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
endif() endif()
# fusion_conv_inception_op needs cudnn 7 above # fusion_conv_inception_op needs cudnn 7 above
# HIP not support cudnnConvolutionBiasActivationForward # HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100)) if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
op_library(fusion_conv_inception_op) op_library(fusion_conv_inception_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n")
endif() endif()
# fused_fc_elementwise_layernorm_op # fused_fc_elementwise_layernorm_op
op_library(fused_fc_elementwise_layernorm_op) op_library(fused_fc_elementwise_layernorm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_fc_elementwise_layernorm);\n")
# multihead_matmul_op # multihead_matmul_op
op_library(multihead_matmul_op) op_library(multihead_matmul_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n")
op_library(skip_layernorm_op) op_library(skip_layernorm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(skip_layernorm);\n")
op_library(fused_embedding_eltwise_layernorm_op) op_library(fused_embedding_eltwise_layernorm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n")
# fusion_group # fusion_group
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
op_library(fusion_group_op DEPS device_code) op_library(fusion_group_op DEPS device_code)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_group);\n")
cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op) cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op)
endif() endif()
# fused_bn_add_activation # fused_bn_add_activation
# HIP not support bn act fuse in MIOPEN # HIP not support bn act fuse in MIOPEN
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401)) if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401))
op_library(fused_bn_add_activation_op) op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
endif() endif()
# fused_dropout # fused_dropout
# only support CUDA # only support CUDA
...@@ -82,15 +70,12 @@ if (WITH_GPU OR WITH_ROCM) ...@@ -82,15 +70,12 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
op_library(fused_feedforward_op) op_library(fused_feedforward_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_feedforward);\n")
# fused_attention_op # fused_attention_op
op_library(fused_attention_op) op_library(fused_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
endif() endif()
# resnet_unit needs cudnn 8.0 above # resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
op_library(resnet_unit_op) op_library(resnet_unit_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n")
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory) cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory)
endif() endif()
......
...@@ -103,39 +103,69 @@ element of X as a tensor. ...@@ -103,39 +103,69 @@ element of X as a tensor.
namespace ops = paddle::operators; namespace ops = paddle::operators;
#define REGISTER_V2OP_MAKER(op_type, comment) \ #define REGISTER_V2OP_MAKER(op_type, comment) \
namespace paddle { \ namespace paddle { \
namespace operators { \ namespace operators { \
class _##op_type##OverflowV2OpMaker \ class _##op_type##OverflowV2OpMaker \
: public ::paddle::operators::OverflowV2OpMaker { \ : public ::paddle::operators::OverflowV2OpMaker { \
protected: \ protected: \
std::string GetName() const { return #op_type; } \ std::string GetName() const { return #op_type; } \
std::string GetComments() const { return comment; } \ std::string GetComments() const { return comment; } \
}; \ }; \
} \ } \
} \ }
REGISTER_OPERATOR( \
op_type, ops::OverflowV2Op, ops::_##op_type##OverflowV2OpMaker, \ REGISTER_V2OP_MAKER(isinf_v2, "isinfv2(X)")
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \ REGISTER_V2OP_MAKER(isnan_v2, "isnanv2(X)")
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, int64_t, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, double, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, plat::float16, \
ops::functor>);
REGISTER_V2OP_MAKER(isinf_v2, "isinfv2(X)");
REGISTER_V2OP_MAKER(isnan_v2, "isnanv2(X)");
REGISTER_V2OP_MAKER(isfinite_v2, "isfinitev2(X)"); REGISTER_V2OP_MAKER(isfinite_v2, "isfinitev2(X)");
REGISTER_OVERFLOW_CPU_KERNEL(isinf_v2, InfinityV2Functor); REGISTER_OPERATOR(
REGISTER_OVERFLOW_CPU_KERNEL(isnan_v2, NANV2Functor); isinf_v2, ops::OverflowV2Op, ops::_isinf_v2OverflowV2OpMaker,
REGISTER_OVERFLOW_CPU_KERNEL(isfinite_v2, IsfiniteV2Functor); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
isnan_v2, ops::OverflowV2Op, ops::_isnan_v2OverflowV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
isfinite_v2, ops::OverflowV2Op, ops::_isfinite_v2OverflowV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(isnan_v2,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::NANV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::NANV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::NANV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::NANV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
plat::float16, ops::NANV2Functor>);
REGISTER_OP_CPU_KERNEL(
isinf_v2, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, float,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, double,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, plat::float16,
ops::InfinityV2Functor>);
REGISTER_OP_CPU_KERNEL(
isfinite_v2, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, float,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, double,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext, plat::float16,
ops::IsfiniteV2Functor>);
...@@ -18,19 +18,38 @@ ...@@ -18,19 +18,38 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \ REGISTER_OP_CUDA_KERNEL(isnan_v2,
REGISTER_OP_CUDA_KERNEL( \ ops::OverflowKernel<paddle::platform::CUDADeviceContext,
op_type, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int, \ int, ops::NANV2Functor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, int64_t, \ int64_t, ops::NANV2Functor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float, \ float, ops::NANV2Functor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double, \ double, ops::NANV2Functor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16, \ plat::float16, ops::NANV2Functor>);
ops::functor>);
REGISTER_OVERFLOW_CUDA_KERNEL(isinf_v2, InfinityV2Functor); REGISTER_OP_CUDA_KERNEL(
REGISTER_OVERFLOW_CUDA_KERNEL(isnan_v2, NANV2Functor); isinf_v2, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
REGISTER_OVERFLOW_CUDA_KERNEL(isfinite_v2, IsfiniteV2Functor); ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::InfinityV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::InfinityV2Functor>);
REGISTER_OP_CUDA_KERNEL(
isfinite_v2, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::IsfiniteV2Functor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::IsfiniteV2Functor>);
...@@ -12,7 +12,6 @@ endif() ...@@ -12,7 +12,6 @@ endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
op_library(nccl_op DEPS nccl_common) op_library(nccl_op DEPS nccl_common)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n")
set(OPERATOR_DEPS ${OPERATOR_DEPS} nccl_common PARENT_SCOPE) set(OPERATOR_DEPS ${OPERATOR_DEPS} nccl_common PARENT_SCOPE)
endif() endif()
......
...@@ -13,24 +13,6 @@ else() ...@@ -13,24 +13,6 @@ else()
register_operators() register_operators()
endif() endif()
if(WITH_GPU OR WITH_ROCM)
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.part.cu")
string(REPLACE ".part.cu" "" OPS "${OPS}")
foreach(src ${OPS})
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu)
set(CUDA_KERNEL_FILE ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu)
file(READ ${CUDA_KERNEL_FILE} TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OP_CUDA_KERNEL\\(\\n?([^,]+),.*" MATCHED ${TARGET_CONTENT})
if (MATCHED)
string(STRIP ${CMAKE_MATCH_1} MATCHED)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MATCHED}, CUDA);\n")
endif()
endif()
endforeach()
endif()
if(WITH_GPU) if(WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor cub) nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor cub)
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sign_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
...@@ -65,13 +65,3 @@ REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>, ...@@ -65,13 +65,3 @@ REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker<paddle::framework::OpDesc>, ops::SignGradMaker<paddle::framework::OpDesc>,
ops::SignGradMaker<paddle::imperative::OpBase>, ops::SignGradMaker<paddle::imperative::OpBase>,
SignInferShapeFunctor); SignInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
sign, ops::SignKernel<paddle::platform::CPUDeviceContext, float>,
ops::SignKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
sign,
paddle::operators::SignKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::SignKernel<paddle::platform::CUDADeviceContext, double>,
paddle::operators::SignKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/sign_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter infer_io_utils analysis_helper) op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter infer_io_utils analysis_helper)
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(tensorrt_engine);\n")
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op DEPS tensorrt_engine_op
analysis) analysis)
...@@ -4,7 +4,8 @@ endif() ...@@ -4,7 +4,8 @@ endif()
set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl) set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl)
cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib device_context place pten_xpu_info) cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib device_context place pten_xpu_info)
cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context) cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context op_kernel_type)
add_subdirectory(tests) add_subdirectory(tests)
...@@ -32,6 +32,9 @@ ...@@ -32,6 +32,9 @@
#endif #endif
#include "paddle/fluid/pybind/op_function_generator.h" #include "paddle/fluid/pybind/op_function_generator.h"
// pten
#include "paddle/pten/kernels/declarations.h"
// clang-format off // clang-format off
const char* OUT_INITIALIZER_TEMPLATE = const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"; R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
......
...@@ -32,6 +32,9 @@ ...@@ -32,6 +32,9 @@
#include "paddle/fluid/framework/fleet/ascend_wrapper.h" #include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif #endif
// pten
#include "paddle/pten/kernels/declarations.h"
// NOTE(pangyoki): Inplace OP with duplicable input. // NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input. // The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy // The first Varbase in input needs to be specified for the inplace strategy
...@@ -395,7 +398,7 @@ GenerateOpFunctions() { ...@@ -395,7 +398,7 @@ GenerateOpFunctions() {
continue; continue;
} }
auto& op_type = op_proto->type(); auto& op_type = op_proto->type();
// Skip ooerator which is not inherit form OperatorWithKernel, like while, // Skip operator which is not inherit form OperatorWithKernel, like while,
// since only OperatorWithKernel can run in dygraph mode. // since only OperatorWithKernel can run in dygraph mode.
// if the pten lib contains op kernel, we still generate ops method // if the pten lib contains op kernel, we still generate ops method
if (!all_kernels.count(op_type) && if (!all_kernels.count(op_type) &&
......
...@@ -208,14 +208,14 @@ class Kernel { ...@@ -208,14 +208,14 @@ class Kernel {
*/ */
class KernelFactory { class KernelFactory {
public: public:
// replaced by paddle::flat_hash_map later using KernelKeyMap =
using KernelMap = paddle::flat_hash_map< paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
std::string,
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>>; using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
static KernelFactory& Instance(); static KernelFactory& Instance();
KernelMap& kernels() { return kernels_; } KernelNameMap& kernels() { return kernels_; }
bool HasCompatiblePtenKernel(const std::string& op_type) const { bool HasCompatiblePtenKernel(const std::string& op_type) const {
return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end(); return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end();
...@@ -232,13 +232,12 @@ class KernelFactory { ...@@ -232,13 +232,12 @@ class KernelFactory {
Kernel SelectKernel(const std::string& kernel_name, Kernel SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash> SelectKernelMap( KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
const std::string& kernel_name) const;
private: private:
KernelFactory() = default; KernelFactory() = default;
KernelMap kernels_; KernelNameMap kernels_;
}; };
inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
......
...@@ -23,6 +23,9 @@ limitations under the License. */ ...@@ -23,6 +23,9 @@ limitations under the License. */
#include <boost/variant.hpp> #include <boost/variant.hpp>
namespace egr {
class EagerTensor;
}
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// The order should be as same as framework.proto // The order should be as same as framework.proto
...@@ -71,6 +74,13 @@ template <> ...@@ -71,6 +74,13 @@ template <>
struct NameVarMapTrait<VariableWrapper> { struct NameVarMapTrait<VariableWrapper> {
using Type = std::map<std::string, SavedVariableWrapperList>; using Type = std::map<std::string, SavedVariableWrapperList>;
}; };
template <>
struct NameVarMapTrait<egr::EagerTensor> {
using Type =
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>>;
};
} // namespace details } // namespace details
template <typename T> template <typename T>
...@@ -78,6 +88,7 @@ using NameVarMap = typename details::NameVarMapTrait<T>::Type; ...@@ -78,6 +88,7 @@ using NameVarMap = typename details::NameVarMapTrait<T>::Type;
using NameVarBaseMap = NameVarMap<VarBase>; using NameVarBaseMap = NameVarMap<VarBase>;
using NameVariableWrapperMap = NameVarMap<VariableWrapper>; using NameVariableWrapperMap = NameVarMap<VariableWrapper>;
using NameTensorMap = NameVarMap<egr::EagerTensor>;
using VariableWrapperList = std::vector<std::shared_ptr<VariableWrapper>>; using VariableWrapperList = std::vector<std::shared_ptr<VariableWrapper>>;
......
set(kernel_declare_file ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h.tmp CACHE INTERNAL "declarations.h file") set(kernel_declare_file ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h.tmp CACHE INTERNAL "declarations.h file")
set(kernel_declare_file_final ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h) set(kernel_declare_file_final ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h)
file(WRITE ${kernel_declare_file} "// Generated by the paddle/pten/kernels/CMakeLists.txt. DO NOT EDIT!\n\n#pragma once\n\n") file(WRITE ${kernel_declare_file} "// Generated by the paddle/pten/kernels/CMakeLists.txt. DO NOT EDIT!\n\n#pragma once\n\n")
file(APPEND ${kernel_declare_file} "#include \"paddle/pten/core/kernel_registry.h\"\n\n")
# pten functors and functions called by kernels # pten functors and functions called by kernels
add_subdirectory(funcs) add_subdirectory(funcs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册