diff --git a/cmake/operators.cmake b/cmake/operators.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3d8a6aa23e676411093f775ed516e8ece5647580 --- /dev/null +++ b/cmake/operators.cmake @@ -0,0 +1,219 @@ +set(PART_CUDA_KERNEL_FILES) +function(op_library TARGET) + # op_library is a function to create op library. The interface is same as + # cc_library. But it handle split GPU/CPU code and link some common library + # for ops. + set(cc_srcs) + set(cu_srcs) + set(hip_cu_srcs) + set(miopen_hip_cc_srcs) + set(cu_cc_srcs) + set(cudnn_cu_cc_srcs) + set(CUDNN_FILE) + set(mkldnn_cc_srcs) + set(MKLDNN_FILE) + set(op_common_deps operator op_registry math_function) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + set(pybind_flag 0) + cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + list(LENGTH op_library_SRCS op_library_SRCS_len) + if (${op_library_SRCS_len} EQUAL 0) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + list(APPEND cc_srcs ${TARGET}.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) + list(APPEND cu_cc_srcs ${TARGET}.cu.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) + list(APPEND cu_srcs ${TARGET}.cu) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) + set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu + ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE) + list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) + endif() + + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) + list(APPEND hip_cu_srcs ${TARGET}.hip.cu) + endif() + string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) + list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) + endif() + if(WITH_AMD_GPU) + string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc) + list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc) + endif() + endif() + if(WITH_MKLDNN) + string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) + list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc) + endif() + endif() + else() + foreach(src ${op_library_SRCS}) + if (${src} MATCHES ".*\\.hip.cu$") + list(APPEND hip_cu_srcs ${src}) + elseif (${src} MATCHES ".*\\.cu$") + list(APPEND cu_srcs ${src}) + elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") + list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$") + list(APPEND miopen_hip_cc_srcs ${src}) + elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") + list(APPEND mkldnn_cc_srcs ${src}) + elseif(${src} MATCHES ".*\\.cu.cc$") + list(APPEND cu_cc_srcs ${src}) + elseif(${src} MATCHES ".*\\.cc$") + list(APPEND cc_srcs ${src}) + else() + message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu") + endif() + endforeach() + endif() + + list(LENGTH cc_srcs cc_srcs_len) + if (${cc_srcs_len} EQUAL 0) + message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") + endif() + if (WIN32) + # remove windows unsupported op, because windows has no nccl, no warpctc such ops. + foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op" + "crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op" + "fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op") + if ("${TARGET}" STREQUAL "${windows_unsupport_op}") + return() + endif() + endforeach() + endif(WIN32) + set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs") + + list(LENGTH op_library_DEPS op_library_DEPS_len) + if (${op_library_DEPS_len} GREATER 0) + set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) + endif() + if (WITH_GPU) + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) + elseif (WITH_AMD_GPU) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) + else() + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) + endif() + + # Define operators that don't need pybind here. + foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" +"tensor_array_read_write_op" "tensorrt_engine_op") + if ("${TARGET}" STREQUAL "${manual_pybind_op}") + set(pybind_flag 1) + endif() + endforeach() + + # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. + # Note that it's enough to just adding one operator to pybind in a *_op.cc file. + # And for detail pybind information, please see generated paddle/pybind/pybind.h. + file(READ ${TARGET}.cc TARGET_CONTENT) + string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}") + string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}") + if (one_register STREQUAL "") + string(REPLACE "_op" "" TARGET "${TARGET}") + else () + string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}") + string(REPLACE "," "" TARGET "${TARGET}") + endif() + + # pybind USE_NO_KERNEL_OP + # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel + string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") + string(REPLACE "_op" "" TARGET "${TARGET}") + if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + + # pybind USE_CPU_ONLY_OP + list(LENGTH cu_srcs cu_srcs_len) + list(LENGTH cu_cc_srcs cu_cc_srcs_len) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + list(LENGTH hip_cu_srcs hip_cu_srcs_len) + list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND + ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0) + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + + # pybind USE_OP_DEVICE_KERNEL for CUDNN + list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) + if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") + endif() + + # pybind USE_OP_DEVICE_KERNEL for MIOPEN + if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n") + endif() + + # pybind USE_OP_DEVICE_KERNEL for MKLDNN + if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + # Append first implemented MKLDNN activation operator + if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") + else() + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() + endif() + + # pybind USE_OP + if (${pybind_flag} EQUAL 0) + # NOTE(*): activation use macro to regist the kernels, set use_op manually. + if(${TARGET} STREQUAL "activation") + file(APPEND ${pybind_file} "USE_OP(relu);\n") + elseif(${TARGET} STREQUAL "fake_dequantize") + file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") + elseif(${TARGET} STREQUAL "fake_quantize") + file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") + elseif(${TARGET} STREQUAL "tensorrt_engine_op") + message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") + elseif(${TARGET} STREQUAL "fc") + # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + else() + file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") + endif() + endif() +endfunction() + + +function(register_operators) + set(options "") + set(oneValueArgs "") + set(multiValueArgs EXCLUDES DEPS) + cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") + string(REPLACE "_mkldnn" "" OPS "${OPS}") + string(REPLACE ".cc" "" OPS "${OPS}") + list(REMOVE_DUPLICATES OPS) + list(LENGTH register_operators_DEPS register_operators_DEPS_len) + + foreach(src ${OPS}) + list(FIND register_operators_EXCLUDES ${src} _index) + if (${_index} EQUAL -1) + if (${register_operators_DEPS_len} GREATER 0) + op_library(${src} DEPS ${register_operators_DEPS}) + else() + op_library(${src}) + endif() + endif() + endforeach() +endfunction() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index da835b33051c310064a906612ae7f9362f95c7d5..da8941c351571a8ff43974321490065079c2c0b4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -93,11 +93,11 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)) paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)) paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times'], varargs=None, keywords=None, defaults=(0, False)) +paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False)) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) -paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) +paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index 21e0cb3f91cc0ae05513c3bbd470650ca71194d7..2d2323edc3a6636bec72ea2ae7329ebd4e619348 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/init.h" diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 6dc7d84764ddebccd22dc8a177afdff010098e4b..bd0223007e2dad3a55f2db665955b1053b17ffef 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -15,7 +15,7 @@ cc_library(inference_io SRCS io.cc) # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal? cc_library(paddle_fluid_api SRCS io.cc - DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) diff --git a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc index 233bfd6a42b7f123813d4ef5cecf353f7e88d208..38e9b1c5e7c19c89f94ce55324507b02da0c5160 100644 --- a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc @@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", - "elementwise_add", "dropout", "split"}); + "elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"}); if (!node->IsOp()) return false; if (teller_set.count(node->Op()->Type())) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 76d205b737aeb456f242037f2b375d9c537b39f3..d19505877bbc1110fcf5787fffc1436d242a7cdc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(split); +USE_TRT_CONVERTER(prelu); +USE_TRT_CONVERTER(conv2d_transpose); #endif diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index e09705e3c69eb2b2370bd1ad2d9cf178ef041ee6..17f6c6d9f10abf99fd93364d1356e2b3ef1b3934 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,4 @@ -nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) +nv_library(tensorrt_engine SRCS engine.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) add_subdirectory(plugin) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index ed4c398cee518af3211cab4e982082c46ebb36c2..85ad5ffe7875cdc205b5bdff28cc90ef01b236a4 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -2,35 +2,38 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc -pad_op.cc split_op.cc +pad_op.cc split_op.cc prelu_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS - ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) + ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine mul_op SERIAL) nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine mul_op SERIAL) nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op SERIAL) nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine conv_op conv_transpose_op SERIAL) nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL) nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine elementwise_add_op SERIAL) nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine softmax_op SERIAL) nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine batch_norm_op SERIAL) nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine concat_op SERIAL) nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine dropout_op SERIAL) nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pad_op SERIAL) nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin -split_op concat_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin + split_op concat_op SERIAL) +nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin + prelu_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 43950b8c048b4e1b8974956948caa639812b2f78..7900f56c9ce17ffc7c62c85a42c62ba326dea16e 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -18,92 +18,139 @@ namespace paddle { namespace inference { namespace tensorrt { -bool to_skip_merging_optimize(TensorRTEngine* engine_, +bool to_skip_merging_optimize(TensorRTEngine* engine, const std::vector& filters, const std::vector& strides, const std::vector& paddings, std::string input_name) { - if (engine_->itensor_quote_num[input_name] > 0) { + if (engine->itensor_quote_num[input_name] > 0) { return true; } if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) - engine_->itensor_quote_num[input_name] += 1; + engine->itensor_quote_num[input_name] += 1; return false; } +template +void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode, + RegistFunc fadd_layer, SetDilationFunc fset_dilation, + const std::string& name) { + VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight + PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); + + PADDLE_ENFORCE(engine != nullptr); + auto* X = engine->GetITensor(op_desc.Input("Input").front()); + + // Declare weights + auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); + PADDLE_ENFORCE_NOT_NULL(Y_v); + auto* Y_t = Y_v->GetMutable(); + + platform::CPUPlace cpu_place; + std::unique_ptr weight_tensor( + new framework::LoDTensor()); + weight_tensor->Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); + + auto* weight_data = weight_tensor->mutable_data(platform::CPUPlace()); + + PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); + const int n_output = weight_tensor->dims()[0]; + const int n_input = weight_tensor->dims()[1]; + const int filter_h = weight_tensor->dims()[2]; + const int filter_w = weight_tensor->dims()[3]; + const int groups = boost::get(op_desc.GetAttr("groups")); + const std::vector dilations = + boost::get>(op_desc.GetAttr("dilations")); + const std::vector strides = + boost::get>(op_desc.GetAttr("strides")); + const std::vector paddings = + boost::get>(op_desc.GetAttr("paddings")); + + nvinfer1::DimsHW nv_ksize(filter_h, filter_w); + nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); + nvinfer1::DimsHW nv_strides(strides[0], strides[1]); + nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_tensor->numel())}; + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto* layer = fadd_layer(const_cast(X), n_output, n_input, + nv_ksize, weight, bias); + PADDLE_ENFORCE(layer != nullptr); + layer->setStride(nv_strides); + layer->setPadding(nv_paddings); + layer->setNbGroups(groups); + // set dilations + fset_dilation(layer, nv_dilations); + + auto output_name = op_desc.Output("Output").front(); + layer->setName((name + " (Output: " + output_name + ")").c_str()); + engine->weight_map[op_desc.Input("Filter").front()] = + std::move(weight_tensor); + layer->getOutput(0)->setName(output_name.c_str()); + engine->SetITensor(output_name, layer->getOutput(0)); + + if (test_mode || + to_skip_merging_optimize(engine, {filter_h, filter_w}, strides, paddings, + op_desc.Input("Input").front())) { + engine->DeclareOutput(output_name); + } +} + class Conv2dOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a fluid conv2d op to tensorrt conv layer without bias"; - - framework::OpDesc op_desc(op, nullptr); - PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); - PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight - PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); - - auto* X = engine_->GetITensor(op_desc.Input("Input").front()); - - // Declare weights - auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); - PADDLE_ENFORCE_NOT_NULL(Y_v); - auto* Y_t = Y_v->GetMutable(); - - platform::CPUPlace cpu_place; - std::unique_ptr weight_tensor( - new framework::LoDTensor()); - weight_tensor->Resize(Y_t->dims()); - TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); - - auto* weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - - PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); - const int n_output = weight_tensor->dims()[0]; - const int filter_h = weight_tensor->dims()[2]; - const int filter_w = weight_tensor->dims()[3]; - - const int groups = boost::get(op_desc.GetAttr("groups")); - const std::vector dilations = - boost::get>(op_desc.GetAttr("dilations")); - const std::vector strides = - boost::get>(op_desc.GetAttr("strides")); - const std::vector paddings = - boost::get>(op_desc.GetAttr("paddings")); - - nvinfer1::DimsHW nv_ksize(filter_h, filter_w); - nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); - nvinfer1::DimsHW nv_strides(strides[0], strides[1]); - nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); - - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - weight_tensor->memory_size() / sizeof(float)}; - - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; - auto* layer = TRT_ENGINE_ADD_LAYER( - engine_, Convolution, *const_cast(X), n_output, - nv_ksize, weight.get(), bias.get()); - PADDLE_ENFORCE(layer != nullptr); - layer->setStride(nv_strides); - layer->setPadding(nv_paddings); - layer->setDilation(nv_dilations); - layer->setNbGroups(groups); - - auto output_name = op_desc.Output("Output").front(); - layer->setName(("conv2d (Output: " + output_name + ")").c_str()); - engine_->weight_map[op_desc.Input("Filter").front()] = - std::move(weight_tensor); - layer->getOutput(0)->setName(output_name.c_str()); - engine_->SetITensor(output_name, layer->getOutput(0)); - - if (test_mode || - to_skip_merging_optimize(engine_, {filter_h, filter_w}, strides, - paddings, op_desc.Input("Input").front())) { - engine_->DeclareOutput(output_name); - } + ConvertConv2d( + engine_, op, scope, test_mode, + [&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */ + int n_input, /* Conv input maps */ + nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* { + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, + ksize, weight.get(), bias.get()); + return layer; + }, + [](nvinfer1::IConvolutionLayer* layer, nvinfer1::DimsHW& dilations) { + layer->setDilation(dilations); + }, + "conv2d"); + } +}; + +class Deconv2dOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + ConvertConv2d( + engine_, op, scope, test_mode, + [&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */ + int n_input, /* Deconv output maps */ + nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* { + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input, + ksize, weight.get(), bias.get()); + return layer; + }, + [](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::DimsHW& dilations) { + PADDLE_ENFORCE( + dilations.d[0] == 1 && dilations.d[1] == 1, + "Dilations must be (1, 1) for tensorRT, but given (%d, %d)", + dilations.d[0], dilations.d[1]); + }, + "conv2d_transpose"); } }; @@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); +REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 671bcd8aa9a9fff34644a056499961cf6ab81287..1af091fabd2aea03a85b2d19fd556b18cdd65e3b 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* X = engine_->GetITensor(op_desc.Input("X").front()); nvinfer1::Dims dims_x = X->getDimensions(); - PADDLE_ENFORCE(dims_x.nbDims >= 3); + PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.", + dims_x.nbDims); auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..337885e6baa578d1f733e40f09f0586eba393333 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * PRelu converter from fluid to tensorRT. + */ +class PReluOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid prelu op to tensorrt prelu layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE(input_num == 1); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE(output_num == 1); + // Get attrs + std::string mode = boost::get(op_desc.GetAttr("mode")); + // + auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]); + PADDLE_ENFORCE_NOT_NULL(alpha_var); + auto* alpha_tensor = alpha_var->GetMutable(); + + platform::CUDAPlace place; + std::unique_ptr alpha_tensor_device( + new framework::LoDTensor()); + alpha_tensor_device->Resize(alpha_tensor->dims()); + TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get()); + float* alpha_data = alpha_tensor_device->mutable_data(place); + + // Transform alpha to TensorRTEngine::Weight + TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT, + static_cast(alpha_data), + alpha_tensor_device->numel()); + PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode); + nvinfer1::IPluginLayer* layer = + engine_->AddPlugin(&input, input_num, plugin); + // keep alpha tensor to avoid release it's memory + engine_->weight_map[op_desc.Input("Alpha")[0]] = + std::move(alpha_tensor_device); + + std::string layer_name = "prelu (Output: "; + auto output_name = op_desc.Output("Out")[0]; + layer->getOutput(0)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(0)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); + } + layer->setName((layer_name + ")").c_str()); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc index 12179cccc76f8b0f595f41c135290dc0f3b50ad7..159854ab593fbbfa1e08a9ca148f1b3a636d668c 100644 --- a/paddle/fluid/inference/tensorrt/convert/split_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(40) << "convert a fluid split op to tensorrt split layer"; + VLOG(4) << "convert a fluid split op to tensorrt split layer"; framework::OpDesc op_desc(op, nullptr); // Declare inputs diff --git a/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc index f8711c6b60d74639529624c25429bc245de46479..95916746d6fcb528d26a8f8bb39980b55c4f3704 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" +USE_OP(conv2d); +USE_OP(conv2d_transpose); + namespace paddle { namespace inference { namespace tensorrt { @@ -51,7 +54,37 @@ TEST(conv2d_op, test) { validator.Execute(3); } +TEST(conv2d_transpose_op, test) { + std::unordered_set parameters({"deconv2d-Y"}); + framework::Scope scope; + TRTConvertValidation validator(5, parameters, scope, 1 << 15); + + validator.DeclInputVar("deconv2d-X", nvinfer1::Dims3(3, 5, 5)); + validator.DeclParamVar("deconv2d-Y", nvinfer1::Dims4(3, 2, 3, 3)); + validator.DeclOutputVar("deconv2d-Out", nvinfer1::Dims3(2, 5, 5)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("conv2d_transpose"); + desc.SetInput("Input", {"deconv2d-X"}); + desc.SetInput("Filter", {"deconv2d-Y"}); + desc.SetOutput("Output", {"deconv2d-Out"}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + desc.SetAttr("strides", strides); + desc.SetAttr("paddings", paddings); + desc.SetAttr("dilations", dilations); + desc.SetAttr("groups", groups); + + validator.SetOp(*desc.Proto()); + + validator.Execute(3); +} + } // namespace tensorrt } // namespace inference } // namespace paddle -USE_OP(conv2d); diff --git a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..453f222f1f1e3f3b9ee8fa7bd49f4cab2286e7ea --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(prelu_op, test_channel_wise) { + std::unordered_set parameters({"prelu_alpha"}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1)); + validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("prelu"); + desc.SetInput("X", {"prelu_input"}); + desc.SetInput("Alpha", {"prelu_alpha"}); + desc.SetOutput("Out", {"prelu_out"}); + + desc.SetAttr("mode", std::string("channel")); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +TEST(prelu_op, test_element_wise) { + std::unordered_set parameters({"prelu_alpha"}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2)); + validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("prelu"); + desc.SetInput("X", {"prelu_input"}); + desc.SetInput("Alpha", {"prelu_alpha"}); + desc.SetOutput("Out", {"prelu_out"}); + + desc.SetAttr("mode", std::string("element")); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +TEST(prelu_op, test_scalar) { + std::unordered_set parameters({"prelu_alpha"}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1)); + validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("prelu"); + desc.SetInput("X", {"prelu_input"}); + desc.SetInput("Alpha", {"prelu_alpha"}); + desc.SetOutput("Out", {"prelu_out"}); + + desc.SetAttr("mode", std::string("all")); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +// USE_OP(prelu); +USE_CPU_ONLY_OP(prelu); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index fdd8b56b0ce5c9b5cb6395bcb437aae5ae27829b..208bd12b83aa19f01de9bcf4ada630c87defad5d 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, Buffer &TensorRTEngine::buffer(const std::string &name) { PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); auto it = buffer_sizes_.find(name); - PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s", + name); auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); return buffers_[slot_offset]; } diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 335acdf653e55cc7f3ceccdba88992851c8e0310..99420f19ba17d0bebf6dde3800d57c912256dc6b 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -40,6 +40,7 @@ class TensorRTEngine : public EngineBase { // Weight is model parameter. class Weight { public: + Weight() = default; Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) { w_.type = dtype; w_.values = value; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 71b7a551619a43e5300ad3205418d1174c7019ff..6611e2e4b35ee51e58cb6b7b088ad160acb659ad 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1 +1 @@ -nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) +nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..0f1ca112955afeecbf82b26324b77aa8def2ad9f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -0,0 +1,131 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +static const int CUDA_NUM_THREADS = 1024; +static const int CUDA_MAX_NUM_BLOCKS = 65535; +inline static int GET_NUM_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +__global__ void PReluChannelWiseKernel(const float *input, const float *alpha, + float *output, int channel, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const float *in = input + offset; + float *out = output + offset; + float scale = alpha[blockIdx.x % channel]; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + float x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +__global__ void PReluElementWiseKernel(const float *input, const float *alpha, + float *output, size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const float *in = input + offset; + const float *scale = alpha + offset; + float *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + float x = in[i]; + out[i] = (x > 0) ? x : scale[i] * x; + } +} + +__global__ void PReluScalarKernel(const float *input, const float *alpha, + float *output, size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const float *in = input + offset; + float scale = *alpha; + float *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + float x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +static inline void PReluChannelWise(cudaStream_t stream, const float *input, + const float *alpha, float *output, + int batch_size, + const nvinfer1::Dims &dims) { + size_t unroll = batch_size * dims.d[0]; + size_t spatial_size = dims.d[1] * dims.d[2]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluChannelWiseKernel<<>>( + input, alpha, output, dims.d[0], spatial_size); +} + +static inline void PReluElementWise(cudaStream_t stream, const float *input, + const float *alpha, float *output, + int batch_size, + const nvinfer1::Dims &dims) { + size_t unroll = batch_size * dims.d[0]; + size_t spatial_size = dims.d[1] * dims.d[2]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluElementWiseKernel<<>>( + input, alpha, output, spatial_size); +} + +static inline void PReluScalar(cudaStream_t stream, const float *input, + const float *alpha, float *output, + int batch_size, const nvinfer1::Dims &dims) { + size_t unroll = batch_size * dims.d[0]; + size_t spatial_size = dims.d[1] * dims.d[2]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluScalarKernel<<>>( + input, alpha, output, spatial_size); +} + +nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, + const nvinfer1::Dims *inputDims, + int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const &input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +int PReluPlugin::enqueue(int batchSize, const void *const *inputs, + void **outputs, void *workspace, cudaStream_t stream) { + // input dims is CHW. + const auto &input_dims = this->getInputDims(0); + const float *input = reinterpret_cast(inputs[0]); + const float *alpha = reinterpret_cast(alpha_.get().values); + float *output = reinterpret_cast(outputs)[0]; + if (mode_ == "channel") { + PReluChannelWise(stream, input, alpha, output, batchSize, input_dims); + } else if (mode_ == "element") { + PReluElementWise(stream, input, alpha, output, batchSize, input_dims); + } else { + PReluScalar(stream, input, alpha, output, batchSize, input_dims); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0f865c89be2dc20d3a30314ec02fd0b425b2fe --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -0,0 +1,68 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PReluPlugin : public PluginTensorRT { + TensorRTEngine::Weight alpha_; + std::string mode_; + + protected: + size_t getSerializationSize() override { + // return getBaseSerializationSize(alpha_) + SerializedSize(mode_); + return 0; + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) override { + // serializeBase(buffer); + // SerializeValue(&buffer, alpha_); + // SerializeValue(&buffer, mode_); + } + + public: + PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode) + : alpha_(alpha), mode_(mode) {} + + // It was used for tensorrt deserialization. + // It should not be called by users. + PReluPlugin(void const *serialData, size_t serialLength) { + // deserializeBase(serialData, serialLength); + // DeserializeValue(&serialData, &serialLength, &alpha_); + // DeserializeValue(&serialData, &serialLength, &mode_); + } + + PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); } + + const char *getPluginType() const override { return "prelu"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f06ef199d165455047a602f7ddec23534b99108e..df2a3e7aa635c9ba41dad85ccb8316823f875639 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -1,359 +1,73 @@ -file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") -string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") -string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") -list(REMOVE_DUPLICATES GENERAL_OPS) -set(DEPS_OPS "") -set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h) -file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operator/CMakeLists.txt. DO NOT EDIT!\n\n") - -set(PART_CUDA_KERNEL_FILES) -function(op_library TARGET) - # op_library is a function to create op library. The interface is same as - # cc_library. But it handle split GPU/CPU code and link some common library - # for ops. - set(cc_srcs) - set(cu_srcs) - set(hip_cu_srcs) - set(miopen_hip_cc_srcs) - set(cu_cc_srcs) - set(cudnn_cu_cc_srcs) - set(CUDNN_FILE) - set(mkldnn_cc_srcs) - set(MKLDNN_FILE) - set(op_common_deps operator op_registry math_function) - set(options "") - set(oneValueArgs "") - set(multiValueArgs SRCS DEPS) - set(pybind_flag 0) - cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN}) - - list(LENGTH op_library_SRCS op_library_SRCS_len) - if (${op_library_SRCS_len} EQUAL 0) - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) - list(APPEND cc_srcs ${TARGET}.cc) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) - list(APPEND cu_cc_srcs ${TARGET}.cu.cc) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) - list(APPEND cu_srcs ${TARGET}.cu) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) - set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu - ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE) - list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) - endif() - - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) - list(APPEND hip_cu_srcs ${TARGET}.hip.cu) - endif() - string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}") - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) - list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) - endif() - if(WITH_AMD_GPU) - string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}") - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc) - list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc) - endif() - endif() - if(WITH_MKLDNN) - string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) - list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc) - endif() - endif() - else() - foreach(src ${op_library_SRCS}) - if (${src} MATCHES ".*\\.hip.cu$") - list(APPEND hip_cu_srcs ${src}) - elseif (${src} MATCHES ".*\\.cu$") - list(APPEND cu_srcs ${src}) - elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") - list(APPEND cudnn_cu_cc_srcs ${src}) - elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$") - list(APPEND miopen_hip_cc_srcs ${src}) - elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") - list(APPEND mkldnn_cc_srcs ${src}) - elseif(${src} MATCHES ".*\\.cu.cc$") - list(APPEND cu_cc_srcs ${src}) - elseif(${src} MATCHES ".*\\.cc$") - list(APPEND cc_srcs ${src}) - else() - message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu") - endif() - endforeach() - endif() - - list(LENGTH cc_srcs cc_srcs_len) - if (${cc_srcs_len} EQUAL 0) - message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") - endif() - if (WIN32) - # remove windows unsupported op, because windows has no nccl, no warpctc such ops. - foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op" - "crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op" - "fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op" - "fusion_seqexpand_concat_fc_op" "attention_lstm_op" "fused_embedding_fc_lstm_op" "fc_op") - if ("${TARGET}" STREQUAL "${windows_unsupport_op}") - return() - endif() - endforeach() - endif(WIN32) - set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE) +include(operators) - list(LENGTH op_library_DEPS op_library_DEPS_len) - if (${op_library_DEPS_len} GREATER 0) - set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) - endif() - if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} - ${op_common_deps}) - elseif (WITH_AMD_GPU) - hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} - ${op_common_deps}) - else() - cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} - ${op_common_deps}) - endif() - - # Define operators that don't need pybind here. - foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" -"tensor_array_read_write_op" "tensorrt_engine_op") - if ("${TARGET}" STREQUAL "${manual_pybind_op}") - set(pybind_flag 1) - endif() - endforeach() - - # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. - # Note that it's enough to just adding one operator to pybind in a *_op.cc file. - # And for detail pybind information, please see generated paddle/pybind/pybind.h. - file(READ ${TARGET}.cc TARGET_CONTENT) - string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}") - string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}") - if (one_register STREQUAL "") - string(REPLACE "_op" "" TARGET "${TARGET}") - else () - string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}") - string(REPLACE "," "" TARGET "${TARGET}") - endif() - - # pybind USE_NO_KERNEL_OP - # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel - string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") - string(REPLACE "_op" "" TARGET "${TARGET}") - if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") - file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") - set(pybind_flag 1) - endif() - - # pybind USE_CPU_ONLY_OP - list(LENGTH cu_srcs cu_srcs_len) - list(LENGTH cu_cc_srcs cu_cc_srcs_len) - list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) - list(LENGTH hip_cu_srcs hip_cu_srcs_len) - list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND - ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0) - file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") - set(pybind_flag 1) - endif() - - # pybind USE_OP_DEVICE_KERNEL for CUDNN - list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) - if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") - endif() - - # pybind USE_OP_DEVICE_KERNEL for MIOPEN - if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n") - endif() - - # pybind USE_OP_DEVICE_KERNEL for MKLDNN - if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) - # Append first implemented MKLDNN activation operator - if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") - else() - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") - endif() - endif() - - # pybind USE_OP - if (${pybind_flag} EQUAL 0) - # NOTE(*): activation use macro to regist the kernels, set use_op manually. - if(${TARGET} STREQUAL "activation") - file(APPEND ${pybind_file} "USE_OP(relu);\n") - elseif(${TARGET} STREQUAL "fake_dequantize") - file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") - elseif(${TARGET} STREQUAL "fake_quantize") - file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") - elseif(${TARGET} STREQUAL "tensorrt_engine_op") - message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") - elseif(${TARGET} STREQUAL "fc") - # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition - file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") - else() - file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") - endif() - endif() -endfunction() +# clean cache and pybind_file content first when rebuild +unset(GLOB_OP_LIB CACHE) +unset(OP_LIBRARY CACHE) +set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h CACHE INTERNAL "pybind.h file") +file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operator/CMakeLists.txt. DO NOT EDIT!\n\n") add_subdirectory(math) -if (NOT WIN32) -add_subdirectory(nccl) -if(WITH_GPU) - op_library(nccl_op DEPS nccl_common) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") -else() - set(DEPS_OPS ${DEPS_OPS} nccl_op) -endif() -endif() # NOT WIN32 +add_subdirectory(controlflow) +add_subdirectory(csp) +add_subdirectory(detection) +add_subdirectory(elementwise) +add_subdirectory(fused) +add_subdirectory(metrics) +add_subdirectory(optimizers) +add_subdirectory(reduce_ops) +add_subdirectory(sequence_ops) -set(DISTRIBUTE_DEPS "") if(WITH_DISTRIBUTE) add_subdirectory(distributed) - set(DISTRIBUTE_DEPS "") - if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) - else() - set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) - if(WITH_BRPC_RDMA) - find_library(IBVERBS_LIBRARY NAMES ibverbs) - ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) - SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY}) - - - find_library(RDMACM_LIBRARY NAMES rdmacm) - ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL) - SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY}) - - set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} ibverbs rdmacm) - endif() - endif() - - set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") - foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op") - op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS}) - set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - endforeach() + add_subdirectory(distributed_ops) +endif() - #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op - # listen_and_serv_op sum_op executor SERIAL) - if(WITH_GPU AND NOT WIN32) - set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op ${DISTRIBUTE_DEPS} executor SERIAL) - if(WITH_GRPC) - op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc) - else() - op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_brpc) - endif() - set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - else() - set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) - endif() # WITH_GPU AND NOT WIN32 -else() - set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) +if (NOT WIN32) + add_subdirectory(reader) endif() -op_library(cross_entropy_op DEPS cross_entropy) -if(WITH_GPU) - op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax cub) - op_library(sequence_softmax_op DEPS cub) -else() - op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) +if (NOT WIN32) + add_subdirectory(nccl) endif() -op_library(softmax_op DEPS softmax) if (WITH_GPU AND TENSORRT_FOUND) - op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n") - nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc - DEPS tensorrt_engine_op - analysis) -else() - set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) + add_subdirectory(tensorrt) endif() -op_library(hash_op DEPS xxhash) -op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) -op_library(sum_op DEPS selected_rows_functor) -op_library(sgd_op DEPS selected_rows_functor) -op_library(print_op DEPS lod_tensor) -op_library(adagrad_op DEPS selected_rows_functor) -op_library(maxout_op DEPS maxouting) -op_library(unpool_op DEPS unpooling) -op_library(pool_op DEPS pooling) -op_library(pool_with_index_op DEPS pooling) -op_library(lod_rank_table_op DEPS lod_rank_table) -op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) -op_library(array_to_lod_tensor_op DEPS lod_rank_table_op) -op_library(max_sequence_len_op DEPS lod_rank_table) -op_library(sequence_conv_op DEPS context_project) -op_library(sequence_pool_op DEPS sequence_pooling) -if (NOT WIN32) - op_library(lstm_op DEPS sequence2batch lstm_compute) - op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) - op_library(lstmp_op DEPS sequence2batch lstm_compute) - op_library(gru_op DEPS sequence2batch gru_compute) -endif(NOT WIN32) -op_library(recurrent_op DEPS executor) -op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) -op_library(cos_sim_op DEPS cos_sim_functor) -op_library(parallel_do_op DEPS executor) -op_library(unsqueeze_op DEPS reshape_op) -op_library(squeeze_op DEPS reshape_op) -op_library(flatten_op DEPS reshape_op) -op_library(sequence_pad_op DEPS sequence_padding) -op_library(unstack_op DEPS stack_op) -op_library(fake_quantize_op DEPS memory) -if (NOT WIN32) -op_library(crf_decoding_op DEPS jit_kernel) -op_library(fusion_lstm_op DEPS jit_kernel) -endif(NOT WIN32) + +register_operators(EXCLUDES warpctc_op) + +# warpctc_cudnn need cudnn 7 above if (WITH_GPU) - op_library(conv_op DEPS vol2col depthwise_conv im2col) - op_library(layer_norm_op DEPS cub) - op_library(reduce_mean_op DEPS cub) - op_library(affine_channel_op DEPS cub) + if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) + else() + op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) + endif() else() - op_library(conv_op DEPS vol2col im2col) + op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() -op_library(conv_transpose_op DEPS vol2col im2col) - -# FIXME(typhoonzero): save/load depends lodtensor serialization functions -op_library(save_op DEPS lod_tensor) -op_library(load_op DEPS lod_tensor) -op_library(save_combine_op DEPS lod_tensor) -op_library(load_combine_op DEPS lod_tensor) -op_library(concat_op DEPS concat_and_split) -op_library(tensor_array_to_tensor_op DEPS concat_op) - -list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) - -foreach(src ${GENERAL_OPS}) - op_library(${src}) -endforeach() - -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") +set(COMMON_OP_DEPS "") +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor dynload_warpctc sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) if (NOT WIN32) -add_subdirectory(reader) -endif(NOT WIN32) -foreach(src ${READER_LIBRARY}) - set(OP_LIBRARY ${src} ${OP_LIBRARY}) -endforeach() + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) +endif() +if (WITH_GPU) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub) +endif() -add_subdirectory(detection) -foreach(src ${DETECTION_LIBRARY}) - set(OP_LIBRARY ${src} ${OP_LIBRARY}) -endforeach() +# FIXME(typhoonzero): operator deps may not needed. +# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) +# op_library(array_to_lod_tensor_op DEPS lod_rank_table_op) +# op_library(unsqueeze_op DEPS reshape_op) +# op_library(squeeze_op DEPS reshape_op) +# op_library(flatten_op DEPS reshape_op) +# op_library(unstack_op DEPS stack_op) +# op_library(tensor_array_to_tensor_op DEPS concat_op) -set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") -set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency") +set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS}) +set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) @@ -362,18 +76,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) -if(NOT WIN32) - nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) -endif() nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) -if(WITH_GPU) - foreach(CUDA_KERNEL_FILE ${PART_CUDA_KERNEL_FILES}) - file(READ ${CUDA_KERNEL_FILE} TARGET_CONTENT) - string(REGEX MATCH "REGISTER_OP_CUDA_KERNEL\\(\\n?([^,]+),.*" MATCHED ${TARGET_CONTENT}) - if (MATCHED) - string(STRIP ${CMAKE_MATCH_1} MATCHED) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MATCHED}, CUDA);\n") - endif() - endforeach() -endif() +set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b1c2ee22951a3881b4ce5b82f9ff7eb01fde6e9e --- /dev/null +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -0,0 +1,4 @@ +include(operators) +register_operators() + +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc similarity index 98% rename from paddle/fluid/operators/compare_op.cc rename to paddle/fluid/operators/controlflow/compare_op.cc index f40b1ba338d429c248103eeb930ac7e1bb690218..488ca7fe95f5119c59b861011993a379d08008ba 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/compare_op.h" +#include "paddle/fluid/operators/controlflow/compare_op.h" #include #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu similarity index 94% rename from paddle/fluid/operators/compare_op.cu rename to paddle/fluid/operators/controlflow/compare_op.cu index 1bf85c64fb5b4d79c62118959fd72b13ed1c63ed..b1f306358359764b919f9e570cf44f9733a7d178 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/compare_op.h" +#include "paddle/fluid/operators/controlflow/compare_op.h" REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/controlflow/compare_op.h similarity index 97% rename from paddle/fluid/operators/compare_op.h rename to paddle/fluid/operators/controlflow/compare_op.h index 1cbabdaf6767815c1fedba0eabec9b5de678e047..b7529e4ae632d31524846d9d5aa4b1883f4509a1 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/controlflow/compare_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/transform.h" namespace paddle { diff --git a/paddle/fluid/operators/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc similarity index 100% rename from paddle/fluid/operators/conditional_block_op.cc rename to paddle/fluid/operators/controlflow/conditional_block_op.cc diff --git a/paddle/fluid/operators/feed_op.cc b/paddle/fluid/operators/controlflow/feed_op.cc similarity index 100% rename from paddle/fluid/operators/feed_op.cc rename to paddle/fluid/operators/controlflow/feed_op.cc diff --git a/paddle/fluid/operators/fetch_op.cc b/paddle/fluid/operators/controlflow/fetch_op.cc similarity index 100% rename from paddle/fluid/operators/fetch_op.cc rename to paddle/fluid/operators/controlflow/fetch_op.cc diff --git a/paddle/fluid/operators/get_places_op.cc b/paddle/fluid/operators/controlflow/get_places_op.cc similarity index 100% rename from paddle/fluid/operators/get_places_op.cc rename to paddle/fluid/operators/controlflow/get_places_op.cc diff --git a/paddle/fluid/operators/logical_op.cc b/paddle/fluid/operators/controlflow/logical_op.cc similarity index 99% rename from paddle/fluid/operators/logical_op.cc rename to paddle/fluid/operators/controlflow/logical_op.cc index 26970db8d2af62bb06fce4eb1a1f21fd41617bd1..6446cab5ec5f889dccaef90484476e55c4852dee 100644 --- a/paddle/fluid/operators/logical_op.cc +++ b/paddle/fluid/operators/controlflow/logical_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/logical_op.h" +#include "paddle/fluid/operators/controlflow/logical_op.h" #include #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/logical_op.cu b/paddle/fluid/operators/controlflow/logical_op.cu similarity index 94% rename from paddle/fluid/operators/logical_op.cu rename to paddle/fluid/operators/controlflow/logical_op.cu index 7ffe4dfc268b1ad3894dd54cb13c2f424818aa05..7ca54b488bfbb260c422941b82145f092a150be7 100644 --- a/paddle/fluid/operators/logical_op.cu +++ b/paddle/fluid/operators/controlflow/logical_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/logical_op.h" +#include "paddle/fluid/operators/controlflow/logical_op.h" REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CUDA, paddle::operators::LogicalAndFunctor); diff --git a/paddle/fluid/operators/logical_op.h b/paddle/fluid/operators/controlflow/logical_op.h similarity index 100% rename from paddle/fluid/operators/logical_op.h rename to paddle/fluid/operators/controlflow/logical_op.h diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/controlflow/parallel_do_op.cc similarity index 100% rename from paddle/fluid/operators/parallel_do_op.cc rename to paddle/fluid/operators/controlflow/parallel_do_op.cc diff --git a/paddle/fluid/operators/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc similarity index 100% rename from paddle/fluid/operators/tensor_array_read_write_op.cc rename to paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc similarity index 100% rename from paddle/fluid/operators/while_op.cc rename to paddle/fluid/operators/controlflow/while_op.cc diff --git a/paddle/fluid/operators/csp/CMakeLists.txt b/paddle/fluid/operators/csp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d468316e8eacb73c4a4ce81c784880bb5e46c2d --- /dev/null +++ b/paddle/fluid/operators/csp/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/go_op.cc b/paddle/fluid/operators/csp/go_op.cc similarity index 100% rename from paddle/fluid/operators/go_op.cc rename to paddle/fluid/operators/csp/go_op.cc diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index e5c3f0eeb385e1a15fdbb12a989603996420efe3..58f6f48467310ffb2429ad440f58fcd823edf079 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -40,4 +40,8 @@ endif() detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) #Export local libraries to parent -set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) +# set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) + +foreach(src ${LOCAL_DETECTION_LIBS}) + set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs") +endforeach() diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..28bb90af5675b2fe14813675ad001c0cf1d71e12 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -0,0 +1,40 @@ +include(operators) + +set(DISTRIBUTE_DEPS "") +if(WITH_GRPC) + set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) +else() + set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + if(WITH_BRPC_RDMA) + find_library(IBVERBS_LIBRARY NAMES ibverbs) + ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY}) + + + find_library(RDMACM_LIBRARY NAMES rdmacm) + ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY}) + + set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} ibverbs rdmacm) + endif() +endif() + +set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + + +file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +list(REMOVE_DUPLICATES OPS) + +foreach(src ${OPS}) + set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +endforeach() + +register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS}) + +if(WITH_GPU AND NOT WIN32) + set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common) + op_library(gen_nccl_id_op ${DISTRIBUTE_DEPS} nccl_common) +endif() + +set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE) +set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency") diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc similarity index 98% rename from paddle/fluid/operators/checkpoint_notify_op.cc rename to paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc index defa287bdb913e406aa7e2a383cefc3cca8c4d94..ed4dced51356515d5910e2962c9ee91a1997dbf0 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/string/printf.h" namespace paddle { diff --git a/paddle/fluid/operators/fake_init_op.cc b/paddle/fluid/operators/distributed_ops/fake_init_op.cc similarity index 100% rename from paddle/fluid/operators/fake_init_op.cc rename to paddle/fluid/operators/distributed_ops/fake_init_op.cc diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc similarity index 100% rename from paddle/fluid/operators/fetch_barrier_op.cc rename to paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc similarity index 100% rename from paddle/fluid/operators/gen_nccl_id_op.cc rename to paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc similarity index 99% rename from paddle/fluid/operators/listen_and_serv_op.cc rename to paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index e3d09e2d14817fe0f2ccda18ed90c9436b399ae3..9f0c7db0e1133f6d73e73a9d162a945ba4c17dc6 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include "paddle/fluid/operators/listen_and_serv_op.h" +#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send"); DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get"); diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h similarity index 100% rename from paddle/fluid/operators/listen_and_serv_op.h rename to paddle/fluid/operators/distributed_ops/listen_and_serv_op.h diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc similarity index 98% rename from paddle/fluid/operators/merge_ids_op.cc rename to paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 6e0e13698097ade36449f2e8ff6ab981a1b24311..252a63cb605f65e8572281a05e884fb8b020a820 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/merge_ids_op.h" +#include "paddle/fluid/operators/distributed_ops/merge_ids_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/distributed_ops/merge_ids_op.h similarity index 100% rename from paddle/fluid/operators/merge_ids_op.h rename to paddle/fluid/operators/distributed_ops/merge_ids_op.h diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/distributed_ops/prefetch_op.cc similarity index 98% rename from paddle/fluid/operators/prefetch_op.cc rename to paddle/fluid/operators/distributed_ops/prefetch_op.cc index 55853d25460bf6e3d07c829d686e71cc9367118c..faa67a28d86235625a87b8bd7b87685e09c75f0b 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/distributed_ops/prefetch_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc similarity index 100% rename from paddle/fluid/operators/recv_op.cc rename to paddle/fluid/operators/distributed_ops/recv_op.cc diff --git a/paddle/fluid/operators/ref_by_trainer_id_op.cc b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc similarity index 97% rename from paddle/fluid/operators/ref_by_trainer_id_op.cc rename to paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc index 6cb651af6dc3d8e301365968787c199acc4c60ee..98b0af7688b928f21573247b327bee1d22a73f17 100644 --- a/paddle/fluid/operators/ref_by_trainer_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/ref_by_trainer_id_op.h" +#include "paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h" #include namespace paddle { diff --git a/paddle/fluid/operators/ref_by_trainer_id_op.cu.cc b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cu.cc similarity index 94% rename from paddle/fluid/operators/ref_by_trainer_id_op.cu.cc rename to paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cu.cc index b98e2b5c9c7341f2a424fb4b32ff1e8bc45a056c..168cd51355de56c2e2a83ba73d7eb14f6ba6e533 100644 --- a/paddle/fluid/operators/ref_by_trainer_id_op.cu.cc +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/ref_by_trainer_id_op.h" +#include "paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h" REGISTER_OP_CUDA_KERNEL( ref_by_trainer_id, diff --git a/paddle/fluid/operators/ref_by_trainer_id_op.h b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h similarity index 100% rename from paddle/fluid/operators/ref_by_trainer_id_op.h rename to paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/distributed_ops/send_barrier_op.cc similarity index 100% rename from paddle/fluid/operators/send_barrier_op.cc rename to paddle/fluid/operators/distributed_ops/send_barrier_op.cc diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc similarity index 98% rename from paddle/fluid/operators/send_op.cc rename to paddle/fluid/operators/distributed_ops/send_op.cc index 0ad43d56d3cd7500290dc1e386a2dbaf4453a191..be53a1a32b59d7c0235382f5db18d2203b4a035a 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/distributed_ops/send_recv_op_test.cc similarity index 99% rename from paddle/fluid/operators/send_recv_op_test.cc rename to paddle/fluid/operators/distributed_ops/send_recv_op_test.cc index d79b16e3cca714d44c88834082cea9367480da9a..bf798a8251fcb4148db486f26d32525b59299c81 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/distributed_ops/send_recv_op_test.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/listen_and_serv_op.h" +#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/string/printf.h" diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/distributed_ops/send_recv_util.h similarity index 100% rename from paddle/fluid/operators/send_recv_util.h rename to paddle/fluid/operators/distributed_ops/send_recv_util.h diff --git a/paddle/fluid/operators/split_byref_op.cc b/paddle/fluid/operators/distributed_ops/split_byref_op.cc similarity index 98% rename from paddle/fluid/operators/split_byref_op.cc rename to paddle/fluid/operators/distributed_ops/split_byref_op.cc index bc998e1abbd7131a7497288cc9d66315a6fedc85..d65e7ffe5a492fe5df038bb6bd469e09de6f95ca 100644 --- a/paddle/fluid/operators/split_byref_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_byref_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/split_byref_op.h" +#include "paddle/fluid/operators/distributed_ops/split_byref_op.h" #include "paddle/fluid/operators/split_op.h" namespace paddle { diff --git a/paddle/fluid/operators/split_byref_op.cu.cc b/paddle/fluid/operators/distributed_ops/split_byref_op.cu.cc similarity index 91% rename from paddle/fluid/operators/split_byref_op.cu.cc rename to paddle/fluid/operators/distributed_ops/split_byref_op.cu.cc index 5ee6186f3541b7dcb845ce0c6d28081685925da0..056659c3ea61f6233a6dda56ca1e272e72770d4a 100644 --- a/paddle/fluid/operators/split_byref_op.cu.cc +++ b/paddle/fluid/operators/distributed_ops/split_byref_op.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/split_byref_op.h" +#include "paddle/fluid/operators/distributed_ops/split_byref_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( split_byref, diff --git a/paddle/fluid/operators/split_byref_op.h b/paddle/fluid/operators/distributed_ops/split_byref_op.h similarity index 100% rename from paddle/fluid/operators/split_byref_op.h rename to paddle/fluid/operators/distributed_ops/split_byref_op.h diff --git a/paddle/fluid/operators/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc similarity index 98% rename from paddle/fluid/operators/split_ids_op.cc rename to paddle/fluid/operators/distributed_ops/split_ids_op.cc index 01d432e13068f7b718d08dc15d8cc99a7fbb0afe..f61d387fbef636298c412c227bf7a56a04f69c63 100644 --- a/paddle/fluid/operators/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/split_ids_op.h" +#include "paddle/fluid/operators/distributed_ops/split_ids_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/distributed_ops/split_ids_op.h similarity index 100% rename from paddle/fluid/operators/split_ids_op.h rename to paddle/fluid/operators/distributed_ops/split_ids_op.h diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc similarity index 96% rename from paddle/fluid/operators/test_send_nccl_id.cc rename to paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc index b5426e17aac19dc07ee56545fac8472d9ef0d93c..a73cb08eca272b044501d48e7b8c5b7dc8553a50 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc @@ -22,14 +22,14 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include "paddle/fluid/operators/listen_and_serv_op.h" +#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/string/printf.h" #ifdef PADDLE_WITH_GRPC -#include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #endif USE_NO_KERNEL_OP(listen_and_serv); diff --git a/paddle/fluid/operators/elementwise/CMakeLists.txt b/paddle/fluid/operators/elementwise/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d468316e8eacb73c4a4ce81c784880bb5e46c2d --- /dev/null +++ b/paddle/fluid/operators/elementwise/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_mkldnn_op.cc similarity index 97% rename from paddle/fluid/operators/elementwise_add_mkldnn_op.cc rename to paddle/fluid/operators/elementwise/elementwise_add_mkldnn_op.cc index 9ad82aec8182d6ba06b67391d71317a3d0df1833..6a6741d8fc54d22addca91b75dfabf5950c1a35a 100644 --- a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_mkldnn_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/mkldnn_helper.h" diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc similarity index 92% rename from paddle/fluid/operators/elementwise_add_op.cc rename to paddle/fluid/operators/elementwise/elementwise_add_op.cc index 3c97ac995c649ecd0d196a584240e1e7ac04f08e..7e789cd8d9143164c2346b067855eb904e00075f 100644 --- a/paddle/fluid/operators/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace ops = paddle::operators; REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add); REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out", diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu similarity index 95% rename from paddle/fluid/operators/elementwise_add_op.cu rename to paddle/fluid/operators/elementwise/elementwise_add_op.cu index f9f5c66d34fa1d73db00173e493f9953b8579518..2fb7eeb4b9e3119a6eea3e69a2a6002a80f6c0f3 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h similarity index 97% rename from paddle/fluid/operators/elementwise_add_op.h rename to paddle/fluid/operators/elementwise/elementwise_add_op.h index 9edbdbefe76600dc4bf937d95e70d11450206cd4..69f640ab6649df673f07ac0cef81bf80d16eb98d 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc similarity index 91% rename from paddle/fluid/operators/elementwise_div_op.cc rename to paddle/fluid/operators/elementwise/elementwise_div_op.cc index 84c8a65e5f859d276ae6d5f1a3f25c9d713a7a61..85612ba47448a7b0d712e9314e3980019c96e9c3 100644 --- a/paddle/fluid/operators/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_div_op.h" -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace ops = paddle::operators; REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y"); diff --git a/paddle/fluid/operators/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu similarity index 95% rename from paddle/fluid/operators/elementwise_div_op.cu rename to paddle/fluid/operators/elementwise/elementwise_div_op.cu index 588d1f7420241ba1697e5141e4e4a2870f2dc87c..c5a1a7e08d89f3ef205af4c37246f8fa288189f3 100644 --- a/paddle/fluid/operators/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_div_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h similarity index 94% rename from paddle/fluid/operators/elementwise_div_op.h rename to paddle/fluid/operators/elementwise/elementwise_div_op.h index cdb1264d298ef48d6b3da39d63ff1d09e1561aa4..8a07339077aeaa4403ffd1e1e30e0d58a9cc30e7 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc similarity index 91% rename from paddle/fluid/operators/elementwise_max_op.cc rename to paddle/fluid/operators/elementwise/elementwise_max_op.cc index 411671335a19ae2283ca9db8b8f6bcbb6a6b630a..ea0dcd736e5700fb0f341938ac3e3e3b178f29c1 100644 --- a/paddle/fluid/operators/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_max_op.h" -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace ops = paddle::operators; REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)"); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu similarity index 95% rename from paddle/fluid/operators/elementwise_max_op.cu rename to paddle/fluid/operators/elementwise/elementwise_max_op.cu index 32c99835d66d8b11b72af162230aa383c7e4a57c..a90dcd3ecf0da114110db5946e111a8b3a925e42 100644 --- a/paddle/fluid/operators/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_max_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h similarity index 94% rename from paddle/fluid/operators/elementwise_max_op.h rename to paddle/fluid/operators/elementwise/elementwise_max_op.h index 367489dd563f7d8bdf430517cadf49d4ef2a0105..3ee0c32e0d5d5df02d5d157416918fb4fb3aca92 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc similarity index 91% rename from paddle/fluid/operators/elementwise_min_op.cc rename to paddle/fluid/operators/elementwise/elementwise_min_op.cc index 816192083d2275b26e6dd9afc76f2c021a01cf73..b263b9addd40cfd329d2cc8588c278df2cb008e9 100644 --- a/paddle/fluid/operators/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_min_op.h" -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_min_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace ops = paddle::operators; REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)"); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu similarity index 95% rename from paddle/fluid/operators/elementwise_min_op.cu rename to paddle/fluid/operators/elementwise/elementwise_min_op.cu index a237c9c503ec998fd74fec50a1d7949279bb38f0..ab77709c28c15a925bd3deac07c43e12b12cb781 100644 --- a/paddle/fluid/operators/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_min_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_min_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h similarity index 94% rename from paddle/fluid/operators/elementwise_min_op.h rename to paddle/fluid/operators/elementwise/elementwise_min_op.h index 1bd0a6279766c8eba92d1e3a76191c59410286b2..d04e372faaa4e6296e982afe6155cdde2fec4f81 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc similarity index 95% rename from paddle/fluid/operators/elementwise_mul_op.cc rename to paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 86a8459a79135d1fbcba6886172acc5a2abdb88b..d5e3300ac954aebf34a9c65fbca8de8fa2685932 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_mul_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu similarity index 95% rename from paddle/fluid/operators/elementwise_mul_op.cu rename to paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 2fb1b4bee689c059625e3dbd59f80c541ace83a0..4d16bc38e1d8e4cbbe3afbe08f233e14329e0f2e 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_mul_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h similarity index 96% rename from paddle/fluid/operators/elementwise_mul_op.h rename to paddle/fluid/operators/elementwise/elementwise_mul_op.h index 29e4ab7db1377b6aa80e94a26ab3cb8669f9154a..dc25bc57103286ce183a4649964fd96c62169b7f 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h similarity index 100% rename from paddle/fluid/operators/elementwise_op.h rename to paddle/fluid/operators/elementwise/elementwise_op.h diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h similarity index 100% rename from paddle/fluid/operators/elementwise_op_function.h rename to paddle/fluid/operators/elementwise/elementwise_op_function.h diff --git a/paddle/fluid/operators/elementwise_pow_op.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc similarity index 90% rename from paddle/fluid/operators/elementwise_pow_op.cc rename to paddle/fluid/operators/elementwise/elementwise_pow_op.cc index 5fd6bde9ba0930e29f2161f1ff23ff9f5e7dc85d..6335e67a8a48c8702f0cb14ce947275d47e01d17 100644 --- a/paddle/fluid/operators/elementwise_pow_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_pow_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" #include -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu similarity index 92% rename from paddle/fluid/operators/elementwise_pow_op.cu rename to paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 1f19ebd470973137b465381e498ab07a36323c14..6ee0779f23bc2c734aa1d439abb12f366227e686 100644 --- a/paddle/fluid/operators/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_pow_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_pow_op.h b/paddle/fluid/operators/elementwise/elementwise_pow_op.h similarity index 95% rename from paddle/fluid/operators/elementwise_pow_op.h rename to paddle/fluid/operators/elementwise/elementwise_pow_op.h index 8c1c5f9f98018d8d4368a9333e2004141615775d..dc584b4c32fc3063da0c6de50577d28afcb63b83 100644 --- a/paddle/fluid/operators/elementwise_pow_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc similarity index 92% rename from paddle/fluid/operators/elementwise_sub_op.cc rename to paddle/fluid/operators/elementwise/elementwise_sub_op.cc index b7224261e6a7ca82dff92a25f5fe8818c08e676d..efc66374c812cbd07adef6ac25c9616b880ec383 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise_sub_op.h" -#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" namespace ops = paddle::operators; REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub); REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_sub, "Sub", "Out = X - Y", "Out", diff --git a/paddle/fluid/operators/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu similarity index 95% rename from paddle/fluid/operators/elementwise_sub_op.cu rename to paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 8709f686f9af1bf4dacbc2dfc3e2d5dcc1c59b9a..8d9bf7c4d81d49d83b5d1cf0369be5c9957242b4 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/elementwise_sub_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h similarity index 94% rename from paddle/fluid/operators/elementwise_sub_op.h rename to paddle/fluid/operators/elementwise/elementwise_sub_op.h index 7204c43464e0b81126148b86f64a36b0e299368b..770323fe5a8fe7c1051b418b2541ab4c669635b4 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 5ad0ec251328cc1ba580026bb47bf05316e7dc77..40f7c1c54c861abebc84428f55e2769ac8969f0f 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel { out_shape[i] = x_dims[i] * expand_times[i]; } + // set the first dim to -1 in compile time + if (!ctx->IsRuntime()) { + out_shape[0] = x_dims[0]; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); if (out_shape[0] == x_dims[0]) { ctx->ShareLoD("X", "Out"); @@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel { ctx->Attrs().Get>("expand_times"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - for (size_t i = 0; i < expand_times.size(); ++i) { + size_t start_pos = 0u; + if (!ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + x_dims[0], out_dims[0], + "The first dimension size of Input(Out@GRAD) should be " + "equal to the crroresponding dimension size of Input(X)"); + start_pos = 1u; + } + + for (size_t i = start_pos; i < expand_times.size(); ++i) { PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], "Each dimension size of Input(Out@GRAD) should be " "equal to multiplication of crroresponding dimension " diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d468316e8eacb73c4a4ce81c784880bb5e46c2d --- /dev/null +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc similarity index 99% rename from paddle/fluid/operators/fused_elemwise_activation_op.cc rename to paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index d88ef15949da3809bffe41e4bf303d1fee568675..3771aac0dfd98a52dcd8b789e5a6114e977e22f8 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused_elemwise_activation_op.h" +#include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cu b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cu similarity index 94% rename from paddle/fluid/operators/fused_elemwise_activation_op.cu rename to paddle/fluid/operators/fused/fused_elemwise_activation_op.cu index e1d2b16b4b5e3a480777f834c2cbeb6d00a755e4..e10693bae1859307c9cf266965d4ce20e6de1bf9 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.cu +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused_elemwise_activation_op.h" +#include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h similarity index 99% rename from paddle/fluid/operators/fused_elemwise_activation_op.h rename to paddle/fluid/operators/fused/fused_elemwise_activation_op.h index 5ae9aea959c268985c17643f2f47199c852c2bcb..01dc2dbfd61cc88f72174233382aa49f61c9b60f 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/functors.h" diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc similarity index 99% rename from paddle/fluid/operators/fused_embedding_fc_lstm_op.cc rename to paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index fdc9cb4888b3468b85abfa0c693ed8ac5b0d450b..6d463538d232e1a38f845e7abc3786568ca3bb21 100644 --- a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h" +#include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.h b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h similarity index 100% rename from paddle/fluid/operators/fused_embedding_fc_lstm_op.h rename to paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc similarity index 99% rename from paddle/fluid/operators/fusion_gru_op.cc rename to paddle/fluid/operators/fused/fusion_gru_op.cc index 120b2ab440156f6020fd6005dd64a48e9a6918ec..7e34d1019c9e6577b50ff8c2fa3d767124b5ff3b 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fusion_gru_op.h" +#include "paddle/fluid/operators/fused/fusion_gru_op.h" #include // for memcpy #include #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/fusion_gru_op.h b/paddle/fluid/operators/fused/fusion_gru_op.h similarity index 100% rename from paddle/fluid/operators/fusion_gru_op.h rename to paddle/fluid/operators/fused/fusion_gru_op.h diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc similarity index 99% rename from paddle/fluid/operators/fusion_lstm_op.cc rename to paddle/fluid/operators/fused/fusion_lstm_op.cc index 067e6a3e7cccc1f15ebdd984f3a2441339a989ab..0959539068eef5b550a8e3997d3f11ea67ae0707 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fusion_lstm_op.h" +#include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc_compute.h" diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fused/fusion_lstm_op.h similarity index 100% rename from paddle/fluid/operators/fusion_lstm_op.h rename to paddle/fluid/operators/fused/fusion_lstm_op.h diff --git a/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc similarity index 99% rename from paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc rename to paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index b0910dc19edb246d9acfe3bdb15071c64cbdaba7..40bba09f3ef71021b7daff83b9d63005f7580395 100644 --- a/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h" +#include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h" #include // for min, max #include #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h similarity index 100% rename from paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h rename to paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h diff --git a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc similarity index 99% rename from paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc rename to paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 8d2f055d53a0c5bbef624ff3b01b01724d0b3a21..288b56fc2485138b20c5b53af3e950f1c1886ba5 100644 --- a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h" +#include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" diff --git a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h similarity index 100% rename from paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h rename to paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 2e54bb497dec11eaeda03a1aa6acfd4cc261dbfe..7bf79b08956885259e5ac3801274a1a675e6d975 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index cc3cc9787a3926eea2f9a1620eead9823a7d77c5..4cd014cbadb888a2afe118785336e673e2b5eafb 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -41,6 +41,7 @@ math_library(cross_entropy) math_library(cos_sim_functor) math_library(depthwise_conv) math_library(im2col) +math_library(sampler) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 0aed253c80fc28560716cbcfa70f74ef9c84f9b6..7d81aee596934308763002d440f52400f45b5f20 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -33,11 +33,11 @@ namespace math { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 -#define AVX_FLOAT_BLOCK 8 +#define YMM_FLOAT_BLOCK 8 #define AVX_DOUBLE_BLOCK 4 -#define AVX2_FLOAT_BLOCK 8 +#define YMM_FLOAT_BLOCK 8 #define AVX2_DOUBLE_BLOCK 4 -#define AVX512_FLOAT_BLOCK 16 +#define ZMM_FLOAT_BLOCK 16 #define AVX512_DOUBLE_BLOCK 8 template @@ -88,7 +88,7 @@ template <> inline void vec_scal(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_scal(n, a, x, y); return; @@ -142,7 +142,7 @@ template <> inline void vec_bias_sub(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_bias_sub(n, a, x, y); return; @@ -200,7 +200,7 @@ inline void vec_cross(const int n, const float* x, const float* y, const float* z, float* out) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_cross(n, x, y, z, out); return; @@ -257,7 +257,7 @@ template <> inline void vec_add_bias(const int n, const float a, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_add_bias(n, a, x, y); return; @@ -326,7 +326,7 @@ template <> inline void vec_sigmoid(const int n, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { vec_sigmoid(n, x, y); return; @@ -415,7 +415,7 @@ template <> inline void vec_relu(const int n, const float* x, float* y) { #ifdef __AVX__ - constexpr int block = AVX_FLOAT_BLOCK; + constexpr int block = YMM_FLOAT_BLOCK; if (n < block * 4) { vec_relu(n, x, y); return; diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index e46f60f764ab9f1c292db339a5b38b976de5a11a..e3b600d4427672faa477341e207a5eab2bcf383d 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -41,7 +41,7 @@ void VXXJitCode::generate() { } else if (scalar_index_ == 2) { vbroadcastss(ymm_src2, ptr[param2]); } - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { if (scalar_index_ != 1) { vmovups(ymm_src1, ptr[param1 + offset]); } @@ -57,9 +57,9 @@ void VXXJitCode::generate() { vmaxps(ymm_dst, ymm_zero, ymm_dst); } vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; + offset += sizeof(float) * YMM_FLOAT_BLOCK; } - int rest = num_ % AVX_FLOAT_BLOCK; + int rest = num_ % YMM_FLOAT_BLOCK; if (rest >= 4) { if (scalar_index_ != 1) { vmovups(xmm_src1, ptr[param1 + offset]); @@ -118,18 +118,237 @@ void VXXJitCode::generate() { ret(); } -bool ReluJitCode::init(int d) { return MayIUse(avx); } +#define ALIGN32 __attribute__((aligned(32))) +#define EXP_HIG 88.3762626647949f +#define EXP_LOW -88.3762626647949f +#define CEPHES_LOG2EF 1.44269504088896341 +#define CEPHES_EXP_C1 0.693359375 +#define CEPHES_EXP_C2 -2.12194440e-4 +#define CEPHES_EXP_P0 1.9875691500E-4 +#define CEPHES_EXP_P1 1.3981999507E-3 +#define CEPHES_EXP_P2 8.3334519073E-3 +#define CEPHES_EXP_P3 4.1665795894E-2 +#define CEPHES_EXP_P4 1.6666665459E-1 +#define CEPHES_EXP_P5 5.0000001201E-1 -void ReluJitCode::generate() { - int offset = 0; +#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val + +#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) + +static const float exp_float_consts[] ALIGN32 = { + REPEAT_8TIMES(1.f), + REPEAT_8TIMES(2.f), + REPEAT_8TIMES(0.5f), + REPEAT_8TIMES(EXP_HIG), + REPEAT_8TIMES(EXP_LOW), + REPEAT_8TIMES(CEPHES_LOG2EF), + REPEAT_8TIMES(CEPHES_EXP_C1), + REPEAT_8TIMES(CEPHES_EXP_C2), + REPEAT_8TIMES(CEPHES_EXP_P0), + REPEAT_8TIMES(CEPHES_EXP_P1), + REPEAT_8TIMES(CEPHES_EXP_P2), + REPEAT_8TIMES(CEPHES_EXP_P3), + REPEAT_8TIMES(CEPHES_EXP_P4), + REPEAT_8TIMES(CEPHES_EXP_P5), + REPEAT_8TIMES(EXP_MAX_INPUT), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; + +static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; +static int g_tmp_mem[16] ALIGN32 = {0}; + +bool VActJitCode::init(int d, operand_type type) { + bool ok = MayIUse(avx); + if (type == operand_type::relu) { + return ok; + } else if (type == operand_type::exp) { + // exp is slower than mkl when d >= 256 + return ok && d % 8 == 0 && d < 256; + } else { + // TODO(TJ): support more + return ok && d % 8 == 0; + } +} + +void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) { + vmaxps(ymm_dst, ymm_zero, ymm_src); +} + +void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { + assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore + // check all idx can not equal + ymm_t ymm_fx = ymm_t(fx_idx); + ymm_t ymm_fy = ymm_t(fy_idx); + ymm_t ymm_mask = ymm_t(mask_idx); + ymm_t ymm_tmp = ymm_t(tmp_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); + vminps(ymm_src, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); + vmaxps(ymm_src, ymm_src, ymm_tmp); + // express exp(x) as exp(g + n*log(2)) + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); + vmulps(ymm_fx, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); + vaddps(ymm_fx, ymm_fx, ymm_tmp); + vroundps(ymm_fy, ymm_fx, 0x01); + // if greater, substract 1 + vcmpgtps(ymm_mask, ymm_fy, ymm_fx); + vmovaps(ymm_tmp, ptr[reg_ptr_global]); + vandps(ymm_mask, ymm_mask, ymm_tmp); + vsubps(ymm_fx, ymm_fy, ymm_mask); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); + vmulps(ymm_fy, ymm_fx, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); + ymm_t ymm_z = ymm_t(ymm_mask.getIdx()); + vmulps(ymm_z, ymm_fx, ymm_tmp); + vsubps(ymm_src, ymm_src, ymm_fy); + vsubps(ymm_src, ymm_src, ymm_z); + vmulps(ymm_z, ymm_src, ymm_src); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); + vmulps(ymm_dst, ymm_src, ymm_tmp); + for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; + i += (YMM_FLOAT_BLOCK * sizeof(float))) { + vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4 + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmulps(ymm_dst, ymm_dst, ymm_src); + } + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmulps(ymm_dst, ymm_dst, ymm_z); + vaddps(ymm_dst, ymm_dst, ymm_src); + vmovaps(ymm_tmp, ptr[reg_ptr_global]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + // build 2^n + ymm_t ymm_int = ymm_fx; + vcvttps2dq(ymm_int, ymm_fx); + mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); + vmovdqa(ymm_tmp, ptr[reg_ptr_global]); + if (MayIUse(avx2)) { + vpaddd(ymm_int, ymm_int, ymm_tmp); + vpslld(ymm_int, ymm_int, 23); + } else if (MayIUse(avx)) { + xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); + xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx()); + reg64_t reg_ptr_tmp = reg_ptr_global; + mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); + vmovdqa(ptr[reg_ptr_tmp], ymm_int); + vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp); + vpaddd(xtmp1, xtmp1, xtmp2); + vpslld(xtmp1, xtmp1, 23); + vmovdqa(ptr[reg_ptr_tmp], xtmp1); + // next 128bits + vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]); + vmovdqa(xtmp2, + ptr[reg_ptr_tmp + + (YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); + vpaddd(xtmp1, xtmp1, xtmp2); + vpslld(xtmp1, xtmp1, 23); + vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1); + // load out + vmovdqa(ymm_int, ptr[reg_ptr_tmp]); + } + vmulps(ymm_dst, ymm_dst, ymm_int); + pop(reg_ptr_global); +} + +void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { + // y = 1 / (1 + e^-x) + ymm_t ymm_tmp = ymm_t(tmp_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); + vminps(ymm_src, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); + vmaxps(ymm_src, ymm_src, ymm_tmp); + vxorps(ymm_tmp, ymm_tmp, ymm_tmp); + vsubps(ymm_src, ymm_tmp, ymm_src); + exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vdivps(ymm_dst, ymm_tmp, ymm_dst); + pop(reg_ptr_global); +} + +void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { + // y = 2 / (1 + e^(-2x)) - 1 + ymm_t ymm_tmp = ymm_t(tmp_idx); + ymm_t ymm_zero = ymm_t(mask_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vxorps(ymm_zero, ymm_zero, ymm_zero); - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vsubps(ymm_tmp, ymm_zero, ymm_tmp); + vmulps(ymm_src, ymm_src, ymm_tmp); + exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vdivps(ymm_dst, ymm_tmp, ymm_dst); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vsubps(ymm_dst, ymm_dst, ymm_tmp); + pop(reg_ptr_global); +} + +void VActJitCode::generate() { + xmm_t xmm_zero = xmm_t(2); + ymm_t ymm_zero = ymm_t(2); + if (type_ == operand_type::relu) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } + int offset = 0; + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { vmovups(ymm_src, ptr[param1 + offset]); - vmaxps(ymm_dst, ymm_zero, ymm_src); + switch (type_) { + case operand_type::relu: + relu_ymm(ymm_dst, ymm_src, ymm_zero); + break; + case operand_type::exp: + exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + break; + case operand_type::sigmoid: + sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + break; + case operand_type::tanh: + tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + break; + case operand_type::identity: + break; + default: + break; + } vmovups(ptr[param2 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + if (type_ != operand_type::relu) { + // TODO(TJ): remove me + ret(); + return; } - int rest = num_ % AVX_FLOAT_BLOCK; + int rest = num_ % YMM_FLOAT_BLOCK; if (rest >= 4) { vmovups(xmm_src, ptr[param1 + offset]); vmaxps(xmm_dst, xmm_zero, xmm_src); @@ -151,6 +370,7 @@ void ReluJitCode::generate() { } ret(); } + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 3c242870a24c5bb29d34d4b99406c5df8cec6763..71205b211b7f571f8081640ef60222de051ff49d 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -typedef enum { mul = 0, add } operand_type; +typedef enum { + mul = 0, + add, + sub, + relu, + exp, + sigmoid, + tanh, + identity +} operand_type; // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VXXJitCode : public JitCode { @@ -85,26 +94,65 @@ class VXXJitCode : public JitCode { ymm_t ymm_zero = ymm_t(3); }; -class ReluJitCode : public JitCode { +class VActJitCode : public JitCode { public: - DECLARE_JIT_CODE(ReluJitCode); - explicit ReluJitCode(int d, size_t code_size = 256 * 1024, + const char* name() const override { + std::string base = "VActJitCode"; + switch (type_) { + case operand_type::relu: + base += "_Relu"; + break; + case operand_type::exp: + base += "_Exp"; + break; + case operand_type::sigmoid: + base += "_Sigmoid"; + break; + case operand_type::tanh: + base += "_Tanh"; + break; + case operand_type::identity: + base += "_Identity"; + break; + default: + break; + } + return base.c_str(); + } + + explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} - static bool init(int d); + : JitCode(code_size, code_ptr), num_(d), type_(type) {} + static bool init(int d, operand_type type); void generate() override; - private: + protected: + // compute relu with ymm + void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, + const Xbyak::Ymm& zero); + + // compute exp with ymm + void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + + // compute sigmoid with ymm + void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + + // compute tanh with ymm + void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + + protected: int num_; + operand_type type_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; - xmm_t xmm_zero = xmm_t(0); - xmm_t xmm_src = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); + xmm_t xmm_src = xmm_t(0); + ymm_t ymm_src = ymm_t(0); - ymm_t ymm_zero = ymm_t(0); - ymm_t ymm_src = ymm_t(1); + xmm_t xmm_dst = xmm_t(1); ymm_t ymm_dst = ymm_t(1); }; diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index cd3a45e66773c89e45e80ab77ebd925abd6cbe53..4d8d3cd79a16a3ea61c4f63da3493e105847d30b 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -29,9 +29,9 @@ namespace jitkernel { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 -#define AVX_FLOAT_BLOCK 8 -#define AVX2_FLOAT_BLOCK 8 -#define AVX512_FLOAT_BLOCK 16 +#define XMM_FLOAT_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define ZMM_FLOAT_BLOCK 16 typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; @@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel { template class VActKernel : public Kernel { public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template -class VReluKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; - void (*Compute)(const T *, T *, int); -}; +class VReluKernel : public VActKernel {}; template -class VIdentityKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; -}; +class VIdentityKernel : public VActKernel {}; template -class VExpKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; -}; +class VExpKernel : public VActKernel {}; template -class VSigmoidKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; -}; +class VSigmoidKernel : public VActKernel {}; template -class VTanhKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; -}; +class VTanhKernel : public VActKernel {}; template class LSTMKernel : public Kernel { diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index cf46a210afbd4903dc3841f27765c390f721c763..36a50f20434f313e93bfa3dd2c9d46963024caf7 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -25,10 +25,6 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/mklml.h" #endif -#ifdef __AVX__ -#include -#endif - namespace paddle { namespace operators { namespace math { @@ -128,23 +124,16 @@ void VScalMKL(const double* a, const double* x, double* y, int n) { #endif -#define DECLARE_STATIC_FUNC \ - static inline std::string name(int d) { \ - PADDLE_THROW("DType should be either float or double"); \ - } \ - static inline bool useJIT(int d) { return false; } \ - static inline bool useMKL(int d) { return false; } - /* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VMulKernelImpl(int d) : VMulKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { // roughly estimate the size of code - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -191,11 +180,11 @@ bool VMulKernelImpl::useMKL(int d) { template class VAddKernelImpl : public VAddKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VAddKernelImpl(int d) : VAddKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -241,11 +230,11 @@ bool VAddKernelImpl::useMKL(int d) { template class VAddReluKernelImpl : public VAddReluKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VAddReluKernelImpl(int d) : VAddReluKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, sz > 4096 ? sz : 4096)); this->Compute = @@ -273,11 +262,11 @@ bool VAddReluKernelImpl::useJIT(int d) { template class VScalKernelImpl : public VScalKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VScalKernelImpl(int d) : VScalKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -322,11 +311,11 @@ bool VScalKernelImpl::useMKL(int d) { template class VAddBiasKernelImpl : public VAddBiasKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, sz > 4096 ? sz : 4096)); this->Compute = @@ -355,15 +344,15 @@ bool VAddBiasKernelImpl::useJIT(int d) { template class VReluKernelImpl : public VReluKernel { public: - DECLARE_STATIC_FUNC; + JITKERNEL_DECLARE_STATIC_FUNC; explicit VReluKernelImpl(int d) : VReluKernel() { - this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { - size_t sz = 96 /*init*/ + - d / AVX_FLOAT_BLOCK * 4 /* instructions*/ * - 8 /*everage byte for each instruction*/; - jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); + size_t sz = 96 /* init size */ + + d / YMM_FLOAT_BLOCK * 4 /* instructions */ * + 8 /* average bytes for each instruction */; + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } @@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel { this->Compute = VReluRefer; } - void ComputeDeprecated(const T* x, T* y) const override { - VReluRefer(x, y, this->num_); - } #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VReluKernelImpl::useJIT(int d) { - return gen::ReluJitCode::init(d); + return gen::VActJitCode::init(d, gen::operand_type::relu); } #endif -#undef DECLARE_STATIC_FUNC +template +inline void VIdentityRefer(const T* x, T* y, int n) {} + +/* An empty JitKernel */ +template +class VIdentityKernelImpl : public VIdentityKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VIdentityKernelImpl(int d) : VIdentityKernel() { + this->Compute = VIdentityRefer; + } +}; REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vadd, VAddKernel); @@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vrelu, VReluKernel); - -/* An empty JitKernel */ -template -class VIdentityKernelImpl : public VIdentityKernel { - public: - explicit VIdentityKernelImpl(int d) : VIdentityKernel() { this->num_ = d; } - void ComputeDeprecated(const T* x, T* y) const override {} -}; - -REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); +REGISTER_JITKERNEL(videntity, VIdentityKernel); } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc index a4861c347e44ad86a066861d3375b556302a84bc..4d26b81948238f18b097f535534fcfe9049b93c3 100644 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc @@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ - this->end_ = this->num_ / AVX_FLOAT_BLOCK; \ - this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \ + this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ + this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ } \ template <> \ void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ - INIT_ALPHA(AVX_FLOAT_BLOCK) \ + INIT_ALPHA(YMM_FLOAT_BLOCK) \ /* Use the column-major strategy to get the location of maximum score.*/ \ int seq_offset = 0; \ constexpr int state_trans_base_idx = 2; \ @@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { max_score = _mm256_max_ps(max_score, score_v); \ trans_offset += this->num_; \ } \ - UPDATE_ALPHA(AVX_FLOAT_BLOCK) \ + UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ } \ seq_offset += this->num_; \ } \ @@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { CRFDecodeKernelImpl::CRFDecodeKernelImpl(int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ - this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ - this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ + this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ + this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ } \ template <> \ void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ - INIT_ALPHA(AVX2_FLOAT_BLOCK) \ + INIT_ALPHA(YMM_FLOAT_BLOCK) \ /* Use the column-major strategy to get the location of maximum score.*/ \ int seq_offset = 0; \ constexpr int state_trans_base_idx = 2; \ @@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { max_score = _mm256_max_ps(max_score, score_v); \ trans_offset += this->num_; \ } \ - UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \ + UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ } \ seq_offset += this->num_; \ } \ @@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ - this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \ - this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \ + this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \ + this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \ } \ template <> \ void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ - INIT_ALPHA(AVX512_FLOAT_BLOCK) \ + INIT_ALPHA(ZMM_FLOAT_BLOCK) \ /* Use the column-major strategy to get the location of maximum score.*/ \ int seq_offset = 0; \ constexpr int state_trans_base_idx = 2; \ @@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { this->num_ + j_offset), \ max_j); \ /* Calculate the offset of next step*/ \ - j_offset += AVX512_FLOAT_BLOCK; \ + j_offset += ZMM_FLOAT_BLOCK; \ if (j == this->end_ - 1) { \ if (this->rest_ > 0) { \ j_offset += last_offset; \ diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 2ac9e1092362f60ea3d89da0c971a365b45f39ea..f26815300de31c47a7ea341307b0051dee99e63b 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -16,6 +16,11 @@ limitations under the License. */ #include // for exp #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" + +#ifdef PADDLE_WITH_XBYAK +#include "paddle/fluid/operators/math/jit_code.h" +#endif + #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif @@ -30,38 +35,238 @@ namespace math { namespace jitkernel { namespace jit = platform::jit; +// TODO(TJ): move refer codes to one file +// Refer code only focus on correctness +template +void VExpRefer(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +} + +template +void VSigmoidRefer(const T* x, T* y, int n) { + // y = 1 / (1 + e^-x) + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(1) / (static_cast(1) + std::exp(-tmp)); + } +} + +template +void VTanhRefer(const T* x, T* y, int n) { + // y = 2 * sigmoid(2x) - 1 + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * x[i]; + } + VSigmoidRefer(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * y[i] - static_cast(1); + } +} + +#ifdef PADDLE_WITH_MKLML +// try to use MKL to speedup +template +void VExpMKL(const T* x, T* y, int n); + +template <> +void VExpMKL(const float* x, float* y, int n) { + platform::dynload::vsExp(n, x, y); +} + +template <> +void VExpMKL(const double* x, double* y, int n) { + platform::dynload::vdExp(n, x, y); +} + +template +void VSigmoidMKL(const T* x, T* y, int n) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + VExpMKL(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } +} + +template +void VTanhMKL(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * x[i]; + } + VSigmoidMKL(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * y[i] - static_cast(1); + } +} +#endif + /* VExp JitKernel */ -template +template class VExpKernelImpl : public VExpKernel { public: - explicit VExpKernelImpl(int d) : VExpKernel() { this->num_ = d; } - void ComputeDeprecated(const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = std::exp(x[i]); + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VExpKernelImpl(int d) : VExpKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8; + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp, + sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; + } +#endif +#ifdef PADDLE_WITH_MKLML + if (useMKL(d)) { + this->Compute = VExpMKL; + return; } +#endif + this->Compute = VExpRefer; } + +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +#endif }; +#ifdef PADDLE_WITH_XBYAK +template <> +bool VExpKernelImpl::useJIT(int d) { + return gen::VActJitCode::init(d, gen::operand_type::exp); +} +#endif + #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - platform::dynload::vsExp(this->num_, x, y); \ +template <> +bool VExpKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VExpKernelImpl::useMKL(int d) { + return true; +} + +#endif + +/* VSigmoid JitKernel */ +template +class VSigmoidKernelImpl : public VSigmoidKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8; + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid, + sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; + } +#endif + +#ifdef PADDLE_WITH_MKLML + // strictly it's a better impl with MKL, then is refer + if (useMKL(d)) { + this->Compute = VSigmoidMKL; + return; + } +#endif + this->Compute = VSigmoidRefer; } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated( \ - const double* x, double* y) const { \ - platform::dynload::vdExp(this->num_, x, y); \ +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +#endif +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VSigmoidKernelImpl::useJIT(int d) { + return gen::VActJitCode::init(d, gen::operand_type::sigmoid); +} +#endif + +#ifdef PADDLE_WITH_MKLML +template <> +bool VSigmoidKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VSigmoidKernelImpl::useMKL(int d) { + return true; +} +#endif + +/* VTanh JitKernel */ +template +class VTanhKernelImpl : public VTanhKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VTanhKernelImpl(int d) : VTanhKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8; + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh, + sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; + } +#endif + +#ifdef PADDLE_WITH_MKLML + // strictly it's a better impl with MKL, then is refer + if (useMKL(d)) { + this->Compute = VTanhMKL; + return; + } +#endif + this->Compute = VTanhRefer; } -FOR_EACH_ISA(MKL_FLOAT, kLT8); -FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); + +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +#endif +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VTanhKernelImpl::useJIT(int d) { + return gen::VActJitCode::init(d, gen::operand_type::tanh); +} +#endif + +#ifdef PADDLE_WITH_MKLML +template <> +bool VTanhKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VTanhKernelImpl::useMKL(int d) { + return true; +} #endif +REGISTER_JITKERNEL(vexp, VExpKernel); +REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); +REGISTER_JITKERNEL(vtanh, VTanhKernel); + namespace detail { #ifdef __AVX__ @@ -210,334 +415,6 @@ __m256 ExpAVX2(__m256 x) { #endif } // namespace detail - -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - _mm256_storeu_ps(y, expisa(tmp)); \ - } - -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = expisa(tmp0); \ - tmp1 = expisa(tmp1); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx, detail::ExpAVX); -INTRI16_FLOAT(jit::avx, detail::ExpAVX); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); -#endif -// TODO(TJ): eq16 test and complete avx512 - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE - -REGISTER_JITKERNEL_DEPRECATED(vexp, VExpKernel); - -/* VSigmoid JitKernel */ -template -class VSigmoidKernelImpl : public VSigmoidKernel { - public: - explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { - this->num_ = d; - vexp_ = KernelPool::Instance().template Get>(d); - } - void ComputeDeprecated(const T* x, T* y) const override { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < this->num_; ++i) { - y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = static_cast(0) - y[i]; - } - vexp_->ComputeDeprecated(y, y); - for (int i = 0; i < this->num_; ++i) { - y[i] = static_cast(1) / (static_cast(1) + y[i]); - } - } - - private: - std::shared_ptr> vexp_; -}; - -#define INTRI_SIGMOID(tmp, min, max, expisa) \ - tmp = _mm256_max_ps(tmp, min); \ - tmp = _mm256_min_ps(tmp, max); \ - tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \ - tmp = expisa(tmp); \ - tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ - tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) - -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - /* TODO(TJ): try to use static const*/ \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y, tmp); \ - } - -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max, expisa); \ - INTRI_SIGMOID(tmp1, min, max, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } - -#define INTRI_GT8LT16_FLOAT(isa, expisa) \ - template <> \ - VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ - : VSigmoidKernel() { \ - this->num_ = d; \ - this->end_ = AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - vexp_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - } \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y, tmp); \ - const float min_ = SIGMOID_THRESHOLD_MIN; \ - const float max_ = SIGMOID_THRESHOLD_MAX; \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ - y[i] = 0.f - y[i]; \ - } \ - vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = 1.f / (1.f + y[i]); \ - } \ - } - -#define INTRI_GT16_FLOAT(isa, expisa) \ - template <> \ - VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ - : VSigmoidKernel() { \ - this->num_ = d; \ - this->rest_ = d % AVX_FLOAT_BLOCK; \ - this->end_ = d - this->rest_; \ - vexp_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - } \ - template <> \ - void VSigmoidKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - const float min_ = SIGMOID_THRESHOLD_MIN; \ - const float max_ = SIGMOID_THRESHOLD_MAX; \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ - y[i] = 0.f - y[i]; \ - } \ - vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ - for (int i = this->end_; i < this->num_; ++i) { \ - y[i] = 1.f / (1.f + y[i]); \ - } \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx, detail::ExpAVX); -INTRI16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); -// maybe use avx at gt8lt16 and gt16 -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); -// maybe use avx2 at gt8lt16 and gt16 -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_GT8LT16_FLOAT -#undef INTRI_GT16_FLOAT -#undef INTRI_VSIGMOID - -REGISTER_JITKERNEL_DEPRECATED(vsigmoid, VSigmoidKernel); - -/* VTanh JitKernel */ -template -class VTanhKernelImpl : public VTanhKernel { - public: - explicit VTanhKernelImpl(int d) : VTanhKernel() { - this->num_ = d; - vscal_ = KernelPool::Instance().template Get>(d); - vsigmoid_ = KernelPool::Instance().template Get>(d); - vaddbias_ = KernelPool::Instance().template Get>(d); - } - void ComputeDeprecated(const T* x, T* y) const override { - const T a = static_cast(2), b = static_cast(-1); - vscal_->Compute(&a, x, y, this->num_); - vsigmoid_->ComputeDeprecated(y, y); - vscal_->Compute(&a, y, y, this->num_); - vaddbias_->Compute(&b, y, y, this->num_); - } - - private: - std::shared_ptr> vscal_; - std::shared_ptr> vsigmoid_; - std::shared_ptr> vaddbias_; -}; - -#define INTRI_VTANH(tmp, expisa) \ - tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ - tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ - tmp = expisa(tmp); \ - tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ - tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ - tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) - -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y, tmp); \ - } - -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0, expisa); \ - INTRI_VTANH(tmp1, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } - -#define INTRI_GT8LT16_FLOAT(isa, expisa) \ - template <> \ - VTanhKernelImpl::VTanhKernelImpl(int d) \ - : VTanhKernel() { \ - this->num_ = d; \ - this->end_ = AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - vscal_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - vsigmoid_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - vaddbias_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - } \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated( \ - const float* x, float* y) const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y, tmp); \ - x += AVX_FLOAT_BLOCK; \ - y += AVX_FLOAT_BLOCK; \ - const float a = 2.f, b = -1.f; \ - vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->ComputeDeprecated(y, y); \ - vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(&b, y, y, this->num_); \ - } - -#define INTRI_GT16_FLOAT(isa, expisa) \ - template <> \ - VTanhKernelImpl::VTanhKernelImpl(int d) \ - : VTanhKernel() { \ - this->num_ = d; \ - this->rest_ = d % AVX_FLOAT_BLOCK; \ - this->end_ = d - this->rest_; \ - vscal_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - vsigmoid_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - vaddbias_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - } \ - template <> \ - void VTanhKernelImpl::ComputeDeprecated(const float* x, \ - float* y) const { \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - x += this->end_; \ - y += this->end_; \ - const float a = 2.f, b = -1.f; \ - vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->ComputeDeprecated(y, y); \ - vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(&b, y, y, this->num_); \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx, detail::ExpAVX); -INTRI16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); -INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); -// maybe use avx at gt8lt16 and gt16 -#endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); -INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); -// maybe use avx at gt8lt16 and gt16 -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_GT8LT16_FLOAT -#undef INTRI_GT16_FLOAT -#undef INTRI_VTANH - -REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel); - -#undef JITKERNEL_NEW_ACT_IMPL - } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h index a8169ea48ae3eee5a8cba291be4496c4c6074221..8acf60cfbfd3d47ad52862241b7635aba6982ebf 100644 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -15,12 +15,20 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { namespace math { namespace jitkernel { +#define JITKERNEL_DECLARE_STATIC_FUNC \ + static inline std::string name(int d) { \ + PADDLE_THROW("DType should be either float or double"); \ + } \ + static inline bool useJIT(int d) { return false; } \ + static inline bool useMKL(int d) { return false; } + #define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \ template <> \ std::string ker_class##Impl::name(int d) { \ @@ -86,17 +94,17 @@ namespace jitkernel { namespace jit = platform::jit; // TODO(TJ): below defines are deprecated, would be remove recently -#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ - if (d < AVX_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kLT8); \ - } else if (d == AVX_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ8); \ - } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kGT8LT16); \ - } else if (d == AVX512_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ16); \ - } else { \ - macro_(ker, dtype, isa, kGT16); \ +#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ + if (d < YMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kLT8); \ + } else if (d == YMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kEQ8); \ + } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kGT8LT16); \ + } else if (d == ZMM_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kEQ16); \ + } else { \ + macro_(ker, dtype, isa, kGT16); \ } #define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index 926221f0a75c461e275a72f16b4339ae28a8e988..e79b0400ab75d1488a26450bd8cde4a0979fc995 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel { void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, T* checked) const override { // gates: W_ch, W_ih, W_fh, W_oh - act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); + act_gate_d3_->Compute(gates + d_, gates + d_, d3_); /* C_t = C_t-1 * fgated + cand_gated * igated */ - act_cand_d_->ComputeDeprecated(gates, gates); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); - act_cand_d_->ComputeDeprecated(gates, gates); + act_gate_d_->Compute(gates + d_, gates + d_, d_); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_gate_d_->Compute(gates + d3_, gates + d3_, d_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel { vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); - act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); + act_gate_d2_->Compute(gates + d_, gates + d_, d2_); /* C_t = C_t-1 * fgated + cand_gated * igated*/ - act_cand_d_->ComputeDeprecated(gates, gates); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* get ogated*/ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); - act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); + act_gate_d_->Compute(gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); - act_cand_d_->ComputeDeprecated(gates, gates); + act_gate_d_->Compute(gates + d_, gates + d_, d_); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, ct, d_); /* get outgated, put W_oc * C_t on igated */ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_gate_d_->Compute(gates + d3_, gates + d3_, d_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel { } void ComputeH1(T* gates, T* ht) const override { - act_gate_d_->ComputeDeprecated(gates, gates); - act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); + act_gate_d_->Compute(gates, gates, d_); + act_state_d_->Compute(gates + d2_, gates + d2_, d_); vmul_d_->Compute(gates, gates + d2_, ht, d_); } void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { // W: {W_update, W_reset; W_state} - act_gate_d2_->ComputeDeprecated(gates, gates); + act_gate_d2_->Compute(gates, gates, d2_); vmul_d_->Compute(ht_1, gates + d_, ht, d_); } void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { T* y = gates + d2_; - act_state_d_->ComputeDeprecated(y, y); + act_state_d_->Compute(y, y, d_); // out = zt*ht~ + (1-zt)*ht_1 for (int i = 0; i < d_; ++i) { ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 5e1f91ffae03796be2817d0461900c2512938c77..5a6f87fe1f7d10d65d03d78c168d61719cec772e 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -181,7 +181,8 @@ TEST(JitKernel, vexp) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeDeprecated(x_data, ztgt_data); + // ker->Compute(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -222,7 +223,7 @@ void vsigmoid_better( y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = 0.f - y[i]; } - vexp->ComputeDeprecated(y, y); + vexp->Compute(y, y, n); for (int i = 0; i < n; ++i) { y[i] = 1.f / (1.f + y[i]); } @@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeDeprecated(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -287,7 +288,7 @@ void vtanh_better( const int n, const float* x, float* y) { const float a = 2.f, b = -1.f; vscal->Compute(&a, x, y, n); - vsigmoid->ComputeDeprecated(y, y); + vsigmoid->Compute(y, y, n); vscal->Compute(&a, y, y, n); vaddbias->Compute(&b, y, y, n); } @@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->ComputeDeprecated(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -344,8 +345,8 @@ void lstm_ctht_ref( const std::shared_ptr< const paddle::operators::math::jitkernel::VExpKernel>& vexp_1, const int d, float* gates, const float* ct_1, float* ct, float* ht) { - vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); - vtanh_d->ComputeDeprecated(gates, gates); + vsigmoid_3d->Compute(gates + d, gates + d, 3 * d); + vtanh_d->Compute(gates, gates, d); const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -355,7 +356,7 @@ void lstm_ctht_ref( // H_t = act_cell(C_t) * ogated float tmp = ct[k] * 2; tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vexp_1->ComputeDeprecated(&tmp, &tmp); + vexp_1->Compute(&tmp, &tmp, 1); tmp = 2.f / (1.f + tmp) - 1.f; ht[k] = tmp * o[k]; } @@ -373,13 +374,13 @@ void lstm_ctht_better( const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, const int d, float* gates, const float* ct_1, float* ct, float* ht) { int d2 = d * 2; - vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); - vtanh_d->ComputeDeprecated(gates, gates); + vsigmoid_3d->Compute(gates + d, gates + d, 3 * d); + vtanh_d->Compute(gates, gates, d); vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vadd_d->Compute(gates + d, gates + d2, ct, d); /* H_t = act_cell(C_t) * ogated */ - vtanh_d->ComputeDeprecated(ct, gates + d2); + vtanh_d->Compute(ct, gates + d2, d); vmul_d->Compute(gates + d2, gates + d * 3, ht, d); } @@ -736,7 +737,7 @@ void vaddrelu_better( const paddle::operators::math::jitkernel::VReluKernel>& vrelu, const float* x, const float* y, float* z, int d) { vadd->Compute(x, y, z, d); - vrelu->ComputeDeprecated(z, z); + vrelu->Compute(z, z, d); } TEST(JitKernel, vaddrelu) { diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc index 3066dc0ba284611af89c4927f45089a570ab88bc..690d6f6baafb33d50c8f2d3606d903634d622d16 100644 --- a/paddle/fluid/operators/math/sampler.cc +++ b/paddle/fluid/operators/math/sampler.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,52 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sampler.h" +#include +#include +#include +#include namespace paddle { -namespace random { +namespace operators { +namespace math { Sampler::~Sampler() {} -UniformSampler::UniformSampler(int64 range) - : Sampler(range), inv_range_(1.0 / range) { - random_engine_ = std::make_shared(seed_); +UniformSampler::UniformSampler(int64_t range, unsigned int seed) + : Sampler(range, seed), inv_range_(1.0 / (range + 1)) { + random_engine_ = std::make_shared(seed_); dist_ = std::make_shared>(0, range); } -UniformSampler::UniformSampler(int64 range, unsigned int seed) - : Sampler(range, seed), inv_range_(1.0 / range) { - random_engine_ = std::make_shared(seed_); - dist_ = std::make_shared>(0, range); -} - -int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); } +int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); } -float UniformSampler::Probability(int64 value) const { return inv_range_; } +float UniformSampler::Probability(int64_t value) const { return inv_range_; } -LogUniformSampler::LogUniformSampler(int64 range) - : Sampler(range), log_range_(log(range + 1)) { - random_engine_ = std::make_shared(seed_); - dist_ = std::make_shared>(0, 1); -} - -LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed) +LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed) : Sampler(range, seed), log_range_(log(range + 1)) { - random_engine_ = std::make_shared(seed_); + random_engine_ = std::make_shared(seed_); dist_ = std::make_shared>(0, 1); } -int64 LogUniformSampler::Sample() const { + +int64_t LogUniformSampler::Sample() const { // Got Log Uniform distribution from uniform distribution by // inverse_transform_sampling method // More details: // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ - const int64 value = - static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; + const int64_t value = + static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; // Mathematically, value should be <= range_, but might not be due to some // floating point roundoff, so we mod by range_. return value % range_; } -float LogUniformSampler::Probability(int64 value) const { +float LogUniformSampler::Probability(int64_t value) const { // Given f(x) = 1/[(x+1) * log_range_] // The value's probability is integral of f(x) from value to (value + 1) // More details: @@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const { return (log((value + 2.0) / (value + 1.0))) / log_range_; } -} // namespace random +CustomSampler::CustomSampler(int64_t range, const float* probabilities, + unsigned int seed) + : Sampler(range, seed) { + random_engine_ = std::make_shared(seed_); + real_dist_ = std::make_shared>(0, 1); + int_dist_ = std::make_shared>(0, range); + alias_probs_ = std::make_shared>(range + 1); + alias_ = std::make_shared>(range + 1); + probs_ = std::make_shared>(range + 1); + + std::queue> bigs; + std::queue> littles; + for (int64_t i = 0; i <= range; ++i) { + (*probs_)[i] = probabilities[i]; + float normal_prob = probabilities[i] * (range + 1); + if (normal_prob - 1.0 > 1e-4) { + bigs.emplace(i, normal_prob); + } else if (1.0 - normal_prob > 1e-4) { + littles.emplace(i, normal_prob); + } else { + (*alias_probs_)[i] = normal_prob; + (*alias_)[i] = -1; + } + } + + while ((!littles.empty()) && (!bigs.empty())) { + auto big = bigs.front(); + auto little = littles.front(); + bigs.pop(); + littles.pop(); + (*alias_probs_)[little.first] = little.second; + (*alias_)[little.first] = big.first; + auto big_left = big.second - (1 - little.second); + if (big_left - 1.0 > 1e-4) { + bigs.emplace(big.first, big_left); + } else if (1.0 - big_left > 1e-4) { + littles.emplace(big.first, big_left); + } else { + (*alias_probs_)[big.first] = big_left; + (*alias_)[big.first] = -1; + } + } + + if (!littles.empty()) { // littles.second is close to 1.0 + auto little = littles.front(); + (*alias_probs_)[little.first] = 1.0; + (*alias_)[little.first] = -1; + } + + if (!bigs.empty()) { // bigs.second is close to 1.0 + auto big = bigs.front(); + (*alias_probs_)[big.first] = 1.0; + (*alias_)[big.first] = -1; + } +} + +int64_t CustomSampler::Sample() const { + auto index = (*int_dist_)(*random_engine_); + auto p = (*real_dist_)(*random_engine_); + if (p > (*alias_probs_)[index]) { + return (*alias_)[index]; + } else { + return index; + } +} + +float CustomSampler::Probability(int64_t value) const { + return (*probs_)[value]; +} + +} // namespace math +} // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/sampler.h b/paddle/fluid/operators/math/sampler.h index b82691f269c5d0f267ca98c78646efe9b26f0b34..836cdad51f17e93f811ba14695bbe1a65156c588 100644 --- a/paddle/fluid/operators/math/sampler.h +++ b/paddle/fluid/operators/math/sampler.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include #include +#include + namespace paddle { namespace operators { namespace math { @@ -27,14 +29,14 @@ namespace math { */ class Sampler { public: - explicit Sampler(int64_t range) : range_(range) { - PADDLE_ENFORCE_GT(range, 0); - std::random_device r; - seed_ = r(); - } - explicit Sampler(int64_t range, unsigned int seed) - : range_(range), seed_(seed) { - PADDLE_ENFORCE_GT(range, 0); + explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) { + // PADDLE_ENFORCE_GT(range, 0, "Range should be greater than 0."); + if (seed == 0) { + std::random_device r; + seed_ = r(); + } else { + seed_ = seed; + } } virtual ~Sampler(); // Sample a single value @@ -42,7 +44,7 @@ class Sampler { // The probability that a single call to Sample() returns the given value. virtual float Probability(int64_t value) const = 0; - int64 range() { return range_; } + int64_t range() { return range_; } protected: const int64_t range_; @@ -56,13 +58,11 @@ class Sampler { */ class UniformSampler : public Sampler { public: - explicit UniformSampler(int64_t range); - - explicit UniformSampler(int64_t range, unsigned int seed); + explicit UniformSampler(int64_t range, unsigned int seed = 0UL); ~UniformSampler() override {} - int64 Sample() const override; + int64_t Sample() const override; float Probability(int64_t value) const override; @@ -79,13 +79,11 @@ class UniformSampler : public Sampler { */ class LogUniformSampler : public Sampler { public: - explicit LogUniformSampler(int64_t range); - - explicit LogUniformSampler(int64_t range, unsigned int seed); + explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL); ~LogUniformSampler() override {} - int64 Sample() const override; + int64_t Sample() const override; float Probability(int64_t value) const override; @@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler { std::shared_ptr> dist_; }; +/** + * Sample integers from [0, range) from custom distribution. + */ +class CustomSampler : public Sampler { + public: + explicit CustomSampler(int64_t range, const float* probabilities, + unsigned int seed = 0UL); + + ~CustomSampler() override {} + + int64_t Sample() const override; + + float Probability(int64_t value) const override; + + private: + std::shared_ptr> alias_probs_; + std::shared_ptr> alias_; + std::shared_ptr> probs_; + std::shared_ptr random_engine_; + std::shared_ptr> real_dist_; + std::shared_ptr> int_dist_; +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/metrics/CMakeLists.txt b/paddle/fluid/operators/metrics/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d468316e8eacb73c4a4ce81c784880bb5e46c2d --- /dev/null +++ b/paddle/fluid/operators/metrics/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc similarity index 98% rename from paddle/fluid/operators/accuracy_op.cc rename to paddle/fluid/operators/metrics/accuracy_op.cc index 42fcace17926641b5caf677eb3c8ba5222e37190..95aa76bc6947c9c39e56d39031c5184dc262acd0 100644 --- a/paddle/fluid/operators/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/accuracy_op.h" +#include "paddle/fluid/operators/metrics/accuracy_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/accuracy_op.cu b/paddle/fluid/operators/metrics/accuracy_op.cu similarity index 98% rename from paddle/fluid/operators/accuracy_op.cu rename to paddle/fluid/operators/metrics/accuracy_op.cu index 23b48c6fdf427348879de07c671c65327d6436d7..b255d2a7c413b4f965f6b874d342dcb93c7b5e66 100644 --- a/paddle/fluid/operators/accuracy_op.cu +++ b/paddle/fluid/operators/metrics/accuracy_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/accuracy_op.h" +#include "paddle/fluid/operators/metrics/accuracy_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" diff --git a/paddle/fluid/operators/accuracy_op.h b/paddle/fluid/operators/metrics/accuracy_op.h similarity index 100% rename from paddle/fluid/operators/accuracy_op.h rename to paddle/fluid/operators/metrics/accuracy_op.h diff --git a/paddle/fluid/operators/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc similarity index 98% rename from paddle/fluid/operators/auc_op.cc rename to paddle/fluid/operators/metrics/auc_op.cc index cb98bc514083ad113fdebfbac043a9516fd9435a..335d4fded4a9543dabf984f7ed9c342b46dd04f0 100644 --- a/paddle/fluid/operators/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/auc_op.h" +#include "paddle/fluid/operators/metrics/auc_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/metrics/auc_op.h similarity index 100% rename from paddle/fluid/operators/auc_op.h rename to paddle/fluid/operators/metrics/auc_op.h diff --git a/paddle/fluid/operators/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc similarity index 99% rename from paddle/fluid/operators/precision_recall_op.cc rename to paddle/fluid/operators/metrics/precision_recall_op.cc index e7ce16f33fb5052ffb41fc05bd1538e2f0dc35be..0d733c47dd2fcaad776d8d4e6467ecd1872bce05 100644 --- a/paddle/fluid/operators/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/precision_recall_op.h" +#include "paddle/fluid/operators/metrics/precision_recall_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/precision_recall_op.h b/paddle/fluid/operators/metrics/precision_recall_op.h similarity index 100% rename from paddle/fluid/operators/precision_recall_op.h rename to paddle/fluid/operators/metrics/precision_recall_op.h diff --git a/paddle/fluid/operators/nccl/CMakeLists.txt b/paddle/fluid/operators/nccl/CMakeLists.txt index cdcba8035762d8f442eb8b8ed52a4e3e99ac31b6..9b26e19cc7ed05038e05308f9277b200a885dc10 100644 --- a/paddle/fluid/operators/nccl/CMakeLists.txt +++ b/paddle/fluid/operators/nccl/CMakeLists.txt @@ -1,3 +1,13 @@ if(WITH_GPU AND NOT WIN32) nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator ) endif() + +if(WITH_GPU) + op_library(nccl_op DEPS nccl_common) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") + set(OPERATOR_DEPS ${OPERATOR_DEPS} nccl_common PARENT_SCOPE) +endif() + +if(NOT WIN32) + nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) +endif() diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl/nccl_op.cc similarity index 100% rename from paddle/fluid/operators/nccl_op.cc rename to paddle/fluid/operators/nccl/nccl_op.cc diff --git a/paddle/fluid/operators/nccl_op.cu.cc b/paddle/fluid/operators/nccl/nccl_op.cu.cc similarity index 100% rename from paddle/fluid/operators/nccl_op.cu.cc rename to paddle/fluid/operators/nccl/nccl_op.cu.cc diff --git a/paddle/fluid/operators/nccl_op_test.cu.cc b/paddle/fluid/operators/nccl/nccl_op_test.cu.cc similarity index 100% rename from paddle/fluid/operators/nccl_op_test.cu.cc rename to paddle/fluid/operators/nccl/nccl_op_test.cu.cc diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 877c9a0528441a7d5b1306c3f8f8be1a5aea577a..9b0d45ae5b9d104c8b7bb1529a9baaaf3d6a736d 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("Input"); auto label_dims = ctx->GetInputDim("Label"); + auto w_dims = ctx->GetInputDim("Weight"); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; if (ctx->HasInput("Bias")) { @@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { "each sample. And it is a dispensable input. The default value of " "sample is 1.") .AsDispensable(); + + AddInput( + "CustomDistribution", + "(Tensor) It is used in 'CostumDist' sampler. " + "It is a tensor with shape [num_total_classes]." + "The i-th element is the probsbility of the i-th class being sampled.") + .AsDispensable(); AddOutput("Cost", "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); AddOutput("SampleLogits", @@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("num_neg_samples", "The number of negative classes. The default value is 10.") .SetDefault(10); + + AddAttr("sampler", + "(int) Which sampler to be used to sample negative class." + "0: Uniform; 1: LogUniform; 2: CostumDist.") + .SetDefault(0); + + AddAttr("seed", + "(int) The seed used in sampler. If it is 0, " + "the sampler will generate a seed randomly.") + .SetDefault(0); + AddAttr>("custom_neg_classes", "This attribute only be used in unitest. Classes " "in this list wiil be used as negative classes " diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 2c4c97f28bc0b511d6eaa8f79a3a4efc9be8a5da..e9af8ad4ce8501f464202039d99c36984d7feba9 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -19,29 +19,28 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using Sampler = math::Sampler; template using EigenMatrix = framework::EigenMatrix; template -void PrepareSamples(const framework::ExecutionContext& context) { +void PrepareSamples(const framework::ExecutionContext& context, + Sampler* sampler) { auto label = context.Input("Label"); const int64_t* label_data = label->data(); auto label_dims = label->dims(); - int num_total_classes = context.Attr("num_total_classes"); + // int num_total_classes = context.Attr("num_total_classes"); // for unitest std::vector custom_neg_classes = context.Attr>("custom_neg_classes"); - // random machine - std::random_device rd; - std::mt19937 rng(rd()); - std::uniform_int_distribution rand(0, num_total_classes - 1); auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); @@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { } else { for (; j < sample_labels_dims[1]; ++j) { // TODO(wanghaoshuang): support more distribution sampling - sample_labels_data[index++] = rand(rng); + sample_labels_data[index++] = sampler->Sample(); } } } @@ -72,7 +71,33 @@ template class NCEKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PrepareSamples(context); + int sampler_type = context.Attr("sampler"); + int seed = context.Attr("seed"); + int num_total_classes = context.Attr("num_total_classes"); + int num_neg_samples = context.Attr("num_neg_samples"); + + Sampler* sampler; + switch (sampler_type) { + case 0: { + sampler = new math::UniformSampler(num_total_classes - 1, seed); + break; + } + case 1: { + sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + break; + } + case 2: { + auto custom_dist = context.Input("CustomDistribution"); + const float* custom_dist_data = custom_dist->data(); + PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); + sampler = new math::CustomSampler(num_total_classes - 1, + custom_dist_data, seed); + break; + } + default: { PADDLE_THROW("Unsupported SamplerType."); } + } + + PrepareSamples(context, sampler); auto sample_labels = context.Output("SampleLabels"); const int64_t* sample_labels_data = sample_labels->data(); auto sample_out = context.Output("SampleLogits"); @@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel { } auto out = context.Output("Cost"); T* out_data = out->mutable_data(context.GetPlace()); - int num_neg_samples = context.Attr("num_neg_samples"); - int num_total_classes = context.Attr("num_total_classes"); int64_t num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_total_classes * num_neg_samples; + int64_t sampled_labels_num = sample_labels->dims()[1]; + // T b = 1. / num_total_classes * num_neg_samples; // forward bias auto bias = context.Input("Bias"); if (bias != nullptr) { @@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel { } // forward cost for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { - int64_t j = 0; out_data[i] = 0; T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; - // for true classes - for (; j < num_true_class; ++j) { - T o = sample_out_data[i * sample_out->dims()[1] + j]; - T cost = -log(o / (o + b)); - out_data[i] += w * cost; - } - // for sampled neg classes - for (; j < sample_labels->dims()[1]; ++j) { - T o = sample_out_data[i * sample_out->dims()[1] + j]; - T cost = -log(b / (o + b)); + for (int64_t j = 0; j < sampled_labels_num; ++j) { + int64_t target = sample_labels_data[i * sampled_labels_num + j]; + T o = sample_out_data[i * sampled_labels_num + j]; + float b = sampler->Probability(target) * num_neg_samples; + T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b)); out_data[i] += w * cost; } } + delete sampler; } }; @@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel { if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_total_classes * num_neg_samples; + + int sampler_type = context.Attr("sampler"); + int seed = context.Attr("seed"); + Sampler* sampler; + switch (sampler_type) { + case 0: { + sampler = new math::UniformSampler(num_total_classes - 1, seed); + break; + } + case 1: { + sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + break; + } + case 2: { + auto custom_dist = context.Input("CustomDistribution"); + const float* custom_dist_data = custom_dist->data(); + PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); + sampler = new math::CustomSampler(num_total_classes - 1, + custom_dist_data, seed); + break; + } + default: { PADDLE_THROW("Unsupported SamplerType."); } + } + + // T b = 1. / num_total_classes * num_neg_samples; Tensor sample_grad; // tmp tensor T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); // backward cost for (int64_t i = 0; i < sample_labels->numel(); ++i) { + int64_t label_idx = i % sample_labels->dims()[1]; + int64_t sample_idx = i / sample_labels->dims()[1]; + float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples; T o = sample_out_data[i]; - T w = sample_weight == nullptr - ? 1 - : sample_weight_data[i / sample_labels->dims()[1]]; - sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class + T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx]; + sample_grad_data[i] = label_idx < num_true_class ? w * (b / (o + b)) * (o - 1) : w * (o * (1 - o) / (o + b)); - sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; + sample_grad_data[i] *= d_out_data[sample_idx]; } // get d_bias auto d_bias = context.Output(framework::GradVarName("Bias")); @@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel { w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; } } + delete sampler; } }; } // namespace operators diff --git a/paddle/fluid/operators/optimizers/CMakeLists.txt b/paddle/fluid/operators/optimizers/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d468316e8eacb73c4a4ce81c784880bb5e46c2d --- /dev/null +++ b/paddle/fluid/operators/optimizers/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc similarity index 98% rename from paddle/fluid/operators/adadelta_op.cc rename to paddle/fluid/operators/optimizers/adadelta_op.cc index 89a7a49e0fa8427826f5d91274912a68f2316b61..9039d02b673b3403c840492c088179b30e23da9c 100644 --- a/paddle/fluid/operators/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/adadelta_op.h" +#include "paddle/fluid/operators/optimizers/adadelta_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/adadelta_op.cu b/paddle/fluid/operators/optimizers/adadelta_op.cu similarity index 93% rename from paddle/fluid/operators/adadelta_op.cu rename to paddle/fluid/operators/optimizers/adadelta_op.cu index fc10c6657476e7f87b2f703a1d0cb88eeebc35cf..3fbfee5df05770a1206ab3170d3baffdd20bc77b 100644 --- a/paddle/fluid/operators/adadelta_op.cu +++ b/paddle/fluid/operators/optimizers/adadelta_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/adadelta_op.h" +#include "paddle/fluid/operators/optimizers/adadelta_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/adadelta_op.h b/paddle/fluid/operators/optimizers/adadelta_op.h similarity index 100% rename from paddle/fluid/operators/adadelta_op.h rename to paddle/fluid/operators/optimizers/adadelta_op.h diff --git a/paddle/fluid/operators/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc similarity index 99% rename from paddle/fluid/operators/adagrad_op.cc rename to paddle/fluid/operators/optimizers/adagrad_op.cc index c88297ff544ddb0e5a97452a8ad2e8f9f77825ba..e8d5a9e2c875570a198629bd745c9d58036746cb 100644 --- a/paddle/fluid/operators/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/adagrad_op.h" +#include "paddle/fluid/operators/optimizers/adagrad_op.h" #include #include diff --git a/paddle/fluid/operators/adagrad_op.cu b/paddle/fluid/operators/optimizers/adagrad_op.cu similarity index 98% rename from paddle/fluid/operators/adagrad_op.cu rename to paddle/fluid/operators/optimizers/adagrad_op.cu index b99b33343d36fbb7f6b1a2928e142ca615b238b3..4efe56855a4bdca41d24f02c29a618a8d4232887 100644 --- a/paddle/fluid/operators/adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/adagrad_op.cu @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/adagrad_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/operators/optimizers/adagrad_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { diff --git a/paddle/fluid/operators/adagrad_op.h b/paddle/fluid/operators/optimizers/adagrad_op.h similarity index 100% rename from paddle/fluid/operators/adagrad_op.h rename to paddle/fluid/operators/optimizers/adagrad_op.h diff --git a/paddle/fluid/operators/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc similarity index 99% rename from paddle/fluid/operators/adam_op.cc rename to paddle/fluid/operators/optimizers/adam_op.cc index f3717af630017eba18aa265f3dbb496e18280a57..5710cda39acce53e35dfceec675fcd4979a84e31 100644 --- a/paddle/fluid/operators/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/adam_op.h" +#include "paddle/fluid/operators/optimizers/adam_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu similarity index 93% rename from paddle/fluid/operators/adam_op.cu rename to paddle/fluid/operators/optimizers/adam_op.cu index 77f1991002e6007e8b8dff4746739a90e836145d..e8090ebacfe85153aba9e275c9cd1c55fd7af15e 100644 --- a/paddle/fluid/operators/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/adam_op.h" +#include "paddle/fluid/operators/optimizers/adam_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h similarity index 100% rename from paddle/fluid/operators/adam_op.h rename to paddle/fluid/operators/optimizers/adam_op.h diff --git a/paddle/fluid/operators/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc similarity index 99% rename from paddle/fluid/operators/adamax_op.cc rename to paddle/fluid/operators/optimizers/adamax_op.cc index d4aa4d338a2379adf985ba7f89b528bc402eda06..4b244a76dc0ebee65b7c95db2d2754ebae03bbac 100644 --- a/paddle/fluid/operators/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/adamax_op.h" +#include "paddle/fluid/operators/optimizers/adamax_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/adamax_op.cu b/paddle/fluid/operators/optimizers/adamax_op.cu similarity index 93% rename from paddle/fluid/operators/adamax_op.cu rename to paddle/fluid/operators/optimizers/adamax_op.cu index 05cafd7a8eef79588d1d5724084586cb9b51d3d4..e54adcb142fe0d50dad23fe5df14bd6f28220d8a 100644 --- a/paddle/fluid/operators/adamax_op.cu +++ b/paddle/fluid/operators/optimizers/adamax_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/adamax_op.h" +#include "paddle/fluid/operators/optimizers/adamax_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/adamax_op.h b/paddle/fluid/operators/optimizers/adamax_op.h similarity index 100% rename from paddle/fluid/operators/adamax_op.h rename to paddle/fluid/operators/optimizers/adamax_op.h diff --git a/paddle/fluid/operators/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc similarity index 98% rename from paddle/fluid/operators/decayed_adagrad_op.cc rename to paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index d73ae9e2721b388212cb6efa354eb4b480df9cad..80278441c07203b03dbcff157193ea5976eefbf1 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/decayed_adagrad_op.h" +#include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/decayed_adagrad_op.cu b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu similarity index 92% rename from paddle/fluid/operators/decayed_adagrad_op.cu rename to paddle/fluid/operators/optimizers/decayed_adagrad_op.cu index 7da16acf05eefc21cbe3dd0540dcbf69022431de..84d65e39329659f82099011f9ec60468d5db6328 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/decayed_adagrad_op.h" +#include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/decayed_adagrad_op.h b/paddle/fluid/operators/optimizers/decayed_adagrad_op.h similarity index 100% rename from paddle/fluid/operators/decayed_adagrad_op.h rename to paddle/fluid/operators/optimizers/decayed_adagrad_op.h diff --git a/paddle/fluid/operators/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc similarity index 99% rename from paddle/fluid/operators/ftrl_op.cc rename to paddle/fluid/operators/optimizers/ftrl_op.cc index b77e12d6508eb07ae137b313ca91eac951afbcbe..1c9e91d9b610669def6d6d52e4753714745d1c0f 100644 --- a/paddle/fluid/operators/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/ftrl_op.h" +#include "paddle/fluid/operators/optimizers/ftrl_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/ftrl_op.cu b/paddle/fluid/operators/optimizers/ftrl_op.cu similarity index 93% rename from paddle/fluid/operators/ftrl_op.cu rename to paddle/fluid/operators/optimizers/ftrl_op.cu index e7371c80da1d1cbb39247b50d8c6537ee8e948f8..f836b75df93861a0fd670f2a0e786e6a797a4661 100644 --- a/paddle/fluid/operators/ftrl_op.cu +++ b/paddle/fluid/operators/optimizers/ftrl_op.cu @@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/ftrl_op.h" +#include "paddle/fluid/operators/optimizers/ftrl_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/ftrl_op.h b/paddle/fluid/operators/optimizers/ftrl_op.h similarity index 100% rename from paddle/fluid/operators/ftrl_op.h rename to paddle/fluid/operators/optimizers/ftrl_op.h diff --git a/paddle/fluid/operators/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc similarity index 96% rename from paddle/fluid/operators/lars_momentum_op.cc rename to paddle/fluid/operators/optimizers/lars_momentum_op.cc index a8dda93902448fa1bd21b719ffd9c9b500caf755..574a03680b66962ac2d6ba249d0fc491a36794cd 100644 --- a/paddle/fluid/operators/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/lars_momentum_op.h" -#include "paddle/fluid/operators/momentum_op.h" +#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" +#include "paddle/fluid/operators/optimizers/momentum_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu similarity index 98% rename from paddle/fluid/operators/lars_momentum_op.cu rename to paddle/fluid/operators/optimizers/lars_momentum_op.cu index eb346851a2f690fa05422c84ddcb08307539048f..a277d6ff2bea917addac8c6ea4b24b63dcbc8dba 100644 --- a/paddle/fluid/operators/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/lars_momentum_op.h" +#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h similarity index 100% rename from paddle/fluid/operators/lars_momentum_op.h rename to paddle/fluid/operators/optimizers/lars_momentum_op.h diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc similarity index 98% rename from paddle/fluid/operators/momentum_op.cc rename to paddle/fluid/operators/optimizers/momentum_op.cc index 7f0b51580aa2591ac7338ad7c29ee4756d909925..cde238c076b6991eb52dac328c3e30a045420c92 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/momentum_op.h" +#include "paddle/fluid/operators/optimizers/momentum_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/momentum_op.cu b/paddle/fluid/operators/optimizers/momentum_op.cu similarity index 93% rename from paddle/fluid/operators/momentum_op.cu rename to paddle/fluid/operators/optimizers/momentum_op.cu index b68fec34d43f0dee834f1045f192d5c6089d9356..8ce739de8dfd74cb43f9521bf39e3127a8a21925 100644 --- a/paddle/fluid/operators/momentum_op.cu +++ b/paddle/fluid/operators/optimizers/momentum_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/momentum_op.h" +#include "paddle/fluid/operators/optimizers/momentum_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h similarity index 100% rename from paddle/fluid/operators/momentum_op.h rename to paddle/fluid/operators/optimizers/momentum_op.h diff --git a/paddle/fluid/operators/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc similarity index 98% rename from paddle/fluid/operators/proximal_adagrad_op.cc rename to paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 8d8075d76111928ec9855eb0b70fe6dbd90a979b..7b07b3b7071cb39e4e81cb4612372eec96efe489 100644 --- a/paddle/fluid/operators/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/proximal_adagrad_op.h" +#include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/proximal_adagrad_op.cu b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu similarity index 92% rename from paddle/fluid/operators/proximal_adagrad_op.cu rename to paddle/fluid/operators/optimizers/proximal_adagrad_op.cu index 7e0226c62bfd5d4804cc70c00391237deec33ebb..d1c1f747b70c3ceb806da06e6786a70b62a32995 100644 --- a/paddle/fluid/operators/proximal_adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu @@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/proximal_adagrad_op.h" +#include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/proximal_adagrad_op.h b/paddle/fluid/operators/optimizers/proximal_adagrad_op.h similarity index 100% rename from paddle/fluid/operators/proximal_adagrad_op.h rename to paddle/fluid/operators/optimizers/proximal_adagrad_op.h diff --git a/paddle/fluid/operators/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc similarity index 98% rename from paddle/fluid/operators/proximal_gd_op.cc rename to paddle/fluid/operators/optimizers/proximal_gd_op.cc index baf9cbcba2ed89f62afc9816e0ab9e0f112e6008..dcef4f7be249e04306732213a7c6209d32602048 100644 --- a/paddle/fluid/operators/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/proximal_gd_op.h" +#include "paddle/fluid/operators/optimizers/proximal_gd_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/proximal_gd_op.cu b/paddle/fluid/operators/optimizers/proximal_gd_op.cu similarity index 92% rename from paddle/fluid/operators/proximal_gd_op.cu rename to paddle/fluid/operators/optimizers/proximal_gd_op.cu index 32ee9ab74cd58fd6f48b6c34e108f31315adaf71..7aa0e1015008eba0c1cf63ba1278dc2b8049b20b 100644 --- a/paddle/fluid/operators/proximal_gd_op.cu +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cu @@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/proximal_gd_op.h" +#include "paddle/fluid/operators/optimizers/proximal_gd_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/proximal_gd_op.h b/paddle/fluid/operators/optimizers/proximal_gd_op.h similarity index 100% rename from paddle/fluid/operators/proximal_gd_op.h rename to paddle/fluid/operators/optimizers/proximal_gd_op.h diff --git a/paddle/fluid/operators/rmsprop_op.cc b/paddle/fluid/operators/optimizers/rmsprop_op.cc similarity index 99% rename from paddle/fluid/operators/rmsprop_op.cc rename to paddle/fluid/operators/optimizers/rmsprop_op.cc index f06f87e61d3a4d1fc8b864b9dd84e697fb12a006..99d1156ee6d5fc88161e25bfa581a265707e6f92 100644 --- a/paddle/fluid/operators/rmsprop_op.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/rmsprop_op.h" +#include "paddle/fluid/operators/optimizers/rmsprop_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/rmsprop_op.cu b/paddle/fluid/operators/optimizers/rmsprop_op.cu similarity index 92% rename from paddle/fluid/operators/rmsprop_op.cu rename to paddle/fluid/operators/optimizers/rmsprop_op.cu index cdc473769598be5aac87a14613d9acdd5c1a1204..69e35a309e04f61068d9ff1b6d9f1450d2524253 100644 --- a/paddle/fluid/operators/rmsprop_op.cu +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/rmsprop_op.h" +#include "paddle/fluid/operators/optimizers/rmsprop_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/rmsprop_op.h b/paddle/fluid/operators/optimizers/rmsprop_op.h similarity index 100% rename from paddle/fluid/operators/rmsprop_op.h rename to paddle/fluid/operators/optimizers/rmsprop_op.h diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc similarity index 98% rename from paddle/fluid/operators/sgd_op.cc rename to paddle/fluid/operators/optimizers/sgd_op.cc index ea62acd08c5009556abf05c91726111870d1a462..690381a67f89d18fe81c3b856b7ddce25d496ed0 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sgd_op.h" +#include "paddle/fluid/operators/optimizers/sgd_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/optimizers/sgd_op.cu similarity index 98% rename from paddle/fluid/operators/sgd_op.cu rename to paddle/fluid/operators/optimizers/sgd_op.cu index d3f4eba3b24ec1ac0328ef270256cdf3abe499db..a9d303d55d8f681fe3a014db36ede5ef6b2742bd 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/optimizers/sgd_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/sgd_op.h" +#include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { diff --git a/paddle/fluid/operators/sgd_op.h b/paddle/fluid/operators/optimizers/sgd_op.h similarity index 100% rename from paddle/fluid/operators/sgd_op.h rename to paddle/fluid/operators/optimizers/sgd_op.h diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 728197377df04df8c993a48bc282431473fe9959..6c919ee1782ebce6d56f7530daa9b748dfb26c47 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -1,3 +1,5 @@ +include(operators) + cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader) set(LOCAL_READER_LIBS) @@ -28,4 +30,10 @@ reader_library(create_py_reader_op SRCS create_py_reader_op.cc) cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) # Export local libraries to parent -set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) +# set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) + +op_library(read_op) + +foreach(src ${LOCAL_READER_LIBS}) + set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs") +endforeach() diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/reader/read_op.cc similarity index 100% rename from paddle/fluid/operators/read_op.cc rename to paddle/fluid/operators/reader/read_op.cc diff --git a/paddle/fluid/operators/reduce_ops/CMakeLists.txt b/paddle/fluid/operators/reduce_ops/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5fe4d15ae2c6254a50318813c852b6c314880aba --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/CMakeLists.txt @@ -0,0 +1,20 @@ +include(operators) +register_operators() + +if(WITH_GPU) + file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.part.cu") + string(REPLACE ".part.cu" "" OPS "${OPS}") + + foreach(src ${OPS}) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu) + set(CUDA_KERNEL_FILE ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu) + file(READ ${CUDA_KERNEL_FILE} TARGET_CONTENT) + string(REGEX MATCH "REGISTER_OP_CUDA_KERNEL\\(\\n?([^,]+),.*" MATCHED ${TARGET_CONTENT}) + if (MATCHED) + string(STRIP ${CMAKE_MATCH_1} MATCHED) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MATCHED}, CUDA);\n") + endif() + + endif() + endforeach() +endif() diff --git a/paddle/fluid/operators/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h similarity index 100% rename from paddle/fluid/operators/cub_reduce.h rename to paddle/fluid/operators/reduce_ops/cub_reduce.h diff --git a/paddle/fluid/operators/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc similarity index 96% rename from paddle/fluid/operators/reduce_max_op.cc rename to paddle/fluid/operators/reduce_ops/reduce_max_op.cc index 95d3768e1fdf6947659c7b3a1c9d57fad741472a..cb438b4a8057267015c8b3c15dd8468fca5a4b44 100644 --- a/paddle/fluid/operators/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_min_max_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" REGISTER_REDUCE_OP(reduce_max); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/reduce_max_op.cu b/paddle/fluid/operators/reduce_ops/reduce_max_op.cu similarity index 95% rename from paddle/fluid/operators/reduce_max_op.cu rename to paddle/fluid/operators/reduce_ops/reduce_max_op.cu index b21da178f3eeaafa41bde5f64cc4abcf7944b032..832112ede833a06e053dcff5139e82f054b127c4 100644 --- a/paddle/fluid/operators/reduce_max_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_min_max_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" REGISTER_OP_CUDA_KERNEL(reduce_max, ops::ReduceKernel -#include "paddle/fluid/operators/cub_reduce.h" -#include "paddle/fluid/operators/reduce_mean_op.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_mean_op.h b/paddle/fluid/operators/reduce_ops/reduce_mean_op.h similarity index 95% rename from paddle/fluid/operators/reduce_mean_op.h rename to paddle/fluid/operators/reduce_ops/reduce_mean_op.h index 1359679c4767d2032bf3e3a90849ad2a2ef3e829..240c43bc6d0af266e3500c14f894fe30abab728e 100644 --- a/paddle/fluid/operators/reduce_mean_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/fluid/operators/reduce_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_mean_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu similarity index 95% rename from paddle/fluid/operators/reduce_mean_op.part.cu rename to paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu index 4b663bcdca7c20f8802d962a362f429d8eafe9af..9324ec1e1db6f40e463b415e5d2bdc5cfe664ef4 100644 --- a/paddle/fluid/operators/reduce_mean_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu @@ -13,7 +13,7 @@ // limitations under the License. // .part used to speed up nvcc compile -#include "paddle/fluid/operators/reduce_mean_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" REGISTER_OP_CUDA_KERNEL( reduce_mean_grad, ops::ReduceGradKernel #include -#include "paddle/fluid/operators/reduce_op_function.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_op_function.h b/paddle/fluid/operators/reduce_ops/reduce_op_function.h similarity index 100% rename from paddle/fluid/operators/reduce_op_function.h rename to paddle/fluid/operators/reduce_ops/reduce_op_function.h diff --git a/paddle/fluid/operators/reduce_prod_op.cc b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc similarity index 96% rename from paddle/fluid/operators/reduce_prod_op.cc rename to paddle/fluid/operators/reduce_ops/reduce_prod_op.cc index 713728b99757a6f3bb128f665d5576ac64eef8ec..88935107df187da731e5b77bb6c24cd692d2994f 100644 --- a/paddle/fluid/operators/reduce_prod_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_prod_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" REGISTER_REDUCE_OP(reduce_prod); REGISTER_OP_CPU_KERNEL(reduce_prod, diff --git a/paddle/fluid/operators/reduce_prod_op.cu b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu similarity index 95% rename from paddle/fluid/operators/reduce_prod_op.cu rename to paddle/fluid/operators/reduce_ops/reduce_prod_op.cu index d8692afb96e4d5d3206210060684dd12fb4d79a7..4434937f75397d8d5340a94abbd41efa7e7a8d4b 100644 --- a/paddle/fluid/operators/reduce_prod_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_prod_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" REGISTER_OP_CUDA_KERNEL(reduce_prod, ops::ReduceKernel -#include "paddle/fluid/operators/reduce_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu similarity index 90% rename from paddle/fluid/operators/reduce_sum_op.part.cu rename to paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 525633f62a95b2d0d677fcbebe551b75cb2a180d..eb3295731b047391a244bfb598c9d802bca1fc0c 100644 --- a/paddle/fluid/operators/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/cub_reduce.h" -#include "paddle/fluid/operators/reduce_sum_op.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" REGISTER_OP_CUDA_KERNEL( reduce_sum_grad, ops::ReduceGradKernel namespace paddle { diff --git a/paddle/fluid/operators/sequence_concat_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc similarity index 94% rename from paddle/fluid/operators/sequence_concat_op.cu.cc rename to paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc index eb6535235df80a9267b22403ae1f35c6cefb7fe7..7b8043bc4538b486bb73e005769e1585e5c4817e 100644 --- a/paddle/fluid/operators/sequence_concat_op.cu.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cu.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_concat_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h" template using Kernel = diff --git a/paddle/fluid/operators/sequence_concat_op.h b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h similarity index 100% rename from paddle/fluid/operators/sequence_concat_op.h rename to paddle/fluid/operators/sequence_ops/sequence_concat_op.h diff --git a/paddle/fluid/operators/sequence_conv_op.cc b/paddle/fluid/operators/sequence_ops/sequence_conv_op.cc similarity index 99% rename from paddle/fluid/operators/sequence_conv_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_conv_op.cc index 95a21a5d3ee6d8037431083edc25d1cddf05dedb..65cd9edbc7125f605d6fb437a2e056054eb9a6d7 100644 --- a/paddle/fluid/operators/sequence_conv_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_conv_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_conv_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h" #include diff --git a/paddle/fluid/operators/sequence_conv_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_conv_op.cu.cc similarity index 93% rename from paddle/fluid/operators/sequence_conv_op.cu.cc rename to paddle/fluid/operators/sequence_ops/sequence_conv_op.cu.cc index de482b7f10bafc4ac6f3838670e2da9a86374c26..600981b5e96c279329a67b608a8dd94dee7d88ef 100644 --- a/paddle/fluid/operators/sequence_conv_op.cu.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_conv_op.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_conv_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_conv_op.h b/paddle/fluid/operators/sequence_ops/sequence_conv_op.h similarity index 100% rename from paddle/fluid/operators/sequence_conv_op.h rename to paddle/fluid/operators/sequence_ops/sequence_conv_op.h diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc similarity index 97% rename from paddle/fluid/operators/sequence_enumerate_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc index 58e48c228bb34814700fd0f7a3d62ef4b1a435dd..1eebadc2c980ddf1cbaaefef1568dd401d0c77ed 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_enumerate_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu similarity index 97% rename from paddle/fluid/operators/sequence_enumerate_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu index bdc9a615aa9a1ecd99c1f6995361f8c5ff0aa383..28821e7129c1601f1214b0b56696fbf526a2123f 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu @@ -14,7 +14,7 @@ #include #include -#include "paddle/fluid/operators/sequence_enumerate_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { diff --git a/paddle/fluid/operators/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h similarity index 100% rename from paddle/fluid/operators/sequence_enumerate_op.h rename to paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h diff --git a/paddle/fluid/operators/sequence_erase_op.cc b/paddle/fluid/operators/sequence_ops/sequence_erase_op.cc similarity index 97% rename from paddle/fluid/operators/sequence_erase_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_erase_op.cc index 816ba123a6cbf84ec9b321d5d7cfef7fab9749b1..ddda80ee0824e261b0d737f86e03866d5fdfd77a 100644 --- a/paddle/fluid/operators/sequence_erase_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_erase_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_erase_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_erase_op.h" #include namespace paddle { diff --git a/paddle/fluid/operators/sequence_erase_op.cu b/paddle/fluid/operators/sequence_ops/sequence_erase_op.cu similarity index 98% rename from paddle/fluid/operators/sequence_erase_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_erase_op.cu index 3a58e47f1132cd1ac85584b2470e8c6cddcfb28a..619c40dbd10ad6b538f2d4e3567966b222fc5e2d 100644 --- a/paddle/fluid/operators/sequence_erase_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_erase_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/sequence_erase_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_erase_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { diff --git a/paddle/fluid/operators/sequence_erase_op.h b/paddle/fluid/operators/sequence_ops/sequence_erase_op.h similarity index 100% rename from paddle/fluid/operators/sequence_erase_op.h rename to paddle/fluid/operators/sequence_ops/sequence_erase_op.h diff --git a/paddle/fluid/operators/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_expand_as_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index 33c1e1c973c80ba3943924331380d35b225ac800..3b79d0c71975bb740b4085ce80f7d95b65f600c1 100644 --- a/paddle/fluid/operators/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_expand_as_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_expand_as_op.cu b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu similarity index 98% rename from paddle/fluid/operators/sequence_expand_as_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu index 7357f5ae6e732f28307af65d1f1b6b3cbed1f640..998bf82ab1ddcd815491de95a3f7cf987036ee65 100644 --- a/paddle/fluid/operators/sequence_expand_as_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/sequence_expand_as_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { diff --git a/paddle/fluid/operators/sequence_expand_as_op.h b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h similarity index 100% rename from paddle/fluid/operators/sequence_expand_as_op.h rename to paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc similarity index 99% rename from paddle/fluid/operators/sequence_expand_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index 944c7f85e5f43679e1875fcce813382be2ba5526..c07e6962e673ceb274ef31cbf492f378ae696137 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_expand_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cu similarity index 98% rename from paddle/fluid/operators/sequence_expand_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_expand_op.cu index 550677b22694085059e914678a5361d914b455bc..afc08c7b3f6596efd3b6e0b74c17aa3c9268c47d 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/sequence_expand_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_ops/sequence_expand_op.h similarity index 100% rename from paddle/fluid/operators/sequence_expand_op.h rename to paddle/fluid/operators/sequence_ops/sequence_expand_op.h diff --git a/paddle/fluid/operators/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc similarity index 95% rename from paddle/fluid/operators/sequence_mask_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index 798211f481659eb71248f7a6210e6522273d387f..7fc506aab4d3c6861282b68b09fdcb5fd8055f77 100644 --- a/paddle/fluid/operators/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_mask_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h" REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp, paddle::operators::SequenceMaskOpMaker, diff --git a/paddle/fluid/operators/sequence_mask_op.cu b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cu similarity index 94% rename from paddle/fluid/operators/sequence_mask_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_mask_op.cu index 2ad23774579533b62b9189c1564ad7c7db5c298a..e963ce610e2c147d66087a1df59f67a04d899ccc 100644 --- a/paddle/fluid/operators/sequence_mask_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_mask_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h" REGISTER_OP_CUDA_KERNEL( sequence_mask, diff --git a/paddle/fluid/operators/sequence_mask_op.h b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h similarity index 100% rename from paddle/fluid/operators/sequence_mask_op.h rename to paddle/fluid/operators/sequence_ops/sequence_mask_op.h diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc similarity index 99% rename from paddle/fluid/operators/sequence_pad_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index 4583b26256ba2e084bf7477c54d468df860d9b43..23c7bf7cea830bb0ccf5e81f99130043c2d5f80b 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_pad_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_pad_op.cu b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cu similarity index 95% rename from paddle/fluid/operators/sequence_pad_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_pad_op.cu index ff8f81a2f0ec4a72befc3be2a5fc48c3a586c824..7fc64a530ef5442ae927faac96ad92a4126febcd 100644 --- a/paddle/fluid/operators/sequence_pad_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_pad_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_ops/sequence_pad_op.h similarity index 100% rename from paddle/fluid/operators/sequence_pad_op.h rename to paddle/fluid/operators/sequence_ops/sequence_pad_op.h diff --git a/paddle/fluid/operators/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_pool_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 7e80b8db5e90730e2df420466a33362620e15730..44b09bf7c2c776cdc455a8706cb2b2251f3be509 100644 --- a/paddle/fluid/operators/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_pool_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" #include namespace paddle { diff --git a/paddle/fluid/operators/sequence_pool_op.cu b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu similarity index 93% rename from paddle/fluid/operators/sequence_pool_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_pool_op.cu index 2bf0697af3c74ee922a832fecaa2cd2399a06849..63cd47a38a0ff6413c430c6be6284c5f4bfc2595 100644 --- a/paddle/fluid/operators/sequence_pool_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/fluid/operators/sequence_pool_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h similarity index 100% rename from paddle/fluid/operators/sequence_pool_op.h rename to paddle/fluid/operators/sequence_ops/sequence_pool_op.h diff --git a/paddle/fluid/operators/sequence_reshape_op.cc b/paddle/fluid/operators/sequence_ops/sequence_reshape_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_reshape_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_reshape_op.cc index 31d28d723498892f287246ba228df757d5b9f6c8..5421f35662b3b0a6a61748ac0b6b5f718d213b73 100644 --- a/paddle/fluid/operators/sequence_reshape_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_reshape_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_reshape_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_reshape_op.h" #include "paddle/fluid/framework/ddim.h" namespace paddle { diff --git a/paddle/fluid/operators/sequence_reshape_op.cu b/paddle/fluid/operators/sequence_ops/sequence_reshape_op.cu similarity index 95% rename from paddle/fluid/operators/sequence_reshape_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_reshape_op.cu index 232e031c0b022497d9e5141750dbf8fccffc7615..38bc599165d5f84f67e2fe08bf96ebef4b03d8a4 100644 --- a/paddle/fluid/operators/sequence_reshape_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_reshape_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_reshape_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_reshape_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_reshape_op.h b/paddle/fluid/operators/sequence_ops/sequence_reshape_op.h similarity index 100% rename from paddle/fluid/operators/sequence_reshape_op.h rename to paddle/fluid/operators/sequence_ops/sequence_reshape_op.h diff --git a/paddle/fluid/operators/sequence_reverse_op.cc b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.cc similarity index 94% rename from paddle/fluid/operators/sequence_reverse_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_reverse_op.cc index 1428cca1a6bf6150594f9cb72dbf00cd0eff7df5..dfbbf5f156983189ac1ab82fbff51d7eb4844f9a 100644 --- a/paddle/fluid/operators/sequence_reverse_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_reverse_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_reverse_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sequence_reverse_op.cu b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.cu similarity index 94% rename from paddle/fluid/operators/sequence_reverse_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_reverse_op.cu index ce65f4799e8661adca60d212eaa9c3f0f92c4c29..0a59ed7f9fee07bc3b12909973535f31ef049a4a 100644 --- a/paddle/fluid/operators/sequence_reverse_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/sequence_reverse_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_reverse_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sequence_reverse_op.h b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h similarity index 100% rename from paddle/fluid/operators/sequence_reverse_op.h rename to paddle/fluid/operators/sequence_ops/sequence_reverse_op.h diff --git a/paddle/fluid/operators/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_scatter_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index adb81bffccb50069b3a2e5f391f3fdfde231b2be..c49d1ccb18427a1ec3c45f326b57bce32c60e1e2 100644 --- a/paddle/fluid/operators/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_scatter_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_scatter_op.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/gather.h" diff --git a/paddle/fluid/operators/sequence_scatter_op.h b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.h similarity index 100% rename from paddle/fluid/operators/sequence_scatter_op.h rename to paddle/fluid/operators/sequence_ops/sequence_scatter_op.h diff --git a/paddle/fluid/operators/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_slice_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index df9243dc04c584d70dfa6ca78d5fac8423796466..6f84023e26dbf1280d9622946ab20184fb835be1 100644 --- a/paddle/fluid/operators/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_slice_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_slice_op.cu b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cu similarity index 92% rename from paddle/fluid/operators/sequence_slice_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_slice_op.cu index 059e802df0ebdba68f758decfb8b54a362996335..1e4a1b8323dbaacdf3f74c33e7aa4484d9be2478 100644 --- a/paddle/fluid/operators/sequence_slice_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_slice_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_slice_op.h b/paddle/fluid/operators/sequence_ops/sequence_slice_op.h similarity index 100% rename from paddle/fluid/operators/sequence_slice_op.h rename to paddle/fluid/operators/sequence_ops/sequence_slice_op.h diff --git a/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc similarity index 100% rename from paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc rename to paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc diff --git a/paddle/fluid/operators/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_softmax_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index ada3e0c8dbba38729c2b9c8b02335327835f2ef4..644a5bebc18886a2ac9210576f1c2251ad5ad0be 100644 --- a/paddle/fluid/operators/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_softmax_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" #include namespace paddle { diff --git a/paddle/fluid/operators/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu similarity index 98% rename from paddle/fluid/operators/sequence_softmax_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index e94ceaa170131e8bce7d1574b27f0baeaa8d1ffc..cc5e9821903fb7a726f52177df1d17757f697411 100644 --- a/paddle/fluid/operators/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include // NOLINT -#include "paddle/fluid/operators/sequence_softmax_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_softmax_op.h b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.h similarity index 100% rename from paddle/fluid/operators/sequence_softmax_op.h rename to paddle/fluid/operators/sequence_ops/sequence_softmax_op.h diff --git a/paddle/fluid/operators/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc similarity index 98% rename from paddle/fluid/operators/sequence_unpad_op.cc rename to paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index e633e378a226ece8adea2e150cc6c1e9aa874331..2cf508e0b707ecc986886e72e5d42fde3c84894d 100644 --- a/paddle/fluid/operators/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_unpad_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_unpad_op.cu b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cu similarity index 95% rename from paddle/fluid/operators/sequence_unpad_op.cu rename to paddle/fluid/operators/sequence_ops/sequence_unpad_op.cu index 75248372237ec2cb23122f6b16e64f6ce750ebf9..bf54f77f5b55cf7eb19873e352359c028207308a 100644 --- a/paddle/fluid/operators/sequence_unpad_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_unpad_op.h" +#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/sequence_unpad_op.h b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h similarity index 100% rename from paddle/fluid/operators/sequence_unpad_op.h rename to paddle/fluid/operators/sequence_ops/sequence_unpad_op.h diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index f109dd685c87ab1b0776a855bb5f510eab1f5526..c047bc78ee315201d25a7294b7dae7d766a6c968 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -86,7 +86,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { .GreaterThan(1); AddComment(R"DOC( reorg operator used in Yolo v2. - The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize, + The equation is: C2 = C1/blocksize * blocksize, W2 = W1 * blocksize + offset % blocksize, H2 = H1 * blocksize + offset / blocksize, Reshape Input(X) into the shape according to Attr(blocksize). The data in Input(X) are unchanged. diff --git a/paddle/fluid/operators/tensorrt/CMakeLists.txt b/paddle/fluid/operators/tensorrt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..eee0b90fbae216e804e62993313796e914fcef5a --- /dev/null +++ b/paddle/fluid/operators/tensorrt/CMakeLists.txt @@ -0,0 +1,5 @@ +op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) +file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n") +nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc + DEPS tensorrt_engine_op + analysis) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc similarity index 96% rename from paddle/fluid/operators/tensorrt_engine_op.cc rename to paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc index 41a5786fe8c3295390144732221280e152d0a15a..3cf2ce3c7ef87dcf75548f7d9c3a55d06ed765e8 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc @@ -17,7 +17,7 @@ #include #include -#include "paddle/fluid/operators/tensorrt_engine_op.h" +#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h" namespace paddle { diff --git a/paddle/fluid/operators/tensorrt_engine_op.cu.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc similarity index 93% rename from paddle/fluid/operators/tensorrt_engine_op.cu.cc rename to paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc index e1ddfde6d51ef719ca0b89cf286b176195ee682a..cbe1b426f65386e722a7b02ec1fdfdf75bfd770c 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cu.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/tensorrt_engine_op.h" +#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h similarity index 100% rename from paddle/fluid/operators/tensorrt_engine_op.h rename to paddle/fluid/operators/tensorrt/tensorrt_engine_op.h diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc similarity index 99% rename from paddle/fluid/operators/tensorrt_engine_op_test.cc rename to paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index e21101e8d12f210af08284dbcebe5c14c1af6dd3..56bdd6c2f2801967829f2baf889b5517a1d9d8d9 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/tensorrt_engine_op.h" +#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h" #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/operators/warpctc_cudnn_op.cu.cc b/paddle/fluid/operators/warpctc_cudnn_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a764d59410c90535dbda0b3f11e89ae9bf578c04 --- /dev/null +++ b/paddle/fluid/operators/warpctc_cudnn_op.cu.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/warpctc_op.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +#if CUDNN_VERSION >= 7001 +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedCTCLossDescriptor = platform::ScopedCTCLossDescriptor; +using DataLayout = platform::DataLayout; + +template +class CudnnCTCKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // =====================Copied code from warpctc=========================== + auto* logits = ctx.Input("Logits"); + auto* label = ctx.Input("Label"); + auto* warpctc_grad = ctx.Output("WarpCTCGrad"); + auto* loss = ctx.Output("Loss"); + + const size_t level = 0; + + auto logits_lod = framework::ToAbsOffset(logits->lod()); + auto logits_dims = logits->dims(); + PADDLE_ENFORCE_EQ(logits_dims[0], + static_cast(logits_lod[level].back()), + "The first dimension of Input(Logits) should be equal to " + "the sum of all sequences' lengths."); + + auto label_lod = framework::ToAbsOffset(label->lod()); + auto label_dims = label->dims(); + PADDLE_ENFORCE_EQ( + label_dims[0], label->numel(), + "The width of each timestep in Input(Label) should be 1."); + + const size_t num_sequences = logits_lod[level].size() - 1; + PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1, + "The number of sequences of Input(Logits) should be " + "equal to that of Input(Label)."); + PADDLE_ENFORCE_LE(num_sequences, 256, + "The labelLengths must less than 256 for cudnn call."); + + const size_t sequence_width = logits->numel() / logits_dims[0]; + auto loss_dims = + framework::make_ddim({static_cast(num_sequences), 1}); + + // NOTE: cudnn takes softmax input, calculate softmax first, then do padding + auto& dev_ctx = ctx.template device_context(); + LoDTensor softmax_logits; + softmax_logits.mutable_data(logits->dims(), ctx.GetPlace()); + softmax_logits.set_lod(logits_lod); + int rank = logits->dims().size(); + Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1); + Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1); + math::SoftmaxFunctor()(dev_ctx, &in_2d, &out_2d); + + // ctc needs sequences data stored in transposed padding format + // logits and grad using padding data of layout 'TNC' + // T: max_sequence_length + // N: batch_size (num_sequences) + // C: width + LoDTensor warpctc_logits; + const size_t max_sequence_length = + math::MaximumSequenceLength(logits_lod[level]); + auto warpctc_logits_dims = + framework::make_ddim({static_cast(max_sequence_length), + static_cast(num_sequences), + static_cast(sequence_width)}); + warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); + + LoDTensor cpu_pad_value; + T* pad_value_data = + cpu_pad_value.mutable_data({1}, platform::CPUPlace()); + *pad_value_data = static_cast(0); + LoDTensor pad_value; + if (platform::is_cpu_place(ctx.GetPlace())) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value); + } + + math::PaddingLoDTensorFunctor()( + ctx.template device_context(), softmax_logits, + &warpctc_logits, pad_value, -1, 0, false /* norm_by_times */, + math::kLengthBatchWidth); + const T* warpctc_logits_data = warpctc_logits.data(); + + std::vector warpctc_label_lengths(num_sequences); + std::vector warpctc_logits_lengths(num_sequences); + + for (size_t i = 0; i < num_sequences; ++i) { + warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i]; + warpctc_logits_lengths[i] = + logits_lod[level][i + 1] - logits_lod[level][i]; + } + + T* warpctc_grad_data = + warpctc_grad->mutable_data(warpctc_logits.dims(), ctx.GetPlace()); + + math::SetConstant()( + ctx.template device_context(), warpctc_grad, + static_cast(0)); + + Tensor warpctc_label; + TensorCopySync(*label, platform::CPUPlace(), &warpctc_label); + const int* warpctc_label_data = warpctc_label.data(); + // ======================================================================== + + ScopedTensorDescriptor logits_desc; + ScopedTensorDescriptor grad_desc; + ScopedCTCLossDescriptor ctcloss_desc; + // layout here doesn't have effect. + DataLayout layout = DataLayout::kNCHW; + + auto cu_logits_desc = logits_desc.descriptor( + layout, framework::vectorize2int(warpctc_logits.dims())); + auto cu_grad_desc = grad_desc.descriptor( + layout, framework::vectorize2int(warpctc_grad->dims())); + auto cu_ctcloss_desc = ctcloss_desc.descriptor(); + + auto handle = dev_ctx.cudnn_handle(); + size_t workspace_size; + + CUDNN_ENFORCE(platform::dynload::cudnnGetCTCLossWorkspaceSize( + handle, cu_logits_desc, cu_grad_desc, warpctc_label_data, + warpctc_label_lengths.data(), warpctc_logits_lengths.data(), + CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size)); + + T* loss_data = loss->mutable_data(loss_dims, ctx.GetPlace()); + + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss( + handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data, + warpctc_label_lengths.data(), warpctc_logits_lengths.data(), + loss_data, cu_grad_desc, warpctc_grad_data, + CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace, + workspace_size)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size); + } +}; + +template +class CudnnCTCGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* warpctc_grad = ctx.Input("WarpCTCGrad"); + auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); + const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); + + logits_grad->mutable_data(ctx.GetPlace()); + bool norm_by_times = ctx.Attr("norm_by_times"); + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *warpctc_grad, + logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth); + + const T* loss_grad_data = loss_grad->data(); + math::ScaleLoDTensorFunctor()( + ctx.template device_context(), loss_grad_data, + logits_grad); + } +}; + +#endif +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +#if CUDNN_VERSION >= 7001 +REGISTER_OP_KERNEL( + warpctc, CUDNN, plat::CUDAPlace, + ops::CudnnCTCKernel); +REGISTER_OP_KERNEL( + warpctc_grad, CUDNN, plat::CUDAPlace, + ops::CudnnCTCGradKernel); +#endif diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index e06c8c962f45a4e91b7efed7431571f0fc6870a3..6a257cebf523bfeb1951b709480140e733126f6a 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -14,6 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/warpctc_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + namespace paddle { namespace operators { @@ -45,9 +49,16 @@ class WarpCTCOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; + } +#endif + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType( framework::ToDataType(ctx.Input("Logits")->type()), - ctx.device_context()); + ctx.device_context(), layout_, library_); } }; @@ -86,6 +97,10 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { "normalize the gradients by the number of time-step, " "which is also the sequence's length.") .SetDefault(false); + AddAttr("use_cudnn", + "(bool, default: false), whether to " + "use cudnn kernel.") + .SetDefault(false); AddComment(R"DOC( An operator integrating the open-source [warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 07bb02be1962f758e50cab1f27de43e89f3953c3..f174a7bc4867634df01367cc0091b0ec55849985 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -380,5 +380,28 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { return use_cudnn; } +#if CUDNN_VERSION >= 7001 +class ScopedCTCLossDescriptor { + public: + ScopedCTCLossDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateCTCLossDescriptor(&desc_)); + } + ~ScopedCTCLossDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroyCTCLossDescriptor(desc_)); + } + + template + inline cudnnCTCLossDescriptor_t descriptor() { + PADDLE_ENFORCE( + dynload::cudnnSetCTCLossDescriptor(desc_, CudnnDataType::type)); + return desc_; + } + + private: + cudnnCTCLossDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedCTCLossDescriptor); +}; +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index c26143d2f2780f3042f66b99808c6b85866f9dc4..db2e28bc911a05a077c786c57c4f0e4c34bd61f7 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -154,7 +154,13 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #if CUDNN_VERSION >= 7001 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ __macro(cudnnSetConvolutionGroupCount); \ - __macro(cudnnSetConvolutionMathType); + __macro(cudnnSetConvolutionMathType); \ + __macro(cudnnCreateCTCLossDescriptor); \ + __macro(cudnnDestroyCTCLossDescriptor); \ + __macro(cudnnGetCTCLossDescriptor); \ + __macro(cudnnSetCTCLossDescriptor); \ + __macro(cudnnGetCTCLossWorkspaceSize); \ + __macro(cudnnCTCLoss); CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index f97a6d0d989a8debeac43287a7e118ed24a4ea05..e486930be3fc87c1cfb4723ea358be4b58e89f8a 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -11,12 +11,12 @@ if(WITH_PYTHON) hip_library(paddle_pybind SHARED SRCS ${PYBIND_SRCS} DEPS ${PYBIND_DEPS} - ${GLOB_OP_LIB}) + ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) else() cc_library(paddle_pybind SHARED SRCS ${PYBIND_SRCS} DEPS ${PYBIND_DEPS} - ${GLOB_OP_LIB}) + ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) if(NOT APPLE AND NOT ANDROID AND NOT WIN32) target_link_libraries(paddle_pybind rt) endif(NOT APPLE AND NOT ANDROID AND NOT WIN32) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f60f3731636d805be06319e3f082cbed35be4940..af96f5de4f055ea49fee03612cfc1a1b3c6af2f8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4187,7 +4187,7 @@ def ctc_greedy_decoder(input, blank, name=None): return ctc_out -def warpctc(input, label, blank=0, norm_by_times=False): +def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False): """ An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) @@ -4212,6 +4212,7 @@ def warpctc(input, label, blank=0, norm_by_times=False): by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if warpctc layer was follewed by a mean_op. + use_cudnn (bool, default false): Whether to use cudnn. Returns: Variable: The Connectionist Temporal Classification (CTC) loss, @@ -4235,8 +4236,11 @@ def warpctc(input, label, blank=0, norm_by_times=False): 'Label': [label]}, outputs={'WarpCTCGrad': [grad_out], 'Loss': [loss_out]}, - attrs={'blank': blank, - 'norm_by_times': norm_by_times}) + attrs={ + 'blank': blank, + 'norm_by_times': norm_by_times, + 'use_cudnn': use_cudnn + }) return loss_out @@ -4309,7 +4313,10 @@ def nce(input, param_attr=None, bias_attr=None, num_neg_samples=None, - name=None): + name=None, + sampler="uniform", + custom_dist=None, + seed=0): """ ${comment} @@ -4332,6 +4339,14 @@ def nce(input, num_neg_samples (int): ${num_neg_samples_comment} name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. + sampler (str): The sampler used to sample class from negtive classes. + It can be 'uniform', 'log_uniform' or 'custom_dist'. + default: 'uniform'. + custom_dist (Variable): A tensor with shape [num_total_classes]. + It is used when sampler is set to 'custom_dist'. + custom_dist[i] is the probsbility of i-th class to be sampled. + default: None. + seed (int): The seed used in sampler. default: 0. Returns: Variable: The output nce loss. @@ -4361,6 +4376,16 @@ def nce(input, loss = layers.nce(input=embs, label=words[label_word], num_total_classes=dict_size, param_attr='nce.w', bias_attr='nce.b') + + #or use custom distribution + dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32")) + loss = layers.nce(input=embs, label=words[label_word], + num_total_classes=5, param_attr='nce.w', + bias_attr='nce.b', + num_neg_samples=3, + sampler="custom_dist", + custom_dist=dist) + """ helper = LayerHelper('nce', **locals()) assert isinstance(input, Variable) @@ -4395,9 +4420,31 @@ def nce(input, else: num_neg_samples = int(num_neg_samples) + inputs = { + 'Input': input, + 'Label': label, + 'Weight': w, + 'Bias': b, + 'SampleWeight': sample_weight if sample_weight is not None else [] + } + + if sampler == "uniform": + sampler = 0 + elif sampler == "log_uniform": + sampler = 1 + elif sampler == "custom_dist": + assert custom_dist is not None + assert isinstance(custom_dist, Variable) + inputs['CustomDistribution'] = custom_dist + sampler = 2 + else: + raise Exception("Unsupported sampler type.") + attrs = { 'num_total_classes': int(num_total_classes), - 'num_neg_samples': num_neg_samples + 'num_neg_samples': num_neg_samples, + 'seed': seed, + 'sampler': sampler } helper.append_op( diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index 42ab9b231153f7ede7b8f8dd4e754f8cc92f65fe..3d40b762281ae09d3214f2d2bc496c4966984866 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -38,7 +38,7 @@ depth = 8 mix_hidden_lr = 1e-3 IS_SPARSE = True -PASS_NUM = 1 +PASS_NUM = 2 BATCH_SIZE = 10 embedding_name = 'emb' @@ -196,7 +196,7 @@ def train(use_cuda, save_dirname=None, is_local=True): print("second per batch: " + str((time.time( ) - start_time) / batch_id)) # Set the threshold low to speed up the CI test - if float(cost) < 60.0: + if float(cost) < 80.0: if save_dirname is not None: # TODO(liuyiqun): Change the target to crf_decode fluid.io.save_inference_model(save_dirname, [ @@ -208,6 +208,10 @@ def train(use_cuda, save_dirname=None, is_local=True): batch_id = batch_id + 1 + raise RuntimeError( + "This model should save_inference_model and return, but not reach here, please check!" + ) + if is_local: train_loop(fluid.default_main_program()) else: diff --git a/python/paddle/fluid/tests/unittests/test_infer_shape.py b/python/paddle/fluid/tests/unittests/test_infer_shape.py index fdff22cacc28731a91ff4fd17407bd9edbdd9d8b..9d5e064e6adabe09094350db2976f83d835520eb 100644 --- a/python/paddle/fluid/tests/unittests/test_infer_shape.py +++ b/python/paddle/fluid/tests/unittests/test_infer_shape.py @@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase): mul_op_desc.infer_shape(block) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) + def test_expand_op(self): + prog = core.ProgramDesc() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + shape = [-1, 20] + expand_times = [3, 1] + + # prepare input/output + x1 = block.var(six.b("x")) + x1.set_type(core.VarDesc.VarType.LOD_TENSOR) + x1.set_shape(shape) + + out = block.var(six.b("out")) + out.set_type(core.VarDesc.VarType.LOD_TENSOR) + + # prepare the operator + sum_op_desc = block.append_op() + sum_op_desc.set_type("expand") + sum_op_desc.set_input("X", ["x"]) + sum_op_desc.set_output("Out", ["out"]) + sum_op_desc._set_attr('expand_times', expand_times) + + sum_op_desc.check_attrs() + sum_op_desc.infer_shape(block) + self.assertEqual(out.shape(), shape) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index 0745bd274f73715b6fdec236819b8d89827e1346..c01fdd5dddc139bdefc07b91e9816d62febd7f20 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -68,7 +68,9 @@ class TestNCE(OpTest): self.attrs = { 'num_total_classes': num_classes, 'num_neg_samples': num_neg_samples, - 'custom_neg_classes': list(range(num_neg_samples)) + 'custom_neg_classes': list(range(num_neg_samples)), + 'seed': 0, + 'sampler': 0 } self.inputs = { 'Input': input, diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index 5e3aa13546d0c4fdcde4a3d6378d5a1748327814..ec0592baa22b6215035d2b9ad80e00081eb59126 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -183,6 +183,7 @@ class TestWarpCTCOp(OpTest): self.labels_lod = [[3, 1, 4, 4]] self.blank = self.num_classes - 1 self.norm_by_times = False + self.use_cudnn = False def setUp(self): self.op_type = "warpctc" @@ -215,7 +216,11 @@ class TestWarpCTCOp(OpTest): "Label": (labels, self.labels_lod) } self.outputs = {"Loss": loss} - self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times} + self.attrs = { + "blank": self.blank, + "norm_by_times": self.norm_by_times, + "use_cudnn": self.use_cudnn + } def test_check_output(self): self.check_output() @@ -233,6 +238,22 @@ class TestWarpCTCOpCase1(TestWarpCTCOp): self.labels_lod = [[3, 1, 4, 4]] self.blank = 0 self.norm_by_times = False + self.use_cudnn = False + + +class TestCudnnCTCOp(TestWarpCTCOp): + def config(self): + self.batch_size = 4 + self.num_classes = 8 + self.logits_lod = [[4, 1, 3, 3]] + self.labels_lod = [[3, 1, 4, 4]] + self.blank = 0 + self.norm_by_times = False + self.use_cudnn = True + + def test_check_grad(self): + self.outputs['WarpCTCGrad'] = self.gradient + self.check_grad(["Logits"], "Loss", max_relative_error=0.01) if __name__ == "__main__": diff --git a/python/paddle/fluid/transpiler/details/checkport.py b/python/paddle/fluid/transpiler/details/checkport.py index 7bad4b427a2d53bd14c7a1f870ce74a883158d04..6b78ceeaeec4d9b3db6524a5b5e939f88267340c 100644 --- a/python/paddle/fluid/transpiler/details/checkport.py +++ b/python/paddle/fluid/transpiler/details/checkport.py @@ -34,6 +34,7 @@ def wait_server_ready(endpoints): """ while True: all_ok = True + not_ready_endpoints = [] for ep in endpoints: ip_port = ep.split(":") with closing(socket.socket(socket.AF_INET, @@ -42,8 +43,11 @@ def wait_server_ready(endpoints): result = sock.connect_ex((ip_port[0], int(ip_port[1]))) if result != 0: all_ok = False + not_ready_endpoints.append(ep) if not all_ok: sys.stderr.write("pserver not ready, wait 3 sec to retry...\n") + sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + + "\n") sys.stderr.flush() time.sleep(3) else: