提交 a5cb79c6 编写于 作者: B barrierye

commit data_feed for pull

set(PART_CUDA_KERNEL_FILES)
function(op_library TARGET)
# op_library is a function to create op library. The interface is same as
# cc_library. But it handle split GPU/CPU code and link some common library
# for ops.
set(cc_srcs)
set(cu_srcs)
set(hip_cu_srcs)
set(miopen_hip_cc_srcs)
set(cu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(CUDNN_FILE)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
set(op_common_deps operator op_registry math_function)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(pybind_flag 0)
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
list(LENGTH op_library_SRCS op_library_SRCS_len)
if (${op_library_SRCS_len} EQUAL 0)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
list(APPEND cc_srcs ${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu)
list(APPEND hip_cu_srcs ${TARGET}.hip.cu)
endif()
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
endif()
if(WITH_AMD_GPU)
string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc)
list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc)
endif()
endif()
if(WITH_MKLDNN)
string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc)
list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc)
endif()
endif()
else()
foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.hip.cu$")
list(APPEND hip_cu_srcs ${src})
elseif (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
list(APPEND cudnn_cu_cc_srcs ${src})
elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$")
list(APPEND miopen_hip_cc_srcs ${src})
elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
list(APPEND mkldnn_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
else()
message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
endif()
endforeach()
endif()
list(LENGTH cc_srcs cc_srcs_len)
if (${cc_srcs_len} EQUAL 0)
message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
endif()
if (WIN32)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
endforeach()
endif(WIN32)
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
list(LENGTH op_library_DEPS op_library_DEPS_len)
if (${op_library_DEPS_len} GREATER 0)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif()
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
elseif (WITH_AMD_GPU)
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
endif()
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
endforeach()
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}")
if (one_register STREQUAL "")
string(REPLACE "_op" "" TARGET "${TARGET}")
else ()
string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
string(REPLACE "," "" TARGET "${TARGET}")
endif()
# pybind USE_NO_KERNEL_OP
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
string(REPLACE "_op" "" TARGET "${TARGET}")
if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
# pybind USE_CPU_ONLY_OP
list(LENGTH cu_srcs cu_srcs_len)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
list(LENGTH hip_cu_srcs hip_cu_srcs_len)
list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MIOPEN
if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
endif()
# pybind USE_OP
if (${pybind_flag} EQUAL 0)
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "fake_quantize")
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
endif()
endfunction()
function(register_operators)
set(options "")
set(oneValueArgs "")
set(multiValueArgs EXCLUDES DEPS)
cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
string(REPLACE "_mkldnn" "" OPS "${OPS}")
string(REPLACE ".cc" "" OPS "${OPS}")
list(REMOVE_DUPLICATES OPS)
list(LENGTH register_operators_DEPS register_operators_DEPS_len)
foreach(src ${OPS})
list(FIND register_operators_EXCLUDES ${src} _index)
if (${_index} EQUAL -1)
if (${register_operators_DEPS_len} GREATER 0)
op_library(${src} DEPS ${register_operators_DEPS})
else()
op_library(${src})
endif()
endif()
endforeach()
endfunction()
...@@ -93,11 +93,11 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', ...@@ -93,11 +93,11 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized',
paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)) paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None))
paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)) paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None))
paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times'], varargs=None, keywords=None, defaults=(0, False)) paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False))
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
......
...@@ -36,7 +36,7 @@ add_subdirectory(details) ...@@ -36,7 +36,7 @@ add_subdirectory(details)
endif (NOT WIN32) endif (NOT WIN32)
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_param SRCS async_executor_param.proto) proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto) cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
if(NOT WIN32)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler) shape_inference data_transform lod_tensor profiler)
endif(NOT WIN32)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
if (NOT WIN32)
py_proto_compile(framework_py_proto SRCS framework.proto) py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
if (NOT WIN32) add_custom_command(TARGET framework_py_proto POST_BUILD
add_custom_command(TARGET framework_py_proto POST_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/ COMMENT "Copy generated python proto into directory paddle/fluid/proto."
COMMENT "Copy generated python proto into directory paddle/fluid/proto." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else(NOT WIN32)
string(REPLACE "/" "\\" proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/")
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND copy /Y *.py ${proto_dstpath}
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif(NOT WIN32) endif(NOT WIN32)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
...@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE) ...@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
if(NOT WIN32) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
else(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif(NOT WIN32)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
...@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
endif() # NOT WIN32 endif() # NOT WIN32
cc_library(async_executor cc_library(async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc
DEPS op_registry device_context scope framework_proto glog DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass lod_rank_table feed_fetch_method graph_to_program_pass
async_executor_param) async_executor_proto)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -36,43 +36,13 @@ limitations under the License. */ ...@@ -36,43 +36,13 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
bool AsyncExecutor::workers_initialized_ = false;
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<Scope>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
static void ReadBinaryFile(const std::string& filename, static void ReadBinaryFile(const std::string& filename,
std::string* content) { std::string* content) {
std::string &contents = *content; std::string &contents = *content;
...@@ -139,343 +109,100 @@ static void SaveModel( ...@@ -139,343 +109,100 @@ static void SaveModel(
} }
} // end SaveModel } // end SaveModel
void ExecutorThreadWorker::Reset() { AsyncExecutor::AsyncExecutor(Scope& scope, const platform::Place& place)
inspect_values_.clear(); : root_scope_(scope), place_(place) {}
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
}
}
}
void ExecutorThreadWorker::SetDataFeed(DataFeed& datafeed) {
if (typeid(datafeed) == typeid(TextClassDataFeed)) {
local_reader_.reset(
new TextClassDataFeed(dynamic_cast<TextClassDataFeed &>(datafeed)));
local_reader_->SetThreadId(thread_id_);
}
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = local_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::SetInspectVarNames( void AsyncExecutor::CreateThreads(
const std::vector<std::string>& inspect_var_names) { ExecutorThreadWorker* worker,
inspect_var_names_.clear(); const ProgramDesc& main_program,
inspect_var_names_.insert(inspect_var_names_.end(), const std::shared_ptr<DataFeed>& reader,
inspect_var_names.begin(), inspect_var_names.end()); const std::vector<std::string>& fetch_var_names,
Scope& root_scope,
const int thread_index) {
worker->SetThreadId(thread_index);
worker->SetRootScope(&root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
} }
void ExecutorThreadWorker::SetModelParamNames( void AsyncExecutor::CheckFiles(
const std::vector<std::string>& param_names) { const std::vector<std::string>& files) {
model_param_names_ = param_names; // function for user to check file formats
} // should be exposed to users
void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47
};
unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i];
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(proc, &mask);
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i;
} else {
CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i
<< " is running on processor " << proc
<< "...";
}
}
}
}
void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train";
SetDevice();
int inspect_var_num = inspect_var_names_.size();
inspect_values_.clear();
inspect_values_.resize(inspect_var_num, 0);
local_reader_->WaitNextEpoch();
int epoch = local_reader_->GetCurrentEpoch();
LOG(ERROR) << "epoch: " << epoch;
int batch_num = 1;
while (true) {
const char *file = local_reader_->PickOneFile();
if (file == NULL) {
break;
}
if (!local_reader_->SetFile(file)) {
break;
}
while (true) {
bool flag = local_reader_->ReadBatch();
if (!flag) {
break;
}
for (unsigned int i = 0; i < ops_.size(); ++i) {
ops_[i]->Run(*thread_scope_, place_);
}
batch_num++;
float avg_inspect = 0.0;
for (int i = 0; i < inspect_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(inspect_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
inspect_values_[i] += avg_inspect;
}
thread_scope_->DropKids();
}
local_reader_->UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << epoch + 1
<< " called: " << memory::memory_usage(place_);
}
for (int i = 0; i < inspect_var_num; ++i) {
inspect_values_[i] /= batch_num;
std::string var = inspect_var_names_[i].substr(
0,
inspect_var_names_[i].find_first_of("_"));
LOG(ERROR) << "mean " << var.c_str()
<< " of epoch " << i + 1 << ": " << inspect_values_[i];
}
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0], sizeof(modelfile), "%s_epoch%d.model",
model_prefix_.c_str(), epoch);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG(ERROR) << "Going to save model " << modelfile;
SaveModel(main_program_,
thread_scope_,
model_param_names_,
model_filename,
true);
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
thread_id_ = tid;
}
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
AsyncExecutor::AsyncExecutor(ProgramDesc& main_program,
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed,
unsigned int thread_num,
const platform::Place& place)
: thread_num_(thread_num),
place_(place),
main_program_(main_program),
data_feed_(data_feed) {
model_param_names_.clear();
model_param_names_.insert(model_param_names_.end(),
param_names.begin(),
param_names.end());
}
void AsyncExecutor::InitRootScope(Scope* scope) {
root_scope_ = scope;
}
void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
} }
void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix; model_prefix_ = model_prefix;
} }
void AsyncExecutor::RunStartupProgram(const ProgramDesc& program, std::vector<float> AsyncExecutor::RunFromFile(
Scope* scope) { const ProgramDesc& main_program,
auto& block = program.Block(0); const DataFeedDesc& data_feed_desc,
for (auto& var : block.AllVars()) { const std::vector<std::string>& filelist,
if (var->Persistable()) { const int thread_num,
auto* ptr = scope->Var(var->Name()); const std::vector<std::string>& fetch_var_names) {
CreateTensor(ptr, var->GetType()); std::vector<std::thread> threads;
// LOGERR("Persistable Var Name:%s", var->Name().c_str());
}
}
std::map<std::string, int> param_dict; /*
std::vector<OperatorBase *> ops; readerDesc: protobuf description for reader initlization
for (auto& op_desc : block.AllOps()) { argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
std::vector<std::string> param_name_vec = op_desc->OutputArgumentNames();
bool need_to_run = false; reader:
for (auto& name : param_name_vec) { 1) each thread has a reader, reader will read input data and
if (param_dict.find(name) == param_dict.end()) { put it into input queue
param_dict[name] = 1; 2) each reader has a Next() iterface, that can fetch an instance
need_to_run = true; from the input queue
} */
} // todo: should be factory method for creating datafeed
if (need_to_run) { std::vector<std::shared_ptr<DataFeed> > readers;
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc); readers.resize(thread_num);
OperatorBase* local_op_ptr = local_op.release(); for (unsigned int i = 0; i < readers.size(); ++i) {
ops.push_back(local_op_ptr); readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
} }
// LOGERR("There are %d parameters in startup program, %d op needs to run",
// param_dict.size(), ops.size());
for (auto& op : ops) { std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
op->Run(*scope, place_); workers.resize(thread_num);
for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker);
} }
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for (auto& op : ops) {
delete op;
}
// LOGERR("run startup program done.");
}
std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile( // prepare thread resource here
const std::string& f) { for (int thidx = 0; thidx < thread_num; ++thidx) {
std::string program_desc_str; CreateThreads(workers[thidx].get(), main_program,
ReadBinaryFile(f, &program_desc_str); readers[thidx], fetch_var_names, root_scope_, thidx);
std::unique_ptr<ProgramDesc> program(new ProgramDesc(program_desc_str));
return program;
}
void AsyncExecutor::SetInspectVarNames(
const std::vector<std::string>& inspect_var_names) {
inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
}
void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i].reset(new ExecutorThreadWorker);
workers_[i]->SetThreadId(i);
workers_[i]->CreateThreadOperators(host_program);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->SetPlace(place_);
workers_[i]->SetMaxTrainingEpoch(max_epoch_);
workers_[i]->CreateThreadScope(host_program);
workers_[i]->SetInspectVarNames(inspect_var_names_);
workers_[i]->SetModelParamNames(model_param_names_);
workers_[i]->SetMainProgram(host_program);
workers_[i]->SetModelPrefix(model_prefix_);
//
// new a datafeed here
workers_[i]->SetDataFeed(data_feed_);
workers_[i]->BindingDataFeedMemory();
} }
}
std::vector<float>& AsyncExecutor::Run( // start executing ops in multiple threads
const std::vector<std::string>& inspect_var_names) { for (int thidx = 0; thidx < thread_num; ++thidx) {
SetInspectVarNames(inspect_var_names); threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles,
threads_.clear(); workers[thidx].get()));
// thread binding here?
if (workers_initialized_ == false) {
PrepareThreads(main_program_);
workers_initialized_ = true;
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->Reset();
workers_[i]->SetInspectVarNames(inspect_var_names);
threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
workers_[i].get()));
} }
for (auto& th : threads_) { for (auto& th : threads) {
th.join(); th.join();
} }
inspect_values_.clear(); std::vector<float> fetch_values;
inspect_values_.resize(inspect_var_names_.size(), 0); fetch_values.resize(fetch_var_names.size(), 0);
std::vector<std::vector<float>*> inspect_value_vectors; std::vector<std::vector<float>*> fetch_value_vectors;
inspect_value_vectors.resize(thread_num_); fetch_value_vectors.resize(thread_num);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num; ++i) {
inspect_value_vectors[i] = &workers_[i]->GetInspectValues(); fetch_value_vectors[i] = &workers[i]->GetFetchValues();
} }
for (unsigned int i = 0; i < inspect_var_names_.size(); ++i) { for (unsigned int i = 0; i < fetch_var_names.size(); ++i) {
float value = 0.0; float value = 0.0;
for (int j = 0; j < thread_num_; ++j) { for (int j = 0; j < thread_num; ++j) {
value += inspect_value_vectors[j]->at(i); value += fetch_value_vectors[j]->at(i);
} }
value /= thread_num_; value /= thread_num;
inspect_values_[i] = value; fetch_values[i] = value;
} }
return inspect_values_; return fetch_values;
} }
void AsyncExecutor::LoadInitModel() { void AsyncExecutor::LoadInitModel() {
......
...@@ -23,7 +23,8 @@ limitations under the License. */ ...@@ -23,7 +23,8 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <typeinfo> #include <typeinfo>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/datafeed_creator.h" #include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -31,93 +32,13 @@ limitations under the License. */ ...@@ -31,93 +32,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
void CreateThreadScope(const ProgramDesc& program);
void SetThreadId(int tid);
void CreateThreadOperators(const ProgramDesc& program);
void SetRootScope(Scope* g_scope);
void SetDevice();
void AddFidSet();
void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
void AddTrainFile(const std::string& filename);
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
void SetMaxTrainingEpoch(const int max_epoch);
void BindingDataFeedMemory();
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
void SetModelParamNames(const std::vector<std::string>& param_names);
void SetDataFeed(DataFeed& datafeed); // NOLINT
void Train();
const char* PickOneFile();
void UpdateEpochNum();
void Reset();
void Initialize() {}
std::vector<float>& GetInspectValues() {return inspect_values_;}
protected:
// thread index
int thread_id_;
// max epoch for each thread
unsigned int max_epoch_;
// instances learned currently
int comm_batch_;
std::string model_prefix_;
std::vector<std::string> op_names_;
// local ops for forward and backward
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<ProgramDesc> main_program_;
// binary data reader
std::unique_ptr<DataFeed> local_reader_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
// execution place
platform::Place place_;
// root scope for model parameters
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
private:
std::vector<float> inspect_values_;
};
class AsyncExecutor { class AsyncExecutor {
public: public:
explicit AsyncExecutor(ProgramDesc& main_program, // NOLINT explicit AsyncExecutor(Scope& scope, const platform::Place& place); // NOLINT
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed, // NOLINT
unsigned int thread_num,
const platform::Place& place);
virtual ~AsyncExecutor() {} virtual ~AsyncExecutor() {}
static std::unique_ptr<ProgramDesc> LoadDescFromFile( static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename); const std::string& filename);
void InitRootScope(Scope* scope); Scope* GetRootScope() { return &root_scope_; }
void SetMaxTrainingEpoch(const int max_epoch);
Scope* GetRootScope() { return root_scope_; }
void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch;
}
void SetModelPath(const std::string& model_path) { void SetModelPath(const std::string& model_path) {
model_path_ = model_path; model_path_ = model_path;
...@@ -132,38 +53,32 @@ class AsyncExecutor { ...@@ -132,38 +53,32 @@ class AsyncExecutor {
} }
void SetModelPrefix(const std::string& model_prefix); void SetModelPrefix(const std::string& model_prefix);
virtual void PrepareThreads(const ProgramDesc& host_program);
void RunStartupProgram(const ProgramDesc& program, Scope* scope); void RunStartupProgram(const ProgramDesc& program, Scope* scope);
std::vector<float>& Run(const std::vector<std::string>& inspect_var_names); std::vector<float> RunFromFile(const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names);
void CheckFiles(const std::vector<std::string>& files);
void LoadInitModel(); void LoadInitModel();
private: private:
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names); void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope& root_scope, // NOLINT
const int thread_index);
public: public:
int thread_num_;
int max_epoch_;
int batch_size_;
int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> threads_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
std::string model_prefix_; std::string model_prefix_;
std::string model_path_; std::string model_path_;
std::string init_prog_file_; std::string init_prog_file_;
std::string init_model_file_; std::string init_model_file_;
Scope* root_scope_; Scope& root_scope_;
platform::Place place_; platform::Place place_;
private:
ProgramDesc& main_program_;
TextClassDataFeed& data_feed_;
std::vector<float> inspect_values_;
private:
static bool workers_initialized_;
}; };
} // namespace framework } // namespace framework
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
......
...@@ -43,7 +43,7 @@ size_t DataFeed::file_idx_; ...@@ -43,7 +43,7 @@ size_t DataFeed::file_idx_;
std::mutex DataFeed::mutex_for_pick_file_; std::mutex DataFeed::mutex_for_pick_file_;
void DataFeed::AddFeedVar(Variable* var, const std::string& name) { void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
if (CheckInit() == false) {return;} CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) { if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) { if (use_slots_is_dense_[i]) {
...@@ -56,7 +56,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { ...@@ -56,7 +56,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
} }
bool DataFeed::SetFileList(const std::vector<std::string>& files) { bool DataFeed::SetFileList(const std::vector<std::string>& files) {
if (CheckInit() == false) {return false;} CheckInit();
if (files.size() == 0) { if (files.size() == 0) {
LOG(ERROR) << "error: you have set an empty filelist"; LOG(ERROR) << "error: you have set an empty filelist";
return false; return false;
...@@ -77,27 +77,27 @@ bool DataFeed::PickOneFile(std::string& filename) { ...@@ -77,27 +77,27 @@ bool DataFeed::PickOneFile(std::string& filename) {
return true; return true;
} }
bool DataFeed::CheckInit() { void DataFeed::CheckInit() {
if (finish_init_) {return true;} if (finish_init_) {return;}
LOG(ERROR) << "error: initialization did not succeed"; LOG(ERROR) << "error: initialization did not succeed";
return false; exit(-1);
} }
bool DataFeed::CheckSetFileList() { void DataFeed::CheckSetFileList() {
if (finish_set_filelist_) {return true;} if (finish_set_filelist_) {return;}
LOG(ERROR) << "error: set filelist did not succeed"; LOG(ERROR) << "error: set filelist did not succeed";
return false; exit(-1);
} }
bool DataFeed::CheckStart() { void DataFeed::CheckStart() {
if (finish_start_) {return true;} if (finish_start_) {return;}
LOG(ERROR) << "error: Datafeed has not started running yet"; LOG(ERROR) << "error: Datafeed has not started running yet";
return false; exit(-1);
} }
template<typename T> template<typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
if (!CheckInit()) {return;} CheckInit();
if (queue_size <= 0) { if (queue_size <= 0) {
LOG(ERROR) << "error: illegal queue size: " << queue_size; LOG(ERROR) << "error: illegal queue size: " << queue_size;
return; return;
...@@ -108,7 +108,7 @@ void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { ...@@ -108,7 +108,7 @@ void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
template<typename T> template<typename T>
bool PrivateQueueDataFeed<T>::Start() { bool PrivateQueueDataFeed<T>::Start() {
if (!(CheckSetFileList())) {return false;} CheckSetFileList();
read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this); read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
read_thread_.detach(); read_thread_.detach();
...@@ -121,8 +121,9 @@ void PrivateQueueDataFeed<T>::ReadThread(){ ...@@ -121,8 +121,9 @@ void PrivateQueueDataFeed<T>::ReadThread(){
std::string filename; std::string filename;
while (PickOneFile(filename)) { while (PickOneFile(filename)) {
file_.open(filename.c_str()); // is_text_feed file_.open(filename.c_str()); // is_text_feed
if (!file_.is_open()) { if (!file_.good()) {
LOG(ERROR) << "error: open file<" << filename << "> fail"; LOG(ERROR) << "error: open file<" << filename << "> fail";
continue;
} }
T instance; T instance;
while (ParseOneInstance(instance)) { while (ParseOneInstance(instance)) {
...@@ -135,7 +136,7 @@ void PrivateQueueDataFeed<T>::ReadThread(){ ...@@ -135,7 +136,7 @@ void PrivateQueueDataFeed<T>::ReadThread(){
template<typename T> template<typename T>
bool PrivateQueueDataFeed<T>::Next(){ bool PrivateQueueDataFeed<T>::Next(){
if (!CheckStart()) {return false;} CheckStart();
int index = 0; int index = 0;
T instance; T instance;
T ins_vec(use_slots_.size()); T ins_vec(use_slots_.size());
...@@ -150,15 +151,16 @@ bool PrivateQueueDataFeed<T>::Next(){ ...@@ -150,15 +151,16 @@ bool PrivateQueueDataFeed<T>::Next(){
return batch_size_ != 0; return batch_size_ != 0;
} }
void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) { void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false; finish_init_ = false;
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
/*
if (!data_feed_desc.has_multi_slot_desc()){ if (!data_feed_desc.has_multi_slot_desc()){
LOG(ERROR) << "error: multi_slot_desc has not been set"; LOG(ERROR) << "error: multi_slot_desc has not been set";
return ; exit(-1);
} }
paddle::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
size_t all_slot_num = multi_slot_desc.slots_size(); size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num); all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num); all_slots_type_.resize(all_slot_num);
...@@ -176,10 +178,16 @@ void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) { ...@@ -176,10 +178,16 @@ void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) {
} }
} }
feed_vec_.resize(use_slots_.size()); feed_vec_.resize(use_slots_.size());
*/
finish_init_ = true; finish_init_ = true;
} }
bool MultiSlotDataFeed::CheckFile(const char* filename) {
// check with protobuf ?
std::cerr << "Check error" << std::endl;
return false;
}
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) { bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
std::string line; std::string line;
if (getline(file_, line)) { if (getline(file_, line)) {
...@@ -233,6 +241,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec, ...@@ -233,6 +241,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec,
ins_vec[i].AddIns(instance[i]); ins_vec[i].AddIns(instance[i]);
} }
} }
void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) { void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) {
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
auto& type = ins_vec[i].GetType(); auto& type = ins_vec[i].GetType();
...@@ -253,7 +262,7 @@ void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) { ...@@ -253,7 +262,7 @@ void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) {
feed_vec_[i].GetLoDTensor()->set_lod(data_lod); feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
} }
} else if (type[0] == 'u') { // uint64 } else if (type[0] == 'u') { // uint64
// no uint64_t type // no uint64_t type in paddle
auto& feasign = ins_vec[i].GetUint64Data(); auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) { if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_; int size_in_each_batch = total_instance / batch_size_;
...@@ -263,7 +272,7 @@ void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) { ...@@ -263,7 +272,7 @@ void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) {
} else { } else {
int64_t* tensor_ptr = feed_vec_[i].GetLoDTensor()-> int64_t* tensor_ptr = feed_vec_[i].GetLoDTensor()->
mutable_data<int64_t>({total_instance, 1}, platform::CPUPlace()); mutable_data<int64_t>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(uint64_t)); memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
LoD data_lod{offset}; LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod); feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
} }
......
...@@ -53,14 +53,14 @@ class MixTensor { ...@@ -53,14 +53,14 @@ class MixTensor {
LoDTensor* GetLoDTensor(){ LoDTensor* GetLoDTensor(){
if (is_dense_) { if (is_dense_) {
LOG(ERROR) << "error: let a dense var return a LoDTensor ptr"; LOG(ERROR) << "error: let a dense var return a LoDTensor ptr";
return NULL; exit(-1);
} }
return lodtensor_; return lodtensor_;
} }
Tensor* GetTensor(){ Tensor* GetTensor(){
if (!is_dense_) { if (!is_dense_) {
LOG(ERROR) << "error: let a sparse var return a Tensor ptr"; LOG(ERROR) << "error: let a sparse var return a Tensor ptr";
return NULL; exit(-1);
} }
return tensor_; return tensor_;
} }
...@@ -155,7 +155,7 @@ class DataFeed { ...@@ -155,7 +155,7 @@ class DataFeed {
public: public:
DataFeed() {} DataFeed() {}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0; virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc) = 0;
// for some datafeeds may not be able to implement this interface // for some datafeeds may not be able to implement this interface
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
LOG(ERROR) << "error: The function CheckFile is not implemented"; LOG(ERROR) << "error: The function CheckFile is not implemented";
...@@ -169,10 +169,12 @@ class DataFeed { ...@@ -169,10 +169,12 @@ class DataFeed {
// for subclass with queue // for subclass with queue
virtual void SetQueueSize(int queue_size) { virtual void SetQueueSize(int queue_size) {
LOG(ERROR) << "error: The function SetQueueSize is not implemented"; LOG(ERROR) << "error: The function SetQueueSize is not implemented";
exit(-1);
} }
// for subclass with buffer // for subclass with buffer
virtual void SetBufferSize(int buffer_size) { virtual void SetBufferSize(int buffer_size) {
LOG(ERROR) << "error: The function SetBufferSize is not implemented"; LOG(ERROR) << "error: The function SetBufferSize is not implemented";
exit(-1);
} }
virtual const std::vector<std::string>& GetAllSlots() {return all_slots_;} virtual const std::vector<std::string>& GetAllSlots() {return all_slots_;}
virtual const std::vector<std::string>& GetUseSlots() {return use_slots_;} virtual const std::vector<std::string>& GetUseSlots() {return use_slots_;}
...@@ -181,9 +183,9 @@ class DataFeed { ...@@ -181,9 +183,9 @@ class DataFeed {
protected: protected:
// Check if it is executed in this order: // Check if it is executed in this order:
// Init -> SetFileList/BindingMemory -> Start -> Next // Init -> SetFileList/BindingMemory -> Start -> Next
virtual bool CheckInit(); virtual void CheckInit();
virtual bool CheckSetFileList(); virtual void CheckSetFileList();
virtual bool CheckStart(); virtual void CheckStart();
virtual bool PickOneFile(std::string& filename); virtual bool PickOneFile(std::string& filename);
static std::vector<std::string> filelist_; static std::vector<std::string> filelist_;
...@@ -213,7 +215,7 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -213,7 +215,7 @@ class PrivateQueueDataFeed : public DataFeed {
public: public:
PrivateQueueDataFeed() {} PrivateQueueDataFeed() {}
virtual ~PrivateQueueDataFeed() {} virtual ~PrivateQueueDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0; virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual bool Next(); // no buffer virtual bool Next(); // no buffer
virtual void SetQueueSize(int queue_size); virtual void SetQueueSize(int queue_size);
...@@ -247,28 +249,28 @@ class MultiSlotType { ...@@ -247,28 +249,28 @@ class MultiSlotType {
} }
~MultiSlotType() {} ~MultiSlotType() {}
void SetType(std::string& type) { void SetType(std::string& type) {
if (!CheckType(type)) {return;} CheckType(type);
type_ = type; type_ = type;
} }
std::vector<size_t>& GetOffset() { std::vector<size_t>& GetOffset() {
return offset_; return offset_;
} }
void AddValue(float v) { void AddValue(float v) {
if (!CheckFloat()) {return;} CheckFloat();
float_feasign_.push_back(v); float_feasign_.push_back(v);
} }
void AddValue(uint64_t v) { void AddValue(uint64_t v) {
if (!CheckUint64()) {return;} CheckUint64();
uint64_feasign_.push_back(v); uint64_feasign_.push_back(v);
} }
void AddIns(MultiSlotType& ins) { void AddIns(MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { //float if (ins.GetType()[0] == 'f') { //float
if (!CheckFloat()) {return;} CheckFloat();
auto& vec = ins.GetFloatData(); auto& vec = ins.GetFloatData();
offset_.push_back(offset_.back() + vec.size()); offset_.push_back(offset_.back() + vec.size());
float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end()); float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
} else if (ins.GetType()[0] == 'u') { //uint64 } else if (ins.GetType()[0] == 'u') { //uint64
if (!CheckUint64()) {return;} CheckUint64();
auto& vec = ins.GetUint64Data(); auto& vec = ins.GetUint64Data();
offset_.push_back(offset_.back() + vec.size()); offset_.push_back(offset_.back() + vec.size());
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end()); uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
...@@ -284,27 +286,24 @@ class MultiSlotType { ...@@ -284,27 +286,24 @@ class MultiSlotType {
return type_; return type_;
} }
private: private:
bool CheckType(std::string& type) { void CheckType(std::string& type) {
if (type != "uint64" && type != "float") { if (type != "uint64" && type != "float") {
// check in here // check in here
LOG(ERROR) << "error: here is no this type"; LOG(ERROR) << "error: here is no this type";
return false; exit(-1);
} }
return true;
} }
bool CheckFloat() { void CheckFloat() {
if (type_[0] != 'f') { //float if (type_[0] != 'f') { //float
LOG(ERROR) << "error: add " << type_ << " value to float slot"; LOG(ERROR) << "error: add " << type_ << " value to float slot";
return false; exit(-1);
} }
return true;
} }
bool CheckUint64() { void CheckUint64() {
if (type_[0] != 'u') { //uint64 if (type_[0] != 'u') { //uint64
LOG(ERROR) << "error: add " << type_ << " value to uint64 slot"; LOG(ERROR) << "error: add " << type_ << " value to uint64 slot";
return false; exit(-1);
} }
return true;
} }
std::vector<float> float_feasign_; std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_; std::vector<uint64_t> uint64_feasign_;
...@@ -316,8 +315,8 @@ class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType> ...@@ -316,8 +315,8 @@ class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType>
public: public:
MultiSlotDataFeed() {} MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {} virtual ~MultiSlotDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc); virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc);
//TODO: virtual bool CheckFile(); virtual bool CheckFile(const char* filename);
protected: protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>& vec_ins, virtual void AddInstanceToInsVec(std::vector<MultiSlotType>& vec_ins,
std::vector<MultiSlotType>& instance, int index); std::vector<MultiSlotType>& instance, int index);
...@@ -325,39 +324,6 @@ class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType> ...@@ -325,39 +324,6 @@ class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType>
virtual void PutToFeedVec(std::vector<MultiSlotType>& ins_vec); virtual void PutToFeedVec(std::vector<MultiSlotType>& ins_vec);
}; };
//TODO: to be deleted
class TextClassDataFeed : public DataFeed {
public:
virtual ~TextClassDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) {}
virtual bool Start() {return false;}; //TODO
virtual bool Next() {return false;}; //TODO
virtual bool ReadBatch() {return false;}
virtual void AddFeedVar(Variable* feed, const std::string& name) {}
virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename) {return false;}
virtual bool CheckFile(const char* filename) {
// TODO(xxx)
return false;
}
void SetBatchSize(int batch) {batch_size_ = batch;}
private:
int ReadWholeFile(const std::string& filename, char* buffer) {return -1;}
char* file_content_buffer_;
char* file_content_buffer_ptr_;
int* batch_id_buffer_;
int* label_ptr_;
int file_size_;
std::vector<std::string> names_;
std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
package paddle; package paddle.framework;
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
...@@ -28,5 +28,5 @@ message Slot { ...@@ -28,5 +28,5 @@ message Slot {
required string name = 1; required string name = 1;
required string type = 2; required string type = 2;
optional bool dense = 3 [default = false]; optional bool dense = 3 [default = false];
optional bool use = 4 [default = true]; optional bool use = 4 [default = false];
} }
...@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef shared_ptr<DataFeed> (*Createdata_feedFunction)(); typedef std::shared_ptr<DataFeed> (*Createdata_feedFunction)();
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap; typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
data_feedMap g_data_feed_map; data_feedMap g_data_feed_map;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \ #define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \ namespace { \
shared_ptr<DataFeed> Creator_##data_feed_class() { \ std::shared_ptr<DataFeed> Creator_##data_feed_class() { \
return shared_ptr<DataFeed>(new data_feed_class); \ return std::shared_ptr<DataFeed>(new data_feed_class); \
} \ } \
class __Registerer_##data_feed_class { \ class __Registerer_##data_feed_class { \
public: \ public: \
...@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map; ...@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map;
} // namespace } // namespace
string DataFeedFactory::DataFeedTypeList() { std::string DataFeedFactory::DataFeedTypeList() {
string data_feed_types; std::string data_feed_types;
for (auto iter = g_data_feed_map.begin(); for (auto iter = g_data_feed_map.begin();
iter != g_data_feed_map.end(); ++iter) { iter != g_data_feed_map.end(); ++iter) {
if (iter != g_data_feed_map.begin()) { if (iter != g_data_feed_map.begin()) {
...@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() { ...@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() {
return data_feed_types; return data_feed_types;
} }
shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
const char* data_feed_class) { std::string data_feed_class) {
if (g_data_feed_map.count(string(data_feed_class)) < 1) { if (g_data_feed_map.count(data_feed_class) < 1) {
exit(-1); exit(-1);
} }
return g_data_feed_map[data_feed_class](); return g_data_feed_map[data_feed_class]();
......
...@@ -16,14 +16,15 @@ limitations under the License. */ ...@@ -16,14 +16,15 @@ limitations under the License. */
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_ #define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#include <string> #include <string>
#include "paddle/framework/data_feed.h" #include <memory>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class DataFeedFactory { class DataFeedFactory {
public: public:
static std::string DataFeedTypeList(); static std::string DataFeedTypeList();
static shared_ptr<DataFeed> CreateDataFeed(const char* data_feed_class); static std::shared_ptr<DataFeed> CreateDataFeed(std::string data_feed_class);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource( ...@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource(
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_,
"root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
if (var->Persistable()) { if (var->Persistable()) {
...@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { ...@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
void ExecutorThreadWorker::SetDataFeed( void ExecutorThreadWorker::SetDataFeed(
const std::shared_ptr<DataFeed>& datafeed) { const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed; thread_reader_ = datafeed;
} }
void ExecutorThreadWorker::BindingDataFeedMemory() { void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias(); thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) { for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name); thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
} }
} }
void ExecutorThreadWorker::SetFetchVarNames(
const std::vector<std::string>& fetch_var_names) {
fetch_var_names_.clear();
fetch_var_names_.insert(fetch_var_names_.end(),
fetch_var_names.begin(), fetch_var_names.end());
}
void ExecutorThreadWorker::SetDevice() { void ExecutorThreadWorker::SetDevice() {
// at most 48 threads binding currently // at most 48 threads binding currently
static unsigned priority[] = { static unsigned priority[] = {
...@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() { ...@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() {
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
// todo: configurable // todo: configurable
SetDevice(); SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
thread_reader_->Start(); thread_reader_->Start();
while (int cur_batch = thread_reader_->Next()) {
int cur_batch;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here // executor run here
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
} }
......
...@@ -43,6 +43,9 @@ class ExecutorThreadWorker { ...@@ -43,6 +43,9 @@ class ExecutorThreadWorker {
void SetDevice(); void SetDevice();
void BindingDataFeedMemory(); void BindingDataFeedMemory();
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed); void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void TrainFiles();
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
std::vector<float>& GetFetchValues() {return fetch_values_;}
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
...@@ -66,9 +69,13 @@ class ExecutorThreadWorker { ...@@ -66,9 +69,13 @@ class ExecutorThreadWorker {
Scope* root_scope_; Scope* root_scope_;
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* thread_scope_; Scope* thread_scope_;
private:
std::vector<std::string> fetch_var_names_;
std::vector<float> fetch_values_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_ #endif // PADDLE_FLUID_FRAMEWORK_EXECUTOR_THREAD_WORKER_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ /* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
...@@ -15,7 +15,7 @@ cc_library(inference_io SRCS io.cc) ...@@ -15,7 +15,7 @@ cc_library(inference_io SRCS io.cc)
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal? # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
cc_library(paddle_fluid_api cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
......
...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { ...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout", "split"}); "elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"});
if (!node->IsOp()) return false; if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) { if (teller_set.count(node->Op()->Type())) {
......
...@@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat); ...@@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(split);
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
#endif #endif
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) nv_library(tensorrt_engine SRCS engine.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(plugin) add_subdirectory(plugin)
......
...@@ -2,35 +2,38 @@ ...@@ -2,35 +2,38 @@
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc pad_op.cc split_op.cc prelu_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op SERIAL)
nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine conv_op conv_transpose_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine elementwise_add_op SERIAL)
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine dropout_op SERIAL)
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL) split_op concat_op SERIAL)
nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
prelu_op SERIAL)
...@@ -18,92 +18,139 @@ namespace paddle { ...@@ -18,92 +18,139 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
bool to_skip_merging_optimize(TensorRTEngine* engine_, bool to_skip_merging_optimize(TensorRTEngine* engine,
const std::vector<int>& filters, const std::vector<int>& filters,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
std::string input_name) { std::string input_name) {
if (engine_->itensor_quote_num[input_name] > 0) { if (engine->itensor_quote_num[input_name] > 0) {
return true; return true;
} }
if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 &&
strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0)
engine_->itensor_quote_num[input_name] += 1; engine->itensor_quote_num[input_name] += 1;
return false; return false;
} }
template <typename RegistFunc, typename SetDilationFunc>
void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode,
RegistFunc fadd_layer, SetDilationFunc fset_dilation,
const std::string& name) {
VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias";
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1);
PADDLE_ENFORCE(engine != nullptr);
auto* X = engine->GetITensor(op_desc.Input("Input").front());
// Declare weights
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get());
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
const int n_output = weight_tensor->dims()[0];
const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
const std::vector<int> dilations =
boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
const std::vector<int> strides =
boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
const std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(weight_tensor->numel())};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input,
nv_ksize, weight, bias);
PADDLE_ENFORCE(layer != nullptr);
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setNbGroups(groups);
// set dilations
fset_dilation(layer, nv_dilations);
auto output_name = op_desc.Output("Output").front();
layer->setName((name + " (Output: " + output_name + ")").c_str());
engine->weight_map[op_desc.Input("Filter").front()] =
std::move(weight_tensor);
layer->getOutput(0)->setName(output_name.c_str());
engine->SetITensor(output_name, layer->getOutput(0));
if (test_mode ||
to_skip_merging_optimize(engine, {filter_h, filter_w}, strides, paddings,
op_desc.Input("Input").front())) {
engine->DeclareOutput(output_name);
}
}
class Conv2dOpConverter : public OpConverter { class Conv2dOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid conv2d op to tensorrt conv layer without bias"; ConvertConv2d(
engine_, op, scope, test_mode,
framework::OpDesc op_desc(op, nullptr); [&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); int n_input, /* Conv input maps */
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer =
auto* X = engine_->GetITensor(op_desc.Input("Input").front()); TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
ksize, weight.get(), bias.get());
// Declare weights return layer;
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); },
PADDLE_ENFORCE_NOT_NULL(Y_v); [](nvinfer1::IConvolutionLayer* layer, nvinfer1::DimsHW& dilations) {
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); layer->setDilation(dilations);
},
platform::CPUPlace cpu_place; "conv2d");
std::unique_ptr<framework::LoDTensor> weight_tensor( }
new framework::LoDTensor()); };
weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); class Deconv2dOpConverter : public OpConverter {
public:
auto* weight_data = void operator()(const framework::proto::OpDesc& op,
weight_tensor->mutable_data<float>(platform::CPUPlace()); const framework::Scope& scope, bool test_mode) override {
ConvertConv2d(
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); engine_, op, scope, test_mode,
const int n_output = weight_tensor->dims()[0]; [&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */
const int filter_h = weight_tensor->dims()[2]; int n_input, /* Deconv output maps */
const int filter_w = weight_tensor->dims()[3]; nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
const int groups = boost::get<int>(op_desc.GetAttr("groups")); auto* layer =
const std::vector<int> dilations = TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input,
boost::get<std::vector<int>>(op_desc.GetAttr("dilations")); ksize, weight.get(), bias.get());
const std::vector<int> strides = return layer;
boost::get<std::vector<int>>(op_desc.GetAttr("strides")); },
const std::vector<int> paddings = [](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::DimsHW& dilations) {
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); PADDLE_ENFORCE(
dilations.d[0] == 1 && dilations.d[1] == 1,
nvinfer1::DimsHW nv_ksize(filter_h, filter_w); "Dilations must be (1, 1) for tensorRT, but given (%d, %d)",
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); dilations.d[0], dilations.d[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]); },
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); "conv2d_transpose");
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
weight_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *const_cast<nvinfer1::ITensor*>(X), n_output,
nv_ksize, weight.get(), bias.get());
PADDLE_ENFORCE(layer != nullptr);
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setDilation(nv_dilations);
layer->setNbGroups(groups);
auto output_name = op_desc.Output("Output").front();
layer->setName(("conv2d (Output: " + output_name + ")").c_str());
engine_->weight_map[op_desc.Input("Filter").front()] =
std::move(weight_tensor);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode ||
to_skip_merging_optimize(engine_, {filter_h, filter_w}, strides,
paddings, op_desc.Input("Input").front())) {
engine_->DeclareOutput(output_name);
}
} }
}; };
...@@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter { ...@@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter {
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter);
...@@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* X = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims dims_x = X->getDimensions(); nvinfer1::Dims dims_x = X->getDimensions();
PADDLE_ENFORCE(dims_x.nbDims >= 3); PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.",
dims_x.nbDims);
auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v); PADDLE_ENFORCE_NOT_NULL(Y_v);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* PRelu converter from fluid to tensorRT.
*/
class PReluOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid prelu op to tensorrt prelu layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE(input_num == 1);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE(output_num == 1);
// Get attrs
std::string mode = boost::get<std::string>(op_desc.GetAttr("mode"));
//
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
PADDLE_ENFORCE_NOT_NULL(alpha_var);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
platform::CUDAPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
new framework::LoDTensor());
alpha_tensor_device->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data),
alpha_tensor_device->numel());
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
// keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_device);
std::string layer_name = "prelu (Output: ";
auto output_name = op_desc.Output("Out")[0];
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
layer->setName((layer_name + ")").c_str());
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter);
...@@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter { ...@@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert a fluid split op to tensorrt split layer"; VLOG(4) << "convert a fluid split op to tensorrt split layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_OP(conv2d);
USE_OP(conv2d_transpose);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
...@@ -51,7 +54,37 @@ TEST(conv2d_op, test) { ...@@ -51,7 +54,37 @@ TEST(conv2d_op, test) {
validator.Execute(3); validator.Execute(3);
} }
TEST(conv2d_transpose_op, test) {
std::unordered_set<std::string> parameters({"deconv2d-Y"});
framework::Scope scope;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
validator.DeclInputVar("deconv2d-X", nvinfer1::Dims3(3, 5, 5));
validator.DeclParamVar("deconv2d-Y", nvinfer1::Dims4(3, 2, 3, 3));
validator.DeclOutputVar("deconv2d-Out", nvinfer1::Dims3(2, 5, 5));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("conv2d_transpose");
desc.SetInput("Input", {"deconv2d-X"});
desc.SetInput("Filter", {"deconv2d-Y"});
desc.SetOutput("Output", {"deconv2d-Out"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
const int groups = 1;
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("dilations", dilations);
desc.SetAttr("groups", groups);
validator.SetOp(*desc.Proto());
validator.Execute(3);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(conv2d);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(prelu_op, test_channel_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("channel"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(prelu_op, test_element_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("element"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(prelu_op, test_scalar) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("all"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// USE_OP(prelu);
USE_CPU_ONLY_OP(prelu);
...@@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, ...@@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
Buffer &TensorRTEngine::buffer(const std::string &name) { Buffer &TensorRTEngine::buffer(const std::string &name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name); auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s",
name);
auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
return buffers_[slot_offset]; return buffers_[slot_offset];
} }
......
...@@ -40,6 +40,7 @@ class TensorRTEngine : public EngineBase { ...@@ -40,6 +40,7 @@ class TensorRTEngine : public EngineBase {
// Weight is model parameter. // Weight is model parameter.
class Weight { class Weight {
public: public:
Weight() = default;
Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) { Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) {
w_.type = dtype; w_.type = dtype;
w_.values = value; w_.values = value;
......
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
float *output, int channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float *out = output + offset;
float scale = alpha[blockIdx.x % channel];
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
const float *scale = alpha + offset;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale[i] * x;
}
}
__global__ void PReluScalarKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float scale = *alpha;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
static inline void PReluChannelWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, dims.d[0], spatial_size);
}
static inline void PReluElementWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
static inline void PReluScalar(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size, const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
float *output = reinterpret_cast<float **>(outputs)[0];
if (mode_ == "channel") {
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
} else if (mode_ == "element") {
PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
} else {
PReluScalar(stream, input, alpha, output, batchSize, input_dims);
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_;
std::string mode_;
protected:
size_t getSerializationSize() override {
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
return 0;
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
// serializeBase(buffer);
// SerializeValue(&buffer, alpha_);
// SerializeValue(&buffer, mode_);
}
public:
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode)
: alpha_(alpha), mode_(mode) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
PReluPlugin(void const *serialData, size_t serialLength) {
// deserializeBase(serialData, serialLength);
// DeserializeValue(&serialData, &serialLength, &alpha_);
// DeserializeValue(&serialData, &serialLength, &mode_);
}
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); }
const char *getPluginType() const override { return "prelu"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") include(operators)
string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}")
string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}")
list(REMOVE_DUPLICATES GENERAL_OPS)
set(DEPS_OPS "")
set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h)
file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operator/CMakeLists.txt. DO NOT EDIT!\n\n")
set(PART_CUDA_KERNEL_FILES)
function(op_library TARGET)
# op_library is a function to create op library. The interface is same as
# cc_library. But it handle split GPU/CPU code and link some common library
# for ops.
set(cc_srcs)
set(cu_srcs)
set(hip_cu_srcs)
set(miopen_hip_cc_srcs)
set(cu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(CUDNN_FILE)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
set(op_common_deps operator op_registry math_function)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(pybind_flag 0)
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
list(LENGTH op_library_SRCS op_library_SRCS_len)
if (${op_library_SRCS_len} EQUAL 0)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
list(APPEND cc_srcs ${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu)
list(APPEND hip_cu_srcs ${TARGET}.hip.cu)
endif()
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
endif()
if(WITH_AMD_GPU)
string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc)
list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc)
endif()
endif()
if(WITH_MKLDNN)
string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc)
list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc)
endif()
endif()
else()
foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.hip.cu$")
list(APPEND hip_cu_srcs ${src})
elseif (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
list(APPEND cudnn_cu_cc_srcs ${src})
elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$")
list(APPEND miopen_hip_cc_srcs ${src})
elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
list(APPEND mkldnn_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
else()
message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
endif()
endforeach()
endif()
list(LENGTH cc_srcs cc_srcs_len)
if (${cc_srcs_len} EQUAL 0)
message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
endif()
if (WIN32)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op"
"fusion_seqexpand_concat_fc_op" "attention_lstm_op" "fused_embedding_fc_lstm_op" "fc_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
endforeach()
endif(WIN32)
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
list(LENGTH op_library_DEPS op_library_DEPS_len) # clean cache and pybind_file content first when rebuild
if (${op_library_DEPS_len} GREATER 0) unset(GLOB_OP_LIB CACHE)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) unset(OP_LIBRARY CACHE)
endif() set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h CACHE INTERNAL "pybind.h file")
if (WITH_GPU) file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operator/CMakeLists.txt. DO NOT EDIT!\n\n")
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
elseif (WITH_AMD_GPU)
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
endif()
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
endforeach()
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}")
if (one_register STREQUAL "")
string(REPLACE "_op" "" TARGET "${TARGET}")
else ()
string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
string(REPLACE "," "" TARGET "${TARGET}")
endif()
# pybind USE_NO_KERNEL_OP
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
string(REPLACE "_op" "" TARGET "${TARGET}")
if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
# pybind USE_CPU_ONLY_OP
list(LENGTH cu_srcs cu_srcs_len)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
list(LENGTH hip_cu_srcs hip_cu_srcs_len)
list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MIOPEN
if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
endif()
# pybind USE_OP
if (${pybind_flag} EQUAL 0)
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "fake_quantize")
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
endif()
endfunction()
add_subdirectory(math) add_subdirectory(math)
if (NOT WIN32) add_subdirectory(controlflow)
add_subdirectory(nccl) add_subdirectory(csp)
if(WITH_GPU) add_subdirectory(detection)
op_library(nccl_op DEPS nccl_common) add_subdirectory(elementwise)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") add_subdirectory(fused)
else() add_subdirectory(metrics)
set(DEPS_OPS ${DEPS_OPS} nccl_op) add_subdirectory(optimizers)
endif() add_subdirectory(reduce_ops)
endif() # NOT WIN32 add_subdirectory(sequence_ops)
set(DISTRIBUTE_DEPS "")
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed) add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "") add_subdirectory(distributed_ops)
if(WITH_GRPC) endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY})
find_library(RDMACM_LIBRARY NAMES rdmacm)
ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY})
set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} ibverbs rdmacm)
endif()
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach()
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) if (NOT WIN32)
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op add_subdirectory(reader)
# listen_and_serv_op sum_op executor SERIAL)
if(WITH_GPU AND NOT WIN32)
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op ${DISTRIBUTE_DEPS} executor SERIAL)
if(WITH_GRPC)
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc)
else()
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_brpc)
endif()
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif() # WITH_GPU AND NOT WIN32
else()
set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) if (NOT WIN32)
if(WITH_GPU) add_subdirectory(nccl)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax cub)
op_library(sequence_softmax_op DEPS cub)
else()
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
endif() endif()
op_library(softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND) if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) add_subdirectory(tensorrt)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n")
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op
analysis)
else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()
op_library(hash_op DEPS xxhash)
op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) register_operators(EXCLUDES warpctc_op)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor) # warpctc_cudnn need cudnn 7 above
op_library(print_op DEPS lod_tensor)
op_library(adagrad_op DEPS selected_rows_functor)
op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op DEPS lod_rank_table)
op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
if (NOT WIN32)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
endif(NOT WIN32)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory)
if (NOT WIN32)
op_library(crf_decoding_op DEPS jit_kernel)
op_library(fusion_lstm_op DEPS jit_kernel)
endif(NOT WIN32)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(layer_norm_op DEPS cub) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
op_library(reduce_mean_op DEPS cub) else()
op_library(affine_channel_op DEPS cub) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
else() else()
op_library(conv_op DEPS vol2col im2col) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
op_library(conv_transpose_op DEPS vol2col im2col)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat_and_split)
op_library(tensor_array_to_tensor_op DEPS concat_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
set(COMMON_OP_DEPS "")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor dynload_warpctc sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler)
if (NOT WIN32) if (NOT WIN32)
add_subdirectory(reader) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
endif(NOT WIN32) endif()
foreach(src ${READER_LIBRARY}) if (WITH_GPU)
set(OP_LIBRARY ${src} ${OP_LIBRARY}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub)
endforeach() endif()
add_subdirectory(detection) # FIXME(typhoonzero): operator deps may not needed.
foreach(src ${DETECTION_LIBRARY}) # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
set(OP_LIBRARY ${src} ${OP_LIBRARY}) # op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
endforeach() # op_library(unsqueeze_op DEPS reshape_op)
# op_library(squeeze_op DEPS reshape_op)
# op_library(flatten_op DEPS reshape_op)
# op_library(unstack_op DEPS stack_op)
# op_library(tensor_array_to_tensor_op DEPS concat_op)
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS})
set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency") set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
...@@ -362,18 +76,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea ...@@ -362,18 +76,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
if(NOT WIN32)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if(WITH_GPU) set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
foreach(CUDA_KERNEL_FILE ${PART_CUDA_KERNEL_FILES})
file(READ ${CUDA_KERNEL_FILE} TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OP_CUDA_KERNEL\\(\\n?([^,]+),.*" MATCHED ${TARGET_CONTENT})
if (MATCHED)
string(STRIP ${CMAKE_MATCH_1} MATCHED)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MATCHED}, CUDA);\n")
endif()
endforeach()
endif()
include(operators)
register_operators()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <math.h> #include <math.h>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/logical_op.h" #include "paddle/fluid/operators/controlflow/logical_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/logical_op.h" #include "paddle/fluid/operators/controlflow/logical_op.h"
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CUDA, REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CUDA,
paddle::operators::LogicalAndFunctor); paddle::operators::LogicalAndFunctor);
......
include(operators)
register_operators()
...@@ -40,4 +40,8 @@ endif() ...@@ -40,4 +40,8 @@ endif()
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
#Export local libraries to parent #Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) # set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
foreach(src ${LOCAL_DETECTION_LIBS})
set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs")
endforeach()
include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY})
find_library(RDMACM_LIBRARY NAMES rdmacm)
ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY})
set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} ibverbs rdmacm)
endif()
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
list(REMOVE_DUPLICATES OPS)
foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach()
register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS})
if(WITH_GPU AND NOT WIN32)
set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common)
op_library(gen_nccl_id_op ${DISTRIBUTE_DEPS} nccl_common)
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE)
set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency")
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send"); DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get"); DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get");
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/merge_ids_op.h" #include "paddle/fluid/operators/distributed_ops/merge_ids_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h" #include "paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h"
#include <string> #include <string>
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h" #include "paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
ref_by_trainer_id, ref_by_trainer_id,
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/split_byref_op.h" #include "paddle/fluid/operators/distributed_ops/split_byref_op.h"
#include "paddle/fluid/operators/split_op.h" #include "paddle/fluid/operators/split_op.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/split_byref_op.h" #include "paddle/fluid/operators/distributed_ops/split_byref_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
split_byref, split_byref,
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/split_ids_op.h" #include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -22,14 +22,14 @@ limitations under the License. */ ...@@ -22,14 +22,14 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_GRPC #ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#endif #endif
USE_NO_KERNEL_OP(listen_and_serv); USE_NO_KERNEL_OP(listen_and_serv);
......
include(operators)
register_operators()
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out", REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out",
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y"); REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)"); REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)"); REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_pow_op.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_pow_op.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cmath> #include <cmath>
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_sub, "Sub", "Out = X - Y", "Out", REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_sub, "Sub", "Out = X - Y", "Out",
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
out_shape[i] = x_dims[i] * expand_times[i]; out_shape[i] = x_dims[i] * expand_times[i];
} }
// set the first dim to -1 in compile time
if (!ctx->IsRuntime()) {
out_shape[0] = x_dims[0];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) { if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
...@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("expand_times"); ctx->Attrs().Get<std::vector<int>>("expand_times");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
for (size_t i = 0; i < expand_times.size(); ++i) { size_t start_pos = 0u;
if (!ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0],
"The first dimension size of Input(Out@GRAD) should be "
"equal to the crroresponding dimension size of Input(X)");
start_pos = 1u;
}
for (size_t i = start_pos; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be " "Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension " "equal to multiplication of crroresponding dimension "
......
include(operators)
register_operators()
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h" #include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h" #include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/functors.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h" #include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h" #include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h" #include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h"
#include <algorithm> // for min, max #include <algorithm> // for min, max
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h" #include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
......
...@@ -41,6 +41,7 @@ math_library(cross_entropy) ...@@ -41,6 +41,7 @@ math_library(cross_entropy)
math_library(cos_sim_functor) math_library(cos_sim_functor)
math_library(depthwise_conv) math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
math_library(sampler)
if (NOT WIN32) # windows do not support avx functions yet. if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
......
...@@ -33,11 +33,11 @@ namespace math { ...@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define AVX_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4 #define AVX_DOUBLE_BLOCK 4
#define AVX2_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4 #define AVX2_DOUBLE_BLOCK 4
#define AVX512_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8 #define AVX512_DOUBLE_BLOCK 8
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ template <> ...@@ -88,7 +88,7 @@ template <>
inline void vec_scal<float, platform::jit::avx>(const int n, const float a, inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_scal<float, platform::jit::isa_any>(n, a, x, y); vec_scal<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -142,7 +142,7 @@ template <> ...@@ -142,7 +142,7 @@ template <>
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a, inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y); vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x, ...@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const float* y, const float* z, const float* y, const float* z,
float* out) { float* out) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out); vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
return; return;
...@@ -257,7 +257,7 @@ template <> ...@@ -257,7 +257,7 @@ template <>
inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a, inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y); vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
return; return;
...@@ -326,7 +326,7 @@ template <> ...@@ -326,7 +326,7 @@ template <>
inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x, inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_sigmoid<float, platform::jit::isa_any>(n, x, y); vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
return; return;
...@@ -415,7 +415,7 @@ template <> ...@@ -415,7 +415,7 @@ template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x, inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block * 4) { if (n < block * 4) {
vec_relu<float, platform::jit::isa_any>(n, x, y); vec_relu<float, platform::jit::isa_any>(n, x, y);
return; return;
......
...@@ -41,7 +41,7 @@ void VXXJitCode::generate() { ...@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
} else if (scalar_index_ == 2) { } else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]); vbroadcastss(ymm_src2, ptr[param2]);
} }
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src1, ptr[param1 + offset]);
} }
...@@ -57,9 +57,9 @@ void VXXJitCode::generate() { ...@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps(ymm_dst, ymm_zero, ymm_dst); vmaxps(ymm_dst, ymm_zero, ymm_dst);
} }
vmovups(ptr[param3 + offset], ymm_dst); vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
...@@ -118,18 +118,237 @@ void VXXJitCode::generate() { ...@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
ret(); ret();
} }
bool ReluJitCode::init(int d) { return MayIUse(avx); } #define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
void ReluJitCode::generate() { #define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
int offset = 0;
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0};
bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx);
if (type == operand_type::relu) {
return ok;
} else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256
return ok && d % 8 == 0 && d < 256;
} else {
// TODO(TJ): support more
return ok && d % 8 == 0;
}
}
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
vmaxps(ymm_dst, ymm_zero, ymm_src);
}
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal
ymm_t ymm_fx = ymm_t(fx_idx);
ymm_t ymm_fy = ymm_t(fy_idx);
ymm_t ymm_mask = ymm_t(mask_idx);
ymm_t ymm_tmp = ymm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
// express exp(x) as exp(g + n*log(2))
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(ymm_fx, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(ymm_fx, ymm_fx, ymm_tmp);
vroundps(ymm_fy, ymm_fx, 0x01);
// if greater, substract 1
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vandps(ymm_mask, ymm_mask, ymm_tmp);
vsubps(ymm_fx, ymm_fy, ymm_mask);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(ymm_fy, ymm_fx, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
vmulps(ymm_z, ymm_fx, ymm_tmp);
vsubps(ymm_src, ymm_src, ymm_fy);
vsubps(ymm_src, ymm_src, ymm_z);
vmulps(ymm_z, ymm_src, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(ymm_dst, ymm_src, ymm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_src);
}
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmulps(ymm_dst, ymm_dst, ymm_z);
vaddps(ymm_dst, ymm_dst, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
// build 2^n
ymm_t ymm_int = ymm_fx;
vcvttps2dq(ymm_int, ymm_fx);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
if (MayIUse(avx2)) {
vpaddd(ymm_int, ymm_int, ymm_tmp);
vpslld(ymm_int, ymm_int, 23);
} else if (MayIUse(avx)) {
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
reg64_t reg_ptr_tmp = reg_ptr_global;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
// next 128bits
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp2,
ptr[reg_ptr_tmp +
(YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
// load out
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
}
vmulps(ymm_dst, ymm_dst, ymm_int);
pop(reg_ptr_global);
}
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
// y = 1 / (1 + e^-x)
ymm_t ymm_tmp = ymm_t(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
vminps(ymm_src, ymm_src, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
vmaxps(ymm_src, ymm_src, ymm_tmp);
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
vsubps(ymm_src, ymm_tmp, ymm_src);
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
pop(reg_ptr_global);
}
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
int fy_idx, int mask_idx, int tmp_idx) {
// y = 2 / (1 + e^(-2x)) - 1
ymm_t ymm_tmp = ymm_t(tmp_idx);
ymm_t ymm_zero = ymm_t(mask_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(ymm_zero, ymm_zero, ymm_zero); vxorps(ymm_zero, ymm_zero, ymm_zero);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { vsubps(ymm_tmp, ymm_zero, ymm_tmp);
vmulps(ymm_src, ymm_src, ymm_tmp);
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vsubps(ymm_dst, ymm_dst, ymm_tmp);
pop(reg_ptr_global);
}
void VActJitCode::generate() {
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_zero = ymm_t(2);
if (type_ == operand_type::relu) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
vmaxps(ymm_dst, ymm_zero, ymm_src); switch (type_) {
case operand_type::relu:
relu_ymm(ymm_dst, ymm_src, ymm_zero);
break;
case operand_type::exp:
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
break;
}
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (type_ != operand_type::relu) {
// TODO(TJ): remove me
ret();
return;
} }
int rest = num_ % AVX_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]); vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src); vmaxps(xmm_dst, xmm_zero, xmm_src);
...@@ -151,6 +370,7 @@ void ReluJitCode::generate() { ...@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
} }
ret(); ret();
} }
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm; ...@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm; using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label; using Label = Xbyak::Label;
typedef enum { mul = 0, add } operand_type; typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
} operand_type;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode { class VXXJitCode : public JitCode {
...@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode { ...@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3); ymm_t ymm_zero = ymm_t(3);
}; };
class ReluJitCode : public JitCode { class VActJitCode : public JitCode {
public: public:
DECLARE_JIT_CODE(ReluJitCode); const char* name() const override {
explicit ReluJitCode(int d, size_t code_size = 256 * 1024, std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {} : JitCode(code_size, code_ptr), num_(d), type_(type) {}
static bool init(int d); static bool init(int d, operand_type type);
void generate() override; void generate() override;
private: protected:
// compute relu with ymm
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
const Xbyak::Ymm& zero);
// compute exp with ymm
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
// compute sigmoid with ymm
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
// compute tanh with ymm
void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
protected:
int num_; int num_;
operand_type type_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
reg64_t param2{abi_param2}; reg64_t param2{abi_param2};
xmm_t xmm_zero = xmm_t(0); xmm_t xmm_src = xmm_t(0);
xmm_t xmm_src = xmm_t(1); ymm_t ymm_src = ymm_t(0);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_zero = ymm_t(0); xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_src = ymm_t(1);
ymm_t ymm_dst = ymm_t(1); ymm_t ymm_dst = ymm_t(1);
}; };
......
...@@ -29,9 +29,9 @@ namespace jitkernel { ...@@ -29,9 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
#define AVX_FLOAT_BLOCK 8 #define XMM_FLOAT_BLOCK 4
#define AVX2_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
...@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel { ...@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel {
template <typename T> template <typename T>
class VActKernel : public Kernel { class VActKernel : public Kernel {
public: public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0; void (*Compute)(const T *, T *, int);
}; };
template <typename T> template <typename T>
class VReluKernel : public VActKernel<T> { class VReluKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T> template <typename T>
class VIdentityKernel : public VActKernel<T> { class VIdentityKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VExpKernel : public VActKernel<T> { class VExpKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VSigmoidKernel : public VActKernel<T> { class VSigmoidKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class VTanhKernel : public VActKernel<T> { class VTanhKernel : public VActKernel<T> {};
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
......
...@@ -25,10 +25,6 @@ limitations under the License. */ ...@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) { ...@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif #endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */ /* VMUL JitKernel */
template <typename T> template <typename T>
class VMulKernelImpl : public VMulKernel<T> { class VMulKernelImpl : public VMulKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VMulKernelImpl(int d) : VMulKernel<T>() { explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) { ...@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template <typename T> template <typename T>
class VAddKernelImpl : public VAddKernel<T> { class VAddKernelImpl : public VAddKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddKernelImpl(int d) : VAddKernel<T>() { explicit VAddKernelImpl(int d) : VAddKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) { ...@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template <typename T> template <typename T>
class VAddReluKernelImpl : public VAddReluKernel<T> { class VAddReluKernelImpl : public VAddReluKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) { ...@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template <typename T> template <typename T>
class VScalKernelImpl : public VScalKernel<T> { class VScalKernelImpl : public VScalKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VScalKernelImpl(int d) : VScalKernel<T>() { explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) { ...@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template <typename T> template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> { class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096)); sz > 4096 ? sz : 4096));
this->Compute = this->Compute =
...@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) { ...@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template <typename T> template <typename T>
class VReluKernelImpl : public VReluKernel<T> { class VReluKernelImpl : public VReluKernel<T> {
public: public:
DECLARE_STATIC_FUNC; JITKERNEL_DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() { explicit VReluKernelImpl(int d) : VReluKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
size_t sz = 96 /*init*/ + size_t sz = 96 /* init size */ +
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ * d / YMM_FLOAT_BLOCK * 4 /* instructions */ *
8 /*everage byte for each instruction*/; 8 /* average bytes for each instruction */;
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>(); this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return; return;
} }
...@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> { ...@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> {
this->Compute = VReluRefer<T>; this->Compute = VReluRefer<T>;
} }
void ComputeDeprecated(const T* x, T* y) const override {
VReluRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif #endif
}; };
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VReluKernelImpl<float>::useJIT(int d) { bool VReluKernelImpl<float>::useJIT(int d) {
return gen::ReluJitCode::init(d); return gen::VActJitCode::init(d, gen::operand_type::relu);
} }
#endif #endif
#undef DECLARE_STATIC_FUNC template <typename T>
inline void VIdentityRefer(const T* x, T* y, int n) {}
/* An empty JitKernel */
template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>;
}
};
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vadd, VAddKernel);
...@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); ...@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
/* An empty JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
void ComputeDeprecated(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX_FLOAT_BLOCK; \ this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \ void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \ max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \ trans_offset += this->num_; \
} \ } \
UPDATE_ALPHA(AVX_FLOAT_BLOCK) \ UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \ } \
seq_offset += this->num_; \ seq_offset += this->num_; \
} \ } \
...@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \ CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \ this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \ void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \ max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \ trans_offset += this->num_; \
} \ } \
UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \ UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \ } \
seq_offset += this->num_; \ seq_offset += this->num_; \
} \ } \
...@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \ this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \ this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \ void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(AVX512_FLOAT_BLOCK) \ INIT_ALPHA(ZMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \ /* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \ int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \ constexpr int state_trans_base_idx = 2; \
...@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \ this->num_ + j_offset), \
max_j); \ max_j); \
/* Calculate the offset of next step*/ \ /* Calculate the offset of next step*/ \
j_offset += AVX512_FLOAT_BLOCK; \ j_offset += ZMM_FLOAT_BLOCK; \
if (j == this->end_ - 1) { \ if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \ if (this->rest_ > 0) { \
j_offset += last_offset; \ j_offset += last_offset; \
......
...@@ -16,6 +16,11 @@ limitations under the License. */ ...@@ -16,6 +16,11 @@ limitations under the License. */
#include <cmath> // for exp #include <cmath> // for exp
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
...@@ -30,38 +35,238 @@ namespace math { ...@@ -30,38 +35,238 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanhRefer(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidRefer(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup
template <typename T>
void VExpMKL(const T* x, T* y, int n);
template <>
void VExpMKL<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
}
template <>
void VExpMKL<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <typename T>
void VSigmoidMKL(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
VExpMKL(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
template <typename T>
void VTanhMKL(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidMKL(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#endif
/* VExp JitKernel */ /* VExp JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T>
class VExpKernelImpl : public VExpKernel<T> { class VExpKernelImpl : public VExpKernel<T> {
public: public:
explicit VExpKernelImpl(int d) : VExpKernel<T>() { this->num_ = d; } JITKERNEL_DECLARE_STATIC_FUNC;
void ComputeDeprecated(const T* x, T* y) const override { explicit VExpKernelImpl(int d) : VExpKernel<T>() {
for (int i = 0; i < this->num_; ++i) { #ifdef PADDLE_WITH_XBYAK
y[i] = std::exp(x[i]); if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VExpMKL<T>;
return;
} }
#endif
this->Compute = VExpRefer<T>;
} }
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif
}; };
#ifdef PADDLE_WITH_XBYAK
template <>
bool VExpKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::exp);
}
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \ template <>
template <> \ bool VExpKernelImpl<float>::useMKL(int d) {
void VExpKernelImpl<float, isa, block>::ComputeDeprecated(const float* x, \ return d > 512;
float* y) const { \ }
platform::dynload::vsExp(this->num_, x, y); \
template <>
bool VExpKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VSigmoid JitKernel */
template <typename T>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if (useMKL(d)) {
this->Compute = VSigmoidMKL<T>;
return;
}
#endif
this->Compute = VSigmoidRefer<T>;
} }
#define MKL_DOUBLE(isa, block) \ #ifdef PADDLE_WITH_XBYAK
template <> \
void VExpKernelImpl<double, isa, block>::ComputeDeprecated( \ private:
const double* x, double* y) const { \ std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
platform::dynload::vdExp(this->num_, x, y); \ #endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VSigmoidKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::sigmoid);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VSigmoidKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VSigmoidKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VTanh JitKernel */
template <typename T>
class VTanhKernelImpl : public VTanhKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if (useMKL(d)) {
this->Compute = VTanhMKL<T>;
return;
}
#endif
this->Compute = VTanhRefer<T>;
} }
FOR_EACH_ISA(MKL_FLOAT, kLT8);
FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); #ifdef PADDLE_WITH_XBYAK
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); private:
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VTanhKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::tanh);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VTanhKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VTanhKernelImpl<double>::useMKL(int d) {
return true;
}
#endif #endif
REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
REGISTER_JITKERNEL(vtanh, VTanhKernel);
namespace detail { namespace detail {
#ifdef __AVX__ #ifdef __AVX__
...@@ -210,334 +415,6 @@ __m256 ExpAVX2(__m256 x) { ...@@ -210,334 +415,6 @@ __m256 ExpAVX2(__m256 x) {
#endif #endif
} // namespace detail } // namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, expisa(tmp)); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = expisa(tmp0); \
tmp1 = expisa(tmp1); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL_DEPRECATED(vexp, VExpKernel);
/* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
this->num_ = d;
vexp_ = KernelPool::Instance().template Get<VExpKernel<T>>(d);
}
void ComputeDeprecated(const T* x, T* y) const override {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < this->num_; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
vexp_->ComputeDeprecated(y, y);
for (int i = 0; i < this->num_; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
private:
std::shared_ptr<const VExpKernel<T>> vexp_;
};
#define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::ComputeDeprecated( \
const float* x, float* y) const { \
/* TODO(TJ): try to use static const*/ \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max, expisa); \
INTRI_SIGMOID(tmp1, min, max, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx2 at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID
REGISTER_JITKERNEL_DEPRECATED(vsigmoid, VSigmoidKernel);
/* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class VTanhKernelImpl : public VTanhKernel<T> {
public:
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
this->num_ = d;
vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void ComputeDeprecated(const T* x, T* y) const override {
const T a = static_cast<T>(2), b = static_cast<T>(-1);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->ComputeDeprecated(y, y);
vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(&b, y, y, this->num_);
}
private:
std::shared_ptr<const VScalKernel<T>> vscal_;
std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
};
#define INTRI_VTANH(tmp, expisa) \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0, expisa); \
INTRI_VTANH(tmp1, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vscal_ = \
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
this->rest_); \
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
this->rest_); \
} \
template <> \
void VTanhKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->ComputeDeprecated(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vscal_ = \
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
this->rest_); \
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
this->rest_); \
} \
template <> \
void VTanhKernelImpl<float, isa, kGT16>::ComputeDeprecated(const float* x, \
float* y) const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
x += this->end_; \
y += this->end_; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->ComputeDeprecated(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH
REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel);
#undef JITKERNEL_NEW_ACT_IMPL
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -15,12 +15,20 @@ limitations under the License. */ ...@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \ #define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \ template <> \
std::string ker_class##Impl<float>::name(int d) { \ std::string ker_class##Impl<float>::name(int d) { \
...@@ -86,17 +94,17 @@ namespace jitkernel { ...@@ -86,17 +94,17 @@ namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit;
// TODO(TJ): below defines are deprecated, would be remove recently // TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ #define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \ if (d < YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \ macro_(ker, dtype, isa, kLT8); \
} else if (d == AVX_FLOAT_BLOCK) { \ } else if (d == YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ8); \ macro_(ker, dtype, isa, kEQ8); \
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \ macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == AVX512_FLOAT_BLOCK) { \ } else if (d == ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \ macro_(ker, dtype, isa, kEQ16); \
} else { \ } else { \
macro_(ker, dtype, isa, kGT16); \ macro_(ker, dtype, isa, kGT16); \
} }
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ #define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
......
...@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override { T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
/* C_t = C_t-1 * fgated + cand_gated * igated */ /* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
...@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/ /* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_, d_);
act_cand_d_->ComputeDeprecated(gates, gates); act_cand_d_->Compute(gates, gates, d_);
vmul_d_->Compute(gates, gates + d_, ct, d_); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */ /* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_, d_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
...@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
} }
void ComputeH1(T* gates, T* ht) const override { void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->ComputeDeprecated(gates, gates); act_gate_d_->Compute(gates, gates, d_);
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); act_state_d_->Compute(gates + d2_, gates + d2_, d_);
vmul_d_->Compute(gates, gates + d2_, ht, d_); vmul_d_->Compute(gates, gates + d2_, ht, d_);
} }
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
act_gate_d2_->ComputeDeprecated(gates, gates); act_gate_d2_->Compute(gates, gates, d2_);
vmul_d_->Compute(ht_1, gates + d_, ht, d_); vmul_d_->Compute(ht_1, gates + d_, ht, d_);
} }
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_; T* y = gates + d2_;
act_state_d_->ComputeDeprecated(y, y); act_state_d_->Compute(y, y, d_);
// out = zt*ht~ + (1-zt)*ht_1 // out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) { for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i]; ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
......
...@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) { ...@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); // ker->Compute(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -222,7 +223,7 @@ void vsigmoid_better( ...@@ -222,7 +223,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i]; y[i] = 0.f - y[i];
} }
vexp->ComputeDeprecated(y, y); vexp->Compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]); y[i] = 1.f / (1.f + y[i]);
} }
...@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -287,7 +288,7 @@ void vtanh_better( ...@@ -287,7 +288,7 @@ void vtanh_better(
const int n, const float* x, float* y) { const int n, const float* x, float* y) {
const float a = 2.f, b = -1.f; const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n); vscal->Compute(&a, x, y, n);
vsigmoid->ComputeDeprecated(y, y); vsigmoid->Compute(y, y, n);
vscal->Compute(&a, y, y, n); vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n); vaddbias->Compute(&b, y, y, n);
} }
...@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) { ...@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS(); auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data); ker->Compute(x_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -344,8 +345,8 @@ void lstm_ctht_ref( ...@@ -344,8 +345,8 @@ void lstm_ctht_ref(
const std::shared_ptr< const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1, const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) { const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->ComputeDeprecated(gates, gates); vtanh_d->Compute(gates, gates, d);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
...@@ -355,7 +356,7 @@ void lstm_ctht_ref( ...@@ -355,7 +356,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated // H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2; float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->ComputeDeprecated(&tmp, &tmp); vexp_1->Compute(&tmp, &tmp, 1);
tmp = 2.f / (1.f + tmp) - 1.f; tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k]; ht[k] = tmp * o[k];
} }
...@@ -373,13 +374,13 @@ void lstm_ctht_better( ...@@ -373,13 +374,13 @@ void lstm_ctht_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d, const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
const int d, float* gates, const float* ct_1, float* ct, float* ht) { const int d, float* gates, const float* ct_1, float* ct, float* ht) {
int d2 = d * 2; int d2 = d * 2;
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->ComputeDeprecated(gates, gates); vtanh_d->Compute(gates, gates, d);
vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct, d); vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
vtanh_d->ComputeDeprecated(ct, gates + d2); vtanh_d->Compute(ct, gates + d2, d);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d); vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
} }
...@@ -736,7 +737,7 @@ void vaddrelu_better( ...@@ -736,7 +737,7 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu, const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z, int d) { const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d); vadd->Compute(x, y, z, d);
vrelu->ComputeDeprecated(z, z); vrelu->Compute(z, z, d);
} }
TEST(JitKernel, vaddrelu) { TEST(JitKernel, vaddrelu) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,52 +13,46 @@ See the License for the specific language governing permissions and ...@@ -13,52 +13,46 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include <iostream>
#include <queue>
#include <utility>
#include <vector>
namespace paddle { namespace paddle {
namespace random { namespace operators {
namespace math {
Sampler::~Sampler() {} Sampler::~Sampler() {}
UniformSampler::UniformSampler(int64 range) UniformSampler::UniformSampler(int64_t range, unsigned int seed)
: Sampler(range), inv_range_(1.0 / range) { : Sampler(range, seed), inv_range_(1.0 / (range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_); random_engine_ = std::make_shared<std::mt19937_64>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
} }
UniformSampler::UniformSampler(int64 range, unsigned int seed) int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
: Sampler(range, seed), inv_range_(1.0 / range) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
}
int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
float UniformSampler::Probability(int64 value) const { return inv_range_; } float UniformSampler::Probability(int64_t value) const { return inv_range_; }
LogUniformSampler::LogUniformSampler(int64 range) LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed)
: Sampler(range), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
}
LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed)
: Sampler(range, seed), log_range_(log(range + 1)) { : Sampler(range, seed), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_); random_engine_ = std::make_shared<std::mt19937_64>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
} }
int64 LogUniformSampler::Sample() const {
int64_t LogUniformSampler::Sample() const {
// Got Log Uniform distribution from uniform distribution by // Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method // inverse_transform_sampling method
// More details: // More details:
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
const int64 value = const int64_t value =
static_cast<int64>(exp((*dist_)(*random_engine_) * log_range_)) - 1; static_cast<int64_t>(exp((*dist_)(*random_engine_) * log_range_)) - 1;
// Mathematically, value should be <= range_, but might not be due to some // Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_. // floating point roundoff, so we mod by range_.
return value % range_; return value % range_;
} }
float LogUniformSampler::Probability(int64 value) const { float LogUniformSampler::Probability(int64_t value) const {
// Given f(x) = 1/[(x+1) * log_range_] // Given f(x) = 1/[(x+1) * log_range_]
// The value's probability is integral of f(x) from value to (value + 1) // The value's probability is integral of f(x) from value to (value + 1)
// More details: // More details:
...@@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const { ...@@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const {
return (log((value + 2.0) / (value + 1.0))) / log_range_; return (log((value + 2.0) / (value + 1.0))) / log_range_;
} }
} // namespace random CustomSampler::CustomSampler(int64_t range, const float* probabilities,
unsigned int seed)
: Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
alias_probs_ = std::make_shared<std::vector<float>>(range + 1);
alias_ = std::make_shared<std::vector<int64_t>>(range + 1);
probs_ = std::make_shared<std::vector<float>>(range + 1);
std::queue<std::pair<int64_t, float>> bigs;
std::queue<std::pair<int64_t, float>> littles;
for (int64_t i = 0; i <= range; ++i) {
(*probs_)[i] = probabilities[i];
float normal_prob = probabilities[i] * (range + 1);
if (normal_prob - 1.0 > 1e-4) {
bigs.emplace(i, normal_prob);
} else if (1.0 - normal_prob > 1e-4) {
littles.emplace(i, normal_prob);
} else {
(*alias_probs_)[i] = normal_prob;
(*alias_)[i] = -1;
}
}
while ((!littles.empty()) && (!bigs.empty())) {
auto big = bigs.front();
auto little = littles.front();
bigs.pop();
littles.pop();
(*alias_probs_)[little.first] = little.second;
(*alias_)[little.first] = big.first;
auto big_left = big.second - (1 - little.second);
if (big_left - 1.0 > 1e-4) {
bigs.emplace(big.first, big_left);
} else if (1.0 - big_left > 1e-4) {
littles.emplace(big.first, big_left);
} else {
(*alias_probs_)[big.first] = big_left;
(*alias_)[big.first] = -1;
}
}
if (!littles.empty()) { // littles.second is close to 1.0
auto little = littles.front();
(*alias_probs_)[little.first] = 1.0;
(*alias_)[little.first] = -1;
}
if (!bigs.empty()) { // bigs.second is close to 1.0
auto big = bigs.front();
(*alias_probs_)[big.first] = 1.0;
(*alias_)[big.first] = -1;
}
}
int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
if (p > (*alias_probs_)[index]) {
return (*alias_)[index];
} else {
return index;
}
}
float CustomSampler::Probability(int64_t value) const {
return (*probs_)[value];
}
} // namespace math
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <random> #include <random>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -27,14 +29,14 @@ namespace math { ...@@ -27,14 +29,14 @@ namespace math {
*/ */
class Sampler { class Sampler {
public: public:
explicit Sampler(int64_t range) : range_(range) { explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
PADDLE_ENFORCE_GT(range, 0); // PADDLE_ENFORCE_GT(range, 0, "Range should be greater than 0.");
std::random_device r; if (seed == 0) {
seed_ = r(); std::random_device r;
} seed_ = r();
explicit Sampler(int64_t range, unsigned int seed) } else {
: range_(range), seed_(seed) { seed_ = seed;
PADDLE_ENFORCE_GT(range, 0); }
} }
virtual ~Sampler(); virtual ~Sampler();
// Sample a single value // Sample a single value
...@@ -42,7 +44,7 @@ class Sampler { ...@@ -42,7 +44,7 @@ class Sampler {
// The probability that a single call to Sample() returns the given value. // The probability that a single call to Sample() returns the given value.
virtual float Probability(int64_t value) const = 0; virtual float Probability(int64_t value) const = 0;
int64 range() { return range_; } int64_t range() { return range_; }
protected: protected:
const int64_t range_; const int64_t range_;
...@@ -56,13 +58,11 @@ class Sampler { ...@@ -56,13 +58,11 @@ class Sampler {
*/ */
class UniformSampler : public Sampler { class UniformSampler : public Sampler {
public: public:
explicit UniformSampler(int64_t range); explicit UniformSampler(int64_t range, unsigned int seed = 0UL);
explicit UniformSampler(int64_t range, unsigned int seed);
~UniformSampler() override {} ~UniformSampler() override {}
int64 Sample() const override; int64_t Sample() const override;
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
...@@ -79,13 +79,11 @@ class UniformSampler : public Sampler { ...@@ -79,13 +79,11 @@ class UniformSampler : public Sampler {
*/ */
class LogUniformSampler : public Sampler { class LogUniformSampler : public Sampler {
public: public:
explicit LogUniformSampler(int64_t range); explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL);
explicit LogUniformSampler(int64_t range, unsigned int seed);
~LogUniformSampler() override {} ~LogUniformSampler() override {}
int64 Sample() const override; int64_t Sample() const override;
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
...@@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler { ...@@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler {
std::shared_ptr<std::uniform_real_distribution<>> dist_; std::shared_ptr<std::uniform_real_distribution<>> dist_;
}; };
/**
* Sample integers from [0, range) from custom distribution.
*/
class CustomSampler : public Sampler {
public:
explicit CustomSampler(int64_t range, const float* probabilities,
unsigned int seed = 0UL);
~CustomSampler() override {}
int64_t Sample() const override;
float Probability(int64_t value) const override;
private:
std::shared_ptr<std::vector<float>> alias_probs_;
std::shared_ptr<std::vector<int64_t>> alias_;
std::shared_ptr<std::vector<float>> probs_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
include(operators)
register_operators()
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/accuracy_op.h" #include "paddle/fluid/operators/metrics/accuracy_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include "paddle/fluid/operators/accuracy_op.h" #include "paddle/fluid/operators/metrics/accuracy_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/auc_op.h" #include "paddle/fluid/operators/metrics/auc_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/precision_recall_op.h" #include "paddle/fluid/operators/metrics/precision_recall_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
if(WITH_GPU AND NOT WIN32) if(WITH_GPU AND NOT WIN32)
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator ) nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
endif() endif()
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n")
set(OPERATOR_DEPS ${OPERATOR_DEPS} nccl_common PARENT_SCOPE)
endif()
if(NOT WIN32)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
...@@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
...@@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"each sample. And it is a dispensable input. The default value of " "each sample. And it is a dispensable input. The default value of "
"sample is 1.") "sample is 1.")
.AsDispensable(); .AsDispensable();
AddInput(
"CustomDistribution",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddOutput("Cost", AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits", AddOutput("SampleLogits",
...@@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("num_neg_samples", AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.") "The number of negative classes. The default value is 10.")
.SetDefault(10); .SetDefault(10);
AddAttr<int>("sampler",
"(int) Which sampler to be used to sample negative class."
"0: Uniform; 1: LogUniform; 2: CostumDist.")
.SetDefault(0);
AddAttr<int>("seed",
"(int) The seed used in sampler. If it is 0, "
"the sampler will generate a seed randomly.")
.SetDefault(0);
AddAttr<std::vector<int>>("custom_neg_classes", AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes " "This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes " "in this list wiil be used as negative classes "
......
...@@ -19,29 +19,28 @@ limitations under the License. */ ...@@ -19,29 +19,28 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using Sampler = math::Sampler;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext& context) { void PrepareSamples(const framework::ExecutionContext& context,
Sampler* sampler) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
int num_total_classes = context.Attr<int>("num_total_classes"); // int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest // for unitest
std::vector<int> custom_neg_classes = std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes"); context.Attr<std::vector<int>>("custom_neg_classes");
// random machine
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
...@@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { ...@@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) {
} else { } else {
for (; j < sample_labels_dims[1]; ++j) { for (; j < sample_labels_dims[1]; ++j) {
// TODO(wanghaoshuang): support more distribution sampling // TODO(wanghaoshuang): support more distribution sampling
sample_labels_data[index++] = rand(rng); sample_labels_data[index++] = sampler->Sample();
} }
} }
} }
...@@ -72,7 +71,33 @@ template <typename DeviceContext, typename T> ...@@ -72,7 +71,33 @@ template <typename DeviceContext, typename T>
class NCEKernel : public framework::OpKernel<T> { class NCEKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PrepareSamples<DeviceContext, T>(context); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples");
Sampler* sampler;
switch (sampler_type) {
case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed);
break;
}
case 1: {
sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
break;
}
case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution");
const float* custom_dist_data = custom_dist->data<float>();
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed);
break;
}
default: { PADDLE_THROW("Unsupported SamplerType."); }
}
PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t* sample_labels_data = sample_labels->data<int64_t>();
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
...@@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel<T> {
} }
auto out = context.Output<Tensor>("Cost"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_total_classes = context.Attr<int>("num_total_classes");
int64_t num_true_class = 1; int64_t num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_total_classes * num_neg_samples; int64_t sampled_labels_num = sample_labels->dims()[1];
// T b = 1. / num_total_classes * num_neg_samples;
// forward bias // forward bias
auto bias = context.Input<Tensor>("Bias"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
...@@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel<T> {
} }
// forward cost // forward cost
for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
int64_t j = 0;
out_data[i] = 0; out_data[i] = 0;
T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
// for true classes for (int64_t j = 0; j < sampled_labels_num; ++j) {
for (; j < num_true_class; ++j) { int64_t target = sample_labels_data[i * sampled_labels_num + j];
T o = sample_out_data[i * sample_out->dims()[1] + j]; T o = sample_out_data[i * sampled_labels_num + j];
T cost = -log(o / (o + b)); float b = sampler->Probability(target) * num_neg_samples;
out_data[i] += w * cost; T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
}
// for sampled neg classes
for (; j < sample_labels->dims()[1]; ++j) {
T o = sample_out_data[i * sample_out->dims()[1] + j];
T cost = -log(b / (o + b));
out_data[i] += w * cost; out_data[i] += w * cost;
} }
} }
delete sampler;
} }
}; };
...@@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel<T> {
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_total_classes * num_neg_samples;
int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed");
Sampler* sampler;
switch (sampler_type) {
case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed);
break;
}
case 1: {
sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
break;
}
case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution");
const float* custom_dist_data = custom_dist->data<float>();
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed);
break;
}
default: { PADDLE_THROW("Unsupported SamplerType."); }
}
// T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T* sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
// backward cost // backward cost
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
int64_t label_idx = i % sample_labels->dims()[1];
int64_t sample_idx = i / sample_labels->dims()[1];
float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
T o = sample_out_data[i]; T o = sample_out_data[i];
T w = sample_weight == nullptr T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
? 1 sample_grad_data[i] = label_idx < num_true_class
: sample_weight_data[i / sample_labels->dims()[1]];
sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class
? w * (b / (o + b)) * (o - 1) ? w * (b / (o + b)) * (o - 1)
: w * (o * (1 - o) / (o + b)); : w * (o * (1 - o) / (o + b));
sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; sample_grad_data[i] *= d_out_data[sample_idx];
} }
// get d_bias // get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias")); auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
...@@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
} }
} }
delete sampler;
} }
}; };
} // namespace operators } // namespace operators
......
include(operators)
register_operators()
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/adadelta_op.h" #include "paddle/fluid/operators/optimizers/adadelta_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/adadelta_op.h" #include "paddle/fluid/operators/optimizers/adadelta_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/adagrad_op.h" #include "paddle/fluid/operators/optimizers/adagrad_op.h"
#include <vector> #include <vector>
#include <cmath> #include <cmath>
......
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/adagrad_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/optimizers/adagrad_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/adam_op.h" #include "paddle/fluid/operators/optimizers/adam_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/adam_op.h" #include "paddle/fluid/operators/optimizers/adam_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/adamax_op.h" #include "paddle/fluid/operators/optimizers/adamax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/adamax_op.h" #include "paddle/fluid/operators/optimizers/adamax_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/decayed_adagrad_op.h" #include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/decayed_adagrad_op.h" #include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/ftrl_op.h" #include "paddle/fluid/operators/optimizers/ftrl_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the ...@@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */ specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/ftrl_op.h" #include "paddle/fluid/operators/optimizers/ftrl_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/lars_momentum_op.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/operators/momentum_op.h" #include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lars_momentum_op.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/momentum_op.h" #include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/momentum_op.h" #include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/proximal_adagrad_op.h" #include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the ...@@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */ specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/proximal_adagrad_op.h" #include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/proximal_gd_op.h" #include "paddle/fluid/operators/optimizers/proximal_gd_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the ...@@ -12,7 +12,7 @@ CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */ specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/proximal_gd_op.h" #include "paddle/fluid/operators/optimizers/proximal_gd_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/rmsprop_op.h" #include "paddle/fluid/operators/optimizers/rmsprop_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/rmsprop_op.h" #include "paddle/fluid/operators/optimizers/rmsprop_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sgd_op.h" #include "paddle/fluid/operators/optimizers/sgd_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/sgd_op.h" #include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
include(operators)
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader) cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
set(LOCAL_READER_LIBS) set(LOCAL_READER_LIBS)
...@@ -28,4 +30,10 @@ reader_library(create_py_reader_op SRCS create_py_reader_op.cc) ...@@ -28,4 +30,10 @@ reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
# Export local libraries to parent # Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) # set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
op_library(read_op)
foreach(src ${LOCAL_READER_LIBS})
set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs")
endforeach()
include(operators)
register_operators()
if(WITH_GPU)
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.part.cu")
string(REPLACE ".part.cu" "" OPS "${OPS}")
foreach(src ${OPS})
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu)
set(CUDA_KERNEL_FILE ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu)
file(READ ${CUDA_KERNEL_FILE} TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OP_CUDA_KERNEL\\(\\n?([^,]+),.*" MATCHED ${TARGET_CONTENT})
if (MATCHED)
string(STRIP ${CMAKE_MATCH_1} MATCHED)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MATCHED}, CUDA);\n")
endif()
endif()
endforeach()
endif()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_max); REGISTER_REDUCE_OP(reduce_max);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_max, REGISTER_OP_CUDA_KERNEL(reduce_max,
ops::ReduceKernel<paddle::platform::CUDADeviceContext, ops::ReduceKernel<paddle::platform::CUDADeviceContext,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_mean_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
REGISTER_REDUCE_OP(reduce_mean); REGISTER_REDUCE_OP(reduce_mean);
REGISTER_OP_CPU_KERNEL(reduce_mean, REGISTER_OP_CPU_KERNEL(reduce_mean,
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include <vector> #include <vector>
#include "paddle/fluid/operators/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_mean_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/operators/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
// .part used to speed up nvcc compile // .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_mean_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_mean_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, reduce_mean_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/operators/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_min); REGISTER_REDUCE_OP(reduce_min);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_min, REGISTER_OP_CUDA_KERNEL(reduce_min,
ops::ReduceKernel<paddle::platform::CUDADeviceContext, ops::ReduceKernel<paddle::platform::CUDADeviceContext,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_min_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, reduce_min_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/reduce_op_function.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_prod_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
REGISTER_REDUCE_OP(reduce_prod); REGISTER_REDUCE_OP(reduce_prod);
REGISTER_OP_CPU_KERNEL(reduce_prod, REGISTER_OP_CPU_KERNEL(reduce_prod,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_prod_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_prod, REGISTER_OP_CUDA_KERNEL(reduce_prod,
ops::ReduceKernel<paddle::platform::CUDADeviceContext, ops::ReduceKernel<paddle::platform::CUDADeviceContext,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/operators/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_prod_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_prod_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, reduce_prod_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_sum_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
REGISTER_REDUCE_OP(reduce_sum); REGISTER_REDUCE_OP(reduce_sum);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_sum_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_sum_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, reduce_sum_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
......
include(operators)
register_operators()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h"
#include <vector> #include <vector>
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h"
template <typename T> template <typename T>
using Kernel = using Kernel =
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_conv_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h"
#include <algorithm> #include <algorithm>
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_conv_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_enumerate_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/fluid/operators/sequence_enumerate_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_erase_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_erase_op.h"
#include <vector> #include <vector>
namespace paddle { namespace paddle {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/fluid/operators/sequence_erase_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_erase_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_expand_as_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/sequence_expand_as_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_expand_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/sequence_expand_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_mask_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h"
REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp, REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp,
paddle::operators::SequenceMaskOpMaker, paddle::operators::SequenceMaskOpMaker,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_mask_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sequence_mask, sequence_mask,
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_pad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_pad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_pool_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
#include <string> #include <string>
namespace paddle { namespace paddle {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/sequence_pool_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_reshape_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_reshape_op.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_reshape_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_reshape_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_reverse_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_reverse_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_scatter_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_scatter_op.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_slice_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_slice_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_softmax_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
#include <string> #include <string>
namespace paddle { namespace paddle {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/sequence_softmax_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_unpad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_unpad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -86,7 +86,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,7 +86,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
.GreaterThan(1); .GreaterThan(1);
AddComment(R"DOC( AddComment(R"DOC(
reorg operator used in Yolo v2. reorg operator used in Yolo v2.
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize, The equation is: C2 = C1/blocksize * blocksize, W2 = W1 * blocksize + offset % blocksize, H2 = H1 * blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(blocksize). The Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged. data in Input(X) are unchanged.
......
op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n")
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op
analysis)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/warpctc_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
#if CUDNN_VERSION >= 7001
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedCTCLossDescriptor = platform::ScopedCTCLossDescriptor;
using DataLayout = platform::DataLayout;
template <typename DeviceContext, typename T>
class CudnnCTCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// =====================Copied code from warpctc===========================
auto* logits = ctx.Input<LoDTensor>("Logits");
auto* label = ctx.Input<LoDTensor>("Label");
auto* warpctc_grad = ctx.Output<LoDTensor>("WarpCTCGrad");
auto* loss = ctx.Output<LoDTensor>("Loss");
const size_t level = 0;
auto logits_lod = framework::ToAbsOffset(logits->lod());
auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ(logits_dims[0],
static_cast<int64_t>(logits_lod[level].back()),
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths.");
auto label_lod = framework::ToAbsOffset(label->lod());
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(
label_dims[0], label->numel(),
"The width of each timestep in Input(Label) should be 1.");
const size_t num_sequences = logits_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label).");
PADDLE_ENFORCE_LE(num_sequences, 256,
"The labelLengths must less than 256 for cudnn call.");
const size_t sequence_width = logits->numel() / logits_dims[0];
auto loss_dims =
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
// NOTE: cudnn takes softmax input, calculate softmax first, then do padding
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
LoDTensor softmax_logits;
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
softmax_logits.set_lod(logits_lod);
int rank = logits->dims().size();
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d);
// ctc needs sequences data stored in transposed padding format
// logits and grad using padding data of layout 'TNC'
// T: max_sequence_length
// N: batch_size (num_sequences)
// C: width
LoDTensor warpctc_logits;
const size_t max_sequence_length =
math::MaximumSequenceLength(logits_lod[level]);
auto warpctc_logits_dims =
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
LoDTensor cpu_pad_value;
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
LoDTensor pad_value;
if (platform::is_cpu_place(ctx.GetPlace())) {
pad_value = cpu_pad_value;
} else {
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
}
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), softmax_logits,
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
math::kLengthBatchWidth);
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
warpctc_logits_lengths[i] =
logits_lod[level][i + 1] - logits_lod[level][i];
}
T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), warpctc_grad,
static_cast<T>(0));
Tensor warpctc_label;
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>();
// ========================================================================
ScopedTensorDescriptor logits_desc;
ScopedTensorDescriptor grad_desc;
ScopedCTCLossDescriptor ctcloss_desc;
// layout here doesn't have effect.
DataLayout layout = DataLayout::kNCHW;
auto cu_logits_desc = logits_desc.descriptor<T>(
layout, framework::vectorize2int(warpctc_logits.dims()));
auto cu_grad_desc = grad_desc.descriptor<T>(
layout, framework::vectorize2int(warpctc_grad->dims()));
auto cu_ctcloss_desc = ctcloss_desc.descriptor<T>();
auto handle = dev_ctx.cudnn_handle();
size_t workspace_size;
CUDNN_ENFORCE(platform::dynload::cudnnGetCTCLossWorkspaceSize(
handle, cu_logits_desc, cu_grad_desc, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss(
handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
loss_data, cu_grad_desc, warpctc_grad_data,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace,
workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
}
};
template <typename DeviceContext, typename T>
class CudnnCTCGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss_grad_data,
logits_grad);
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if CUDNN_VERSION >= 7001
REGISTER_OP_KERNEL(
warpctc, CUDNN, plat::CUDAPlace,
ops::CudnnCTCKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_KERNEL(
warpctc_grad, CUDNN, plat::CUDAPlace,
ops::CudnnCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
#endif
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/warpctc_op.h" #include "paddle/fluid/operators/warpctc_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -45,9 +49,16 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -45,9 +49,16 @@ class WarpCTCOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()), framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context()); ctx.device_context(), layout_, library_);
} }
}; };
...@@ -86,6 +97,10 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,6 +97,10 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
"normalize the gradients by the number of time-step, " "normalize the gradients by the number of time-step, "
"which is also the sequence's length.") "which is also the sequence's length.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_cudnn",
"(bool, default: false), whether to "
"use cudnn kernel.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
An operator integrating the open-source An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in [warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
......
...@@ -380,5 +380,28 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { ...@@ -380,5 +380,28 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
return use_cudnn; return use_cudnn;
} }
#if CUDNN_VERSION >= 7001
class ScopedCTCLossDescriptor {
public:
ScopedCTCLossDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateCTCLossDescriptor(&desc_));
}
~ScopedCTCLossDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyCTCLossDescriptor(desc_));
}
template <typename T>
inline cudnnCTCLossDescriptor_t descriptor() {
PADDLE_ENFORCE(
dynload::cudnnSetCTCLossDescriptor(desc_, CudnnDataType<T>::type));
return desc_;
}
private:
cudnnCTCLossDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedCTCLossDescriptor);
};
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -154,7 +154,13 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -154,7 +154,13 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionGroupCount); \ __macro(cudnnSetConvolutionGroupCount); \
__macro(cudnnSetConvolutionMathType); __macro(cudnnSetConvolutionMathType); \
__macro(cudnnCreateCTCLossDescriptor); \
__macro(cudnnDestroyCTCLossDescriptor); \
__macro(cudnnGetCTCLossDescriptor); \
__macro(cudnnSetCTCLossDescriptor); \
__macro(cudnnGetCTCLossWorkspaceSize); \
__macro(cudnnCTCLoss);
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
......
...@@ -11,12 +11,12 @@ if(WITH_PYTHON) ...@@ -11,12 +11,12 @@ if(WITH_PYTHON)
hip_library(paddle_pybind SHARED hip_library(paddle_pybind SHARED
SRCS ${PYBIND_SRCS} SRCS ${PYBIND_SRCS}
DEPS ${PYBIND_DEPS} DEPS ${PYBIND_DEPS}
${GLOB_OP_LIB}) ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
else() else()
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS ${PYBIND_SRCS} SRCS ${PYBIND_SRCS}
DEPS ${PYBIND_DEPS} DEPS ${PYBIND_DEPS}
${GLOB_OP_LIB}) ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
if(NOT APPLE AND NOT ANDROID AND NOT WIN32) if(NOT APPLE AND NOT ANDROID AND NOT WIN32)
target_link_libraries(paddle_pybind rt) target_link_libraries(paddle_pybind rt)
endif(NOT APPLE AND NOT ANDROID AND NOT WIN32) endif(NOT APPLE AND NOT ANDROID AND NOT WIN32)
......
...@@ -30,36 +30,32 @@ limitations under the License. */ ...@@ -30,36 +30,32 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
void BindAsyncExecutor(py::module* m) { void BindAsyncExecutor(py::module* m) {
py::class_<framework::DataFeed>(*m, "DataFeed"); py::class_<pd::DataFeedDesc>(*m, "DataFeedDesc")
py::class_<framework::TextClassDataFeed, .def(pybind11::init<>())
framework::DataFeed>(*m, "TextDataFeed") .def("set_name", (set_name_func)&pd::DataFeedDesc::set_name)
.def(py::init()) .def("set_batch", &pd::DataFeedDesc::set_batch)
.def("set_filelist", .def("set_field_names",
[] (framework::TextClassDataFeed &self, const char *data_list_file) { [] (pd::DataFeedDesc& self, const std::vector<std::string> &fields) {
self.SetFileList(data_list_file); for (auto field : fields) {
}) self.add_field_names(field);
.def("set_batch_size", &framework::TextClassDataFeed::SetBatchSize) }
.def("set_field_names", &framework::TextClassDataFeed::SetFieldNames) });
.def("start_one_epoch", &framework::TextClassDataFeed::StartOneEpoch);
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor") py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<framework::ProgramDesc&, .def(py::init<pd::Scope&, const platform::Place&>())
std::vector<std::string>&, .def("run_from_files", &framework::AsyncExecutor::RunFromFile)
framework::TextClassDataFeed&, .def("check_file", &framework::AsyncExecutor::CheckFiles);
unsigned int,
const platform::Place&>())
.def("init_root_scope", &framework::AsyncExecutor::InitRootScope)
.def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram)
.def("run", &framework::AsyncExecutor::Run);
} // end BindAsyncExecutor } // end BindAsyncExecutor
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
......
...@@ -19,30 +19,26 @@ import contextlib ...@@ -19,30 +19,26 @@ import contextlib
import six import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
from . import Executor from .executor import global_scope
__all__ = ['TextDataFeed', 'AsyncExecutor'] __all__ = ['MultiSlotDataFeed', 'AsyncExecutor']
g_scope = core.Scope() g_scope = core.Scope()
class TextDataFeed(): class DataFeedDesc(object):
def __init__(self): def __init__(self):
self.feed = core.TextDataFeed() self.desc = core.DataFeedDesc()
def set_filelist(self, filelist):
self.feed.set_filelist(filelist)
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.feed.set_batch_size(batch_size) self.desc.set_batch(batch_size)
def set_field_name(self, field_names):
def set_field_names(self, field_names): if isinstance(field_names, str):
if isinstance(field_names, Variable):
field_names = [field_names] field_names = [field_names]
self.desc.set_field_names(field_names)
self.feed.set_field_names(field_names) class MultiSlotDataFeed(DataFeedDesc):
def __init__(self):
def start_an_epoch(self): super(MultiSlotDataFeed, self).__init__()
self.feed.start_one_epoch() self.desc.set_name("MultiSlotDataFeed")
class AsyncExecutor(object): class AsyncExecutor(object):
""" """
...@@ -55,45 +51,19 @@ class AsyncExecutor(object): ...@@ -55,45 +51,19 @@ class AsyncExecutor(object):
They has the exactly same arguments, and expected the same results. They has the exactly same arguments, and expected the same results.
""" """
def __init__(self, def __init__(self, place=None):
program,
param_names,
data_feed,
thread_num,
place=None,
scope=None):
if program is None:
program = default_main_program()
program_desc = program.desc
if not isinstance(data_feed, TextDataFeed):
raise ValueError("data_feed for AsyncExecutor.run() type error")
if place is None: if place is None:
place = core.CPUPlace() place = core.CPUPlace()
if not isinstance(place, core.CPUPlace): if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device") raise ValueError("AsyncExecutor only supports CPU device")
if isinstance(param_names, Variable):
param_names = [param_names]
p = core.Place() p = core.Place()
p.set_place(place) p.set_place(place)
self.executor = core.AsyncExecutor(program_desc, param_names, data_feed.feed, thread_num, p)
def run_startup_program(self,
program=None,
scope=None):
if program is None:
program = default_startup_program()
program_desc = program._get_desc()
if scope is None: scope = global_scope()
scope = g_scope self.executor = core.AsyncExecutor(scope, p)
self.executor.run_startup_program(program_desc, scope) def run(self, program, data_feed, filelist, thread_num, fetch):
def run(self, inspect_vars, scope=None):
""" """
Run program by this Executor. Feed data by feed map, fetch result by fetch_list. Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according Python executor takes a program, add feed operators and fetch operators to this program according
...@@ -136,16 +106,27 @@ class AsyncExecutor(object): ...@@ -136,16 +106,27 @@ class AsyncExecutor(object):
>>> feed={'X': x}, >>> feed={'X': x},
>>> fetch_list=[loss.name]) >>> fetch_list=[loss.name])
""" """
if inspect_vars is not None: if program is None:
if isinstance(inspect_vars, Variable): program = default_main_program()
inspect_vars = [inspect_vars] program_desc = program.desc
inspect_var_names = [var.name for var in inspect_vars]
if data_feed is None:
raise ValueError('ValueError: data_feed should be provided')
if filelist is None:
raise ValueError('ValueError: filelist should be provided')
if isinstance(filelist, str):
filelist = [filelist]
if scope is None: if not isinstance(thread_num, int):
scope = g_scope raise TypeError('TypeError: thread_num should be a positive number')
self.executor.init_root_scope(scope) if fetch is not None:
if isinstance(fetch, Variable):
fetch = [fetch]
fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run(inspect_var_names) evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names)
return evaluation return evaluation
...@@ -4187,7 +4187,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4187,7 +4187,7 @@ def ctc_greedy_decoder(input, blank, name=None):
return ctc_out return ctc_out
def warpctc(input, label, blank=0, norm_by_times=False): def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False):
""" """
An operator integrating the open source Warp-CTC library An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc) (https://github.com/baidu-research/warp-ctc)
...@@ -4212,6 +4212,7 @@ def warpctc(input, label, blank=0, norm_by_times=False): ...@@ -4212,6 +4212,7 @@ def warpctc(input, label, blank=0, norm_by_times=False):
by the number of time-step, which is also the sequence's length. by the number of time-step, which is also the sequence's length.
There is no need to normalize the gradients if warpctc layer was There is no need to normalize the gradients if warpctc layer was
follewed by a mean_op. follewed by a mean_op.
use_cudnn (bool, default false): Whether to use cudnn.
Returns: Returns:
Variable: The Connectionist Temporal Classification (CTC) loss, Variable: The Connectionist Temporal Classification (CTC) loss,
...@@ -4235,8 +4236,11 @@ def warpctc(input, label, blank=0, norm_by_times=False): ...@@ -4235,8 +4236,11 @@ def warpctc(input, label, blank=0, norm_by_times=False):
'Label': [label]}, 'Label': [label]},
outputs={'WarpCTCGrad': [grad_out], outputs={'WarpCTCGrad': [grad_out],
'Loss': [loss_out]}, 'Loss': [loss_out]},
attrs={'blank': blank, attrs={
'norm_by_times': norm_by_times}) 'blank': blank,
'norm_by_times': norm_by_times,
'use_cudnn': use_cudnn
})
return loss_out return loss_out
...@@ -4309,7 +4313,10 @@ def nce(input, ...@@ -4309,7 +4313,10 @@ def nce(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
num_neg_samples=None, num_neg_samples=None,
name=None): name=None,
sampler="uniform",
custom_dist=None,
seed=0):
""" """
${comment} ${comment}
...@@ -4332,6 +4339,14 @@ def nce(input, ...@@ -4332,6 +4339,14 @@ def nce(input,
num_neg_samples (int): ${num_neg_samples_comment} num_neg_samples (int): ${num_neg_samples_comment}
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes].
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
seed (int): The seed used in sampler. default: 0.
Returns: Returns:
Variable: The output nce loss. Variable: The output nce loss.
...@@ -4361,6 +4376,16 @@ def nce(input, ...@@ -4361,6 +4376,16 @@ def nce(input,
loss = layers.nce(input=embs, label=words[label_word], loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w', num_total_classes=dict_size, param_attr='nce.w',
bias_attr='nce.b') bias_attr='nce.b')
#or use custom distribution
dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32"))
loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w',
bias_attr='nce.b',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
""" """
helper = LayerHelper('nce', **locals()) helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable) assert isinstance(input, Variable)
...@@ -4395,9 +4420,31 @@ def nce(input, ...@@ -4395,9 +4420,31 @@ def nce(input,
else: else:
num_neg_samples = int(num_neg_samples) num_neg_samples = int(num_neg_samples)
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if sampler == "uniform":
sampler = 0
elif sampler == "log_uniform":
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
assert isinstance(custom_dist, Variable)
inputs['CustomDistribution'] = custom_dist
sampler = 2
else:
raise Exception("Unsupported sampler type.")
attrs = { attrs = {
'num_total_classes': int(num_total_classes), 'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples 'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler
} }
helper.append_op( helper.append_op(
......
...@@ -38,7 +38,7 @@ depth = 8 ...@@ -38,7 +38,7 @@ depth = 8
mix_hidden_lr = 1e-3 mix_hidden_lr = 1e-3
IS_SPARSE = True IS_SPARSE = True
PASS_NUM = 1 PASS_NUM = 2
BATCH_SIZE = 10 BATCH_SIZE = 10
embedding_name = 'emb' embedding_name = 'emb'
...@@ -196,7 +196,7 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -196,7 +196,7 @@ def train(use_cuda, save_dirname=None, is_local=True):
print("second per batch: " + str((time.time( print("second per batch: " + str((time.time(
) - start_time) / batch_id)) ) - start_time) / batch_id))
# Set the threshold low to speed up the CI test # Set the threshold low to speed up the CI test
if float(cost) < 60.0: if float(cost) < 80.0:
if save_dirname is not None: if save_dirname is not None:
# TODO(liuyiqun): Change the target to crf_decode # TODO(liuyiqun): Change the target to crf_decode
fluid.io.save_inference_model(save_dirname, [ fluid.io.save_inference_model(save_dirname, [
...@@ -208,6 +208,10 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -208,6 +208,10 @@ def train(use_cuda, save_dirname=None, is_local=True):
batch_id = batch_id + 1 batch_id = batch_id + 1
raise RuntimeError(
"This model should save_inference_model and return, but not reach here, please check!"
)
if is_local: if is_local:
train_loop(fluid.default_main_program()) train_loop(fluid.default_main_program())
else: else:
......
...@@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase): ...@@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.infer_shape(block) mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
def test_expand_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
shape = [-1, 20]
expand_times = [3, 1]
# prepare input/output
x1 = block.var(six.b("x"))
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape)
out = block.var(six.b("out"))
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("expand")
sum_op_desc.set_input("X", ["x"])
sum_op_desc.set_output("Out", ["out"])
sum_op_desc._set_attr('expand_times', expand_times)
sum_op_desc.check_attrs()
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -68,7 +68,9 @@ class TestNCE(OpTest): ...@@ -68,7 +68,9 @@ class TestNCE(OpTest):
self.attrs = { self.attrs = {
'num_total_classes': num_classes, 'num_total_classes': num_classes,
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'custom_neg_classes': list(range(num_neg_samples)) 'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0,
'sampler': 0
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
......
...@@ -183,6 +183,7 @@ class TestWarpCTCOp(OpTest): ...@@ -183,6 +183,7 @@ class TestWarpCTCOp(OpTest):
self.labels_lod = [[3, 1, 4, 4]] self.labels_lod = [[3, 1, 4, 4]]
self.blank = self.num_classes - 1 self.blank = self.num_classes - 1
self.norm_by_times = False self.norm_by_times = False
self.use_cudnn = False
def setUp(self): def setUp(self):
self.op_type = "warpctc" self.op_type = "warpctc"
...@@ -215,7 +216,11 @@ class TestWarpCTCOp(OpTest): ...@@ -215,7 +216,11 @@ class TestWarpCTCOp(OpTest):
"Label": (labels, self.labels_lod) "Label": (labels, self.labels_lod)
} }
self.outputs = {"Loss": loss} self.outputs = {"Loss": loss}
self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times} self.attrs = {
"blank": self.blank,
"norm_by_times": self.norm_by_times,
"use_cudnn": self.use_cudnn
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -233,6 +238,22 @@ class TestWarpCTCOpCase1(TestWarpCTCOp): ...@@ -233,6 +238,22 @@ class TestWarpCTCOpCase1(TestWarpCTCOp):
self.labels_lod = [[3, 1, 4, 4]] self.labels_lod = [[3, 1, 4, 4]]
self.blank = 0 self.blank = 0
self.norm_by_times = False self.norm_by_times = False
self.use_cudnn = False
class TestCudnnCTCOp(TestWarpCTCOp):
def config(self):
self.batch_size = 4
self.num_classes = 8
self.logits_lod = [[4, 1, 3, 3]]
self.labels_lod = [[3, 1, 4, 4]]
self.blank = 0
self.norm_by_times = False
self.use_cudnn = True
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -34,6 +34,7 @@ def wait_server_ready(endpoints): ...@@ -34,6 +34,7 @@ def wait_server_ready(endpoints):
""" """
while True: while True:
all_ok = True all_ok = True
not_ready_endpoints = []
for ep in endpoints: for ep in endpoints:
ip_port = ep.split(":") ip_port = ep.split(":")
with closing(socket.socket(socket.AF_INET, with closing(socket.socket(socket.AF_INET,
...@@ -42,8 +43,11 @@ def wait_server_ready(endpoints): ...@@ -42,8 +43,11 @@ def wait_server_ready(endpoints):
result = sock.connect_ex((ip_port[0], int(ip_port[1]))) result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0: if result != 0:
all_ok = False all_ok = False
not_ready_endpoints.append(ep)
if not all_ok: if not all_ok:
sys.stderr.write("pserver not ready, wait 3 sec to retry...\n") sys.stderr.write("pserver not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) +
"\n")
sys.stderr.flush() sys.stderr.flush()
time.sleep(3) time.sleep(3)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册