diff --git a/CMakeLists.txt b/CMakeLists.txt index 68447727118a91a2a8c0d06404353c7ccb734c6d..ba9af2cc7b7acd761dd9da644ca3d8c00775e338 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,6 +212,7 @@ elseif() set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in GPU only now." FORCE) endif() +include(flags) # set paddle compile flags include(cudnn) # set cudnn libraries, must before configure include(cupti) include(configure) # add paddle env configuration @@ -220,7 +221,6 @@ include(package) # set paddle packages include(ccache) # set ccache for compilation include(util) # set unittest and link libs include(rdma) # set rdma libraries -include(flags) # set paddle compile flags include(version) # set PADDLE_VERSION include(coveralls) # set code coverage include(inference_lib) # add paddle fluid inference libraries diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 53454d79f7205415d0634c9350520eee7106b0b5..0a8d891b081deaa1142a63f952667bfb516c5215 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -50,16 +50,16 @@ if(NOT WITH_PROFILER) endif(NOT WITH_PROFILER) if(NOT CMAKE_CROSSCOMPILING) - if(WITH_AVX AND AVX_FOUND) + if(WITH_AVX AND AVX512F_FOUND) + set(SIMD_FLAG ${AVX512F_FLAG}) + elseif(WITH_AVX AND AVX2_FOUND) + set(SIMD_FLAG ${AVX2_FLAG}) + elseif(WITH_AVX AND AVX_FOUND) set(SIMD_FLAG ${AVX_FLAG}) elseif(SSE3_FOUND) set(SIMD_FLAG ${SSE3_FLAG}) endif() endif() -if(UNIX AND NOT APPLE) - # except apple from nix*Os family - set(LINUX TRUE) -endif(UNIX AND NOT APPLE) if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 9eebea816cbfc91052c95ecf99ecc4b0bea4e4c2..cd51533926de7bb132ab7bfab1686d664a331410 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -25,8 +25,25 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib - /usr/lib) -find_library(CUDNN_LIBRARY NAMES libcudnn.so libcudnn.dylib # libcudnn_static.a + /usr/lib + ${CUDA_TOOLKIT_ROOT_DIR} + ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 + ) +set(CUDNN_LIB_NAME "") +if (LINUX) +set(CUDNN_LIB_NAME "libcudnn.so") +endif(LINUX) + +if(WIN32) +# only support cudnn7 +set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll") +endif(WIN32) + +if(Apple) +set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so") +endif(Apple) + +find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist} NO_DEFAULT_PATH DOC "Path to cuDNN library.") diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 1120677a37e0d44163816b66600121c8f0d545af..8ac157c4d79f1f5f2a655152f46b4a4d3f2c6962 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -142,6 +142,11 @@ else() ${GPU_COMMON_FLAGS}) endif() +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif(UNIX AND NOT APPLE) + foreach(flag ${COMMON_FLAGS}) safe_set_cflag(CMAKE_C_FLAGS ${flag}) diff --git a/cmake/simd.cmake b/cmake/simd.cmake index 53c2de332ea74b06d1bd6e5bb119cad6af27ed01..3eacf4d86aa0385eddb690d72e85e3384929bb99 100644 --- a/cmake/simd.cmake +++ b/cmake/simd.cmake @@ -10,6 +10,7 @@ if(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID set(SSE3_FLAG "-msse3") set(AVX_FLAG "-mavx") set(AVX2_FLAG "-mavx2") + set(AVX512F_FLAG "-mavx512f") elseif(MSVC) set(MMX_FLAG "/arch:MMX") set(SSE2_FLAG "/arch:SSE2") @@ -81,5 +82,16 @@ int main() return 0; }" AVX2_FOUND) +# Check AVX512F +set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG}) +set(AVX512F_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + __m512i a = _mm512_undefined_epi32(); + return 0; +}" AVX512F_FOUND) + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED}) -mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND) +mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND AVX512F_FOUND) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8e69f8f0b98b3d07bfd57bf3251e32a3b3c607f7..2ec422cc17faf7f6b99ac70b5f175881bf017566 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -99,12 +99,13 @@ else() cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) endif() - +if (NOT WIN32) cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass fast_threaded_ssa_graph_executor) +endif() # NOT WIN32 cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 20bdc7830f32564448a69e9cd76c02585b7a1aca..344c001a69b53c82967ee983783892a514c2490b 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto all_ops = blocks_[block_id]->AllOps(); for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { int sub_block_id = o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); + } else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) { + std::vector sub_block_ids = + o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name); + std::vector block_descs; + for (int block_id : sub_block_ids) { + block_descs.push_back(MutableBlock(block_id)); + } + op->SetBlocksAttr(attr_name, block_descs); } } } @@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { desc_ = desc; - for (auto &block_desc : *desc_.mutable_blocks()) { - blocks_.emplace_back(new BlockDesc(this, &block_desc)); - } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); - } - } - } - } + InitFromProto(); } ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE(desc_.ParseFromString(binary_str), "Fail to parse program_desc from binary string."); + InitFromProto(); +} + +void ProgramDesc::InitFromProto() { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } @@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { if (attr.type() == proto::AttrType::BLOCK) { size_t blk_idx = attr.block_idx(); op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + } else if (attr.type() == proto::AttrType::BLOCKS) { + auto blks_idx = attr.blocks_idx(); + std::vector block_descs; + for (int blk_idx : blks_idx) { + block_descs.push_back(this->MutableBlock(blk_idx)); + } + op->SetBlocksAttr(attr.name(), block_descs); } } } diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 65fa0a0cfd5ba6d9b8765cee1309e118cb74348a..f3afc85eb924e4b03b7597e043ffd4e267adc977 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -76,6 +76,8 @@ class ProgramDesc { void SetFetchHolderName(const std::string &fetch_holder_name); private: + void InitFromProto(); + proto::ProgramDesc desc_; std::vector> blocks_; diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 6c46e9aad5b7fbf67fdcc07a12e7932ac8b6412b..925ea98dbe62e4da91689f6e56c135e51c24a8a3 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -42,6 +42,19 @@ TEST(ProgramDesc, copy_ctor) { out->SetType(proto::VarType::LOD_TENSOR); op->SetOutput("Y", {out->Name()}); + BlockDesc* new_block = program.AppendBlock(*global_block); + op = new_block->AppendOp(); + op->SetType("mul"); + + op = global_block->AppendOp(); + op->SetType("op_with_subblock"); + op->SetAttr("sub_block", new_block); + + std::vector sub_blocks; + sub_blocks.push_back(program.AppendBlock(*global_block)); + sub_blocks.push_back(program.AppendBlock(*global_block)); + op->SetAttr("sub_blocks", sub_blocks); + ProgramDesc program_copy(program); auto* global_block_copy = program_copy.MutableBlock(0); @@ -64,6 +77,8 @@ TEST(ProgramDesc, copy_ctor) { assert_same_var("Y", y); assert_same_var("Out", out); + bool found_sub_block = false; + bool found_sub_blocks = false; for (size_t i = 0; i < global_block->OpSize(); ++i) { auto op_origin = global_block->Op(i); auto op_copy = global_block_copy->Op(i); @@ -74,8 +89,17 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_EQ(op_copy->Proto()->SerializeAsString(), op_origin->Proto()->SerializeAsString()); - } + if (op->Type() == "op_with_subblock") { + ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); + found_sub_block = true; + + ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size()); + found_sub_blocks = true; + } + } + ASSERT_TRUE(found_sub_block); + ASSERT_TRUE(found_sub_blocks); // Not check block's protostr are same it because the order of vars could be // different and it is correct. } diff --git a/paddle/fluid/inference/api/high_level_api_cn.md b/paddle/fluid/inference/api/high_level_api_cn.md index 2fb914592cbcb1b0c3f2ef33ff9cf4c295e427b6..442c598978c700f4c438b365b8900db5b65bc5ec 100644 --- a/paddle/fluid/inference/api/high_level_api_cn.md +++ b/paddle/fluid/inference/api/high_level_api_cn.md @@ -65,13 +65,13 @@ config.model_dir = "xxx"; config.use_gpu = false; // 创建一个原生的 PaddlePredictor auto predictor = - paddle::CreatePaddlePredictor(config); + paddle::CreatePaddlePredictor(config); // 创建输入 tensor int64_t data[4] = {1, 2, 3, 4}; paddle::PaddleTensor tensor{.name = "", .shape = std::vector({4, 1}), - .data = PaddleBuf(data, sizeof(data)), - .dtype = PaddleDType::INT64}; + .data = paddle::PaddleBuf(data, sizeof(data)), + .dtype = paddle::PaddleDType::INT64}; // 创建输出 tensor,输出 tensor 的内存可以复用 std::vector outputs; // 执行预测 diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index b52d083f280e5e7713600a7b748dedd37aca0a1e..a610687a5b11999a7cb7426dbe961e5972ee1746 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,4 @@ -nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto) +nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) add_subdirectory(convert) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e29fe2a42bd1aaee1ea8c01159e331cf47ca6b72..ed8e9ed77fb233e40bb78329a246ff724b21c547 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -84,6 +84,15 @@ function(op_library TARGET) message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") endif() + #remove windows unsupported op + if (WIN32) + foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op") + if ("${TARGET}" STREQUAL "${windows_unsupport_op}") + return() + endif() + endforeach() + endif(WIN32) + list(LENGTH op_library_DEPS op_library_DEPS_len) if (${op_library_DEPS_len} GREATER 0) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) @@ -181,19 +190,19 @@ function(op_library TARGET) endfunction() add_subdirectory(math) +if (NOT WIN32) add_subdirectory(nccl) - if(WITH_GPU) op_library(nccl_op DEPS nccl_common) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") else() set(DEPS_OPS ${DEPS_OPS} nccl_op) endif() +endif() # NOT WIN32 set(DISTRIBUTE_DEPS "") if(WITH_DISTRIBUTE) add_subdirectory(distributed) - set(DISTRIBUTE_DEPS "") if(WITH_GRPC) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) @@ -222,7 +231,7 @@ if(WITH_DISTRIBUTE) #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op # listen_and_serv_op sum_op executor SERIAL) - if(WITH_GPU) + if(WITH_GPU AND NOT WIN32) set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op ${DISTRIBUTE_DEPS} executor SERIAL) if(WITH_GRPC) @@ -233,7 +242,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) - endif() + endif() # WITH_GPU AND NOT WIN32 else() set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) endif() @@ -331,5 +340,7 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) +if(NOT WIN32) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) +endif() nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) diff --git a/paddle/fluid/operators/conditional_block_op.cc b/paddle/fluid/operators/conditional_block_op.cc index 580fde753816c30b188b8a99cc63fcbafde64e25..135254ce6b6bf9add7bb1f0c3f645ed47081fba4 100644 --- a/paddle/fluid/operators/conditional_block_op.cc +++ b/paddle/fluid/operators/conditional_block_op.cc @@ -29,9 +29,9 @@ class ConditionalOp : public framework::OperatorBase { protected: std::vector InputTensors( - const framework::Scope &scope) const { + const framework::Scope &scope, const std::string &in_name) const { std::vector retv; - auto xs = Inputs("X"); + auto xs = Inputs(in_name); retv.resize(xs.size(), nullptr); std::transform( xs.begin(), xs.end(), retv.begin(), @@ -81,12 +81,18 @@ class ConditionalBlockOp : public ConditionalOp { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - auto xs = InputTensors(scope); - bool need_run; if (Attr("is_scalar_condition")) { + // When is_scalar_condition is True, the conditional variable is a scalar, + // whether need to execute the operators in sub-block depends on the + // conditional variable (Cond). + auto xs = InputTensors(scope, "Cond"); need_run = ScalarCondition(xs); } else { + // When is_scalar_condition is False, the conditional variable maybe a + // vector or tensor, whether need to execute the operators in sub-block + // depends on the input variables (Input). + auto xs = InputTensors(scope, "Input"); need_run = std::all_of( xs.begin(), xs.end(), [](const framework::LoDTensor *t) { return t->numel() != 0; }); @@ -110,11 +116,11 @@ class ConditionalBlockOp : public ConditionalOp { class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", - "The conditional variable of this operator. If X is empty, the " + AddInput("Cond", + "The conditional variable of this operator. If Cond is empty, the " "whole sub-block will not be executed.") .AsDuplicable(); - AddInput("Params", "The input variables of the sub-block.").AsDuplicable(); + AddInput("Input", "The input variables of the sub-block.").AsDuplicable(); AddOutput("Out", "The output variables of the sub-block.").AsDuplicable(); AddOutput("Scope", "(std::vector) The step scope of conditional block. To " @@ -123,13 +129,18 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "sub_block", "The step block of conditional block operator"); AddAttr("is_scalar_condition", - "the input X is used as scalar " - "condition") + "The conditional variable (Cond) is used as scalar " + "condition.") .SetDefault(false); AddComment(R"DOC(Conditional block operator -Run the sub-block if X is not empty. Params is the other inputs and Out is the -outputs of the sub-block. +If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar, +run the operators in sub-block if Cond is True. + +If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or +tensor, run the operators in sub-block if all of input variables are not empty. + + )DOC"); } }; @@ -145,12 +156,12 @@ class ConditionalBlockGradOp : public ConditionalOp { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - auto xs = this->InputTensors(scope); - bool need_run; if (Attr("is_scalar_condition")) { + auto xs = this->InputTensors(scope, "Cond"); need_run = ScalarCondition(xs); } else { + auto xs = this->InputTensors(scope, "Input"); need_run = std::all_of( xs.begin(), xs.end(), [](const framework::LoDTensor *t) { return t->numel() != 0; }); @@ -166,11 +177,11 @@ class ConditionalBlockGradOp : public ConditionalOp { auto *block = Attr("sub_block"); exec.Run(*block->Program(), &cur_scope, block->ID(), false); - AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Params"), - Outputs(framework::GradVarName("Params"))); + AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Input"), + Outputs(framework::GradVarName("Input"))); - AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("X"), - Outputs(framework::GradVarName("X"))); + AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Cond"), + Outputs(framework::GradVarName("Cond"))); } } @@ -199,15 +210,15 @@ class ConditionalBlockGradOp : public ConditionalOp { class ConditionalBlockGradInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInputs("X")); - if (context->HasInputs("Params")) { - PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params"))); - context->SetOutputsDim(framework::GradVarName("Params"), - context->GetInputsDim("Params")); + PADDLE_ENFORCE(context->HasInputs("Cond")); + if (context->HasInputs("Input")) { + PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Input"))); + context->SetOutputsDim(framework::GradVarName("Input"), + context->GetInputsDim("Input")); } - if (context->HasOutputs(framework::GradVarName("X"))) { - context->SetOutputsDim(framework::GradVarName("X"), - context->GetInputsDim("X")); + if (context->HasOutputs(framework::GradVarName("Cond"))) { + context->SetOutputsDim(framework::GradVarName("Cond"), + context->GetInputsDim("Cond")); } } }; @@ -220,14 +231,15 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { std::unique_ptr Apply() const override { auto grad_op = new framework::OpDesc(); grad_op->SetType("conditional_block_grad"); - grad_op->SetInput("X", Input("X")); - grad_op->SetInput("Params", Input("Params")); + grad_op->SetInput("Cond", Input("Cond")); + grad_op->SetInput("Input", Input("Input")); grad_op->SetInput("Out", Output("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput("Scope", Output("Scope")); - grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X", false)); - grad_op->SetOutput(framework::GradVarName("Params"), - InputGrad("Params", false)); + grad_op->SetOutput(framework::GradVarName("Cond"), + InputGrad("Cond", false)); + grad_op->SetOutput(framework::GradVarName("Input"), + InputGrad("Input", false)); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition")); return std::unique_ptr(grad_op); diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h index 3f5fab3b382bea97f43e4bc1b2cd436c956ba264..8181897c3d3844bda5574e85a08b2af038fcd664 100644 --- a/paddle/fluid/operators/crf_decoding_op.h +++ b/paddle/fluid/operators/crf_decoding_op.h @@ -85,6 +85,199 @@ class CRFDecodingOpKernel : public framework::OpKernel { int* track_value = track.mutable_data(emission_dims, platform::CPUPlace()); +#ifdef __AVX__ +// It use the AVX or AVX512 instruction to deal the data as the vector of 8 or +// 16 elements per iteration. Then it can implement the parallel processing. +// Only optimize for float type. +#ifdef __AVX512F__ + size_t step_size = 16; +#else + size_t step_size = 8; +#endif + if (std::is_same::value && (tag_num >= step_size)) { + size_t steps = tag_num / step_size; + size_t remain = tag_num % step_size; + int last_offset = static_cast(remain) - static_cast(step_size); + + // Setup the alpha initial value. + size_t i_offset = 0; + for (size_t i = 0; i <= steps; ++i) { +#ifdef __AVX512F__ + // Declare the variable for the content of weights, input and alpha + // values. + __m512 w_content, x_content, alpha_content; + + // Load the relevant data into the variables from un-aligned address. + w_content = _mm512_loadu_ps((const float*)(w + i_offset)); + x_content = _mm512_loadu_ps((const float*)(x + i_offset)); + alpha_content = _mm512_add_ps(w_content, x_content); + + // Save the alpha value. + _mm512_storeu_ps(reinterpret_cast(alpha_value + i_offset), + alpha_content); +#else + // Declare the variable for the content of weights, input and alpha + // values. + __m256 w_content, x_content, alpha_content; + + // Load the relevant data into the variables from un-aligned address. + w_content = _mm256_loadu_ps((const float*)(w + i_offset)); + x_content = _mm256_loadu_ps((const float*)(x + i_offset)); + alpha_content = _mm256_add_ps(w_content, x_content); + + // Save the alpha value. + _mm256_storeu_ps(reinterpret_cast(alpha_value + i_offset), + alpha_content); +#endif + i_offset += step_size; + if (i == steps - 1) { + if (remain > 0) { + i_offset += last_offset; + } else { + break; + } + } + } + + // Use the column-major strategy to get the location of maximum score. + size_t seq_offset = 0; + for (size_t k = 1; k < seq_len; ++k) { + size_t j_offset = 0; + for (size_t j = 0; j <= steps; ++j) { +#ifdef __AVX512F__ + // Initialize the variables of maximum score and location. + __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); + __m512i max_j = _mm512_setzero_si512(); +#else + // Initialize the variables of maximum score and location. + __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); + __m256i max_j = _mm256_set1_epi32(0); +#endif + // Calculate the offset of transition_weights. + size_t trans_offset = state_trans_base_idx * tag_num + j_offset; + for (size_t i = 0; i < tag_num; ++i) { +#ifdef __AVX512F__ + // Initalize the content of alpha variable with related offset. + __m512 alpha_content = + _mm512_set1_ps(*(const float*)(alpha_value + seq_offset + i)); + // Obtain the content of weights from un-aligned address. + __m512 w_content = + _mm512_loadu_ps((const float*)(w + trans_offset)); + + __m512 score_v = _mm512_add_ps(alpha_content, w_content); + + __mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); + + // According to the mask value, it update the index of the max_score + // location. + max_j = _mm512_mask_set1_epi32(max_j, mask, i); + + // Update the max_score value. + max_score = _mm512_max_ps(max_score, score_v); +#else + // Initalize the content of alpha variable with related offset. + __m256 alpha_content = _mm256_broadcast_ss( + (const float*)(alpha_value + seq_offset + i)); + // Obtain the content of weights from un-aligned address. + __m256 w_content = + _mm256_loadu_ps((const float*)(w + trans_offset)); + __m256 score_v = _mm256_add_ps(alpha_content, w_content); + + __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); + +#ifdef __AVX2__ + // According to the mask value, it update the index of the max_score + // location. + max_j = _mm256_or_si256( + _mm256_andnot_si256((__m256i)mask, max_j), + _mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); +#else + __m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); + __m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); + __m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0); + __m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1); + + lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); + hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); + lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); + hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); + + lo_max_j = _mm_or_si128(lo_mask, lo_max_j); + hi_max_j = _mm_or_si128(hi_mask, hi_max_j); + + // According to the mask value, it update the index of the max_score + // location. + max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); + max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); +#endif + + // Update the max_score value. + max_score = _mm256_max_ps(max_score, score_v); +#endif + trans_offset += tag_num; + } + +#ifdef __AVX512F__ + // Update the alpha and track values. + __m512 x_content = _mm512_loadu_ps( + (const float*)(x + seq_offset + tag_num + j_offset)); + max_score = _mm512_add_ps(max_score, x_content); + _mm512_storeu_ps(reinterpret_cast(alpha_value + seq_offset + + tag_num + j_offset), + max_score); + _mm512_storeu_si512( + reinterpret_cast<__m512i*>(track_value + seq_offset + tag_num + + j_offset), + max_j); +#else + // Update the alpha and track values. + __m256 x_content = _mm256_loadu_ps( + (const float*)(x + seq_offset + tag_num + j_offset)); + max_score = _mm256_add_ps(max_score, x_content); + _mm256_storeu_ps(reinterpret_cast(alpha_value + seq_offset + + tag_num + j_offset), + max_score); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(track_value + seq_offset + tag_num + + j_offset), + max_j); +#endif + + // Calculate the offset of next step + j_offset += step_size; + if (j == steps - 1) { + if (remain > 0) { + j_offset += last_offset; + } else { + break; + } + } + } + + seq_offset += tag_num; + } + } else { + for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i]; + + for (size_t k = 1; k < seq_len; ++k) { + for (size_t i = 0; i < tag_num; ++i) { + T max_score = -std::numeric_limits::max(); + int max_j = 0; + for (size_t j = 0; j < tag_num; ++j) { + T score = alpha_value[(k - 1) * tag_num + j] + + w[(j + state_trans_base_idx) * tag_num + i]; + if (score > max_score) { + max_score = score; + max_j = j; + } + } + + alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i]; + track_value[k * tag_num + i] = max_j; + } + } + } +#else for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i]; for (size_t k = 1; k < seq_len; ++k) { @@ -105,6 +298,7 @@ class CRFDecodingOpKernel : public framework::OpKernel { } } +#endif T max_score = -std::numeric_limits::max(); int max_i = 0; for (size_t i = 0; i < tag_num; ++i) { diff --git a/paddle/fluid/operators/nccl/CMakeLists.txt b/paddle/fluid/operators/nccl/CMakeLists.txt index ce0ddd89bfb0d73e237a6f9a777376624d8ef2d4..cdcba8035762d8f442eb8b8ed52a4e3e99ac31b6 100644 --- a/paddle/fluid/operators/nccl/CMakeLists.txt +++ b/paddle/fluid/operators/nccl/CMakeLists.txt @@ -1,3 +1,3 @@ -if(WITH_GPU) +if(WITH_GPU AND NOT WIN32) nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator ) endif() diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 6c507baf3a0ab0a557d29a53700685753616193b..8a683116b8054de12fc4419b5aa5fbc019b675bb 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -23,9 +23,9 @@ class SqueezeOpInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SqueezeOp should not be null."); + "Input(X) of Squeeze operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SqueezeOp should not be null."); + "Output(Out) of Squeeze operator should not be null."); const auto &x_dims = ctx->GetInputDim("X"); // Check input tensor dims (<6) Eigen limit. @@ -107,7 +107,6 @@ class SqueezeOp : public framework::OperatorBase { framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims); - attrs["inplace"] = Attr("inplace"); // Invoke Reshape Op auto reshape_op = framework::OpRegistry::CreateOp( "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, @@ -125,12 +124,6 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "(std::vector). List of integers," " indicating the dimensions to squeeze.") .SetDefault({}); - AddAttr("inplace", - "(default: false) Squeeze the source tensor's shape without " - "memory copy. When Attr(inplace) is set true, the output " - "tensor shares memory with Input(X), otherwise, a new output " - "tensor is created, and its data are copied from Input(x).") - .SetDefault(false); AddComment(R"DOC( Squeeze Operator. @@ -180,7 +173,6 @@ class SqueezeGradOp : public framework::OperatorBase { auto x_dims = scope.FindVar(Input("X"))->Get().dims(); framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(x_dims); - attrs["inplace"] = Attr("inplace"); auto reshape_op = framework::OpRegistry::CreateOp( "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index f2a15fdf572e0de30f9949dda5020e130b0c5585..0fc8d54f6400c9dfb6af1e764ed44e95195bfe6e 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -23,9 +23,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of UnsqueezeOp should not be null."); + "Input(X) of Unsqueeze operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of UnsqueezeOp should not be null."); + "Output(Out) of Unsqueeze operator should not be null."); const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); @@ -95,7 +95,6 @@ class UnsqueezeOp : public framework::OperatorBase { framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims); - attrs["inplace"] = Attr("inplace"); // Invoke Reshape op. auto reshape_op = framework::OpRegistry::CreateOp( "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, @@ -126,13 +125,6 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { " within [1, 6] dimensions (Eigen limit)."); } }); - AddAttr( - "inplace", - "(default: false) Unsqueeze the source tensor's shape without " - "memory copy. When Attr(inplace) is set true, the output " - "tensor shares memory with Input(X), otherwise, a new output " - "tensor is created, and its data are copied from Input(x).") - .SetDefault(false); AddComment(R"DOC( Unsqueeze Operator. @@ -168,7 +160,6 @@ class UnsqueezeGradOp : public framework::OperatorBase { framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(x_dims); - attrs["inplace"] = Attr("inplace"); auto reshape_op = framework::OpRegistry::CreateOp( "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 9da787a4073fa002f75154f7c4fba54e9ed8efa6..07159d4a12ef4b628f7705ed206d3334be46dfc8 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -3,7 +3,7 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc) # There is no macOS version of NCCL. -if (NOT APPLE) +if (NOT APPLE AND NOT WIN32) list(APPEND CUDA_SRCS nccl.cc) endif() diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 81b5359b40589d898bda0dfa71afb6f51385354b..6c2331b75f64b777adcdca4245d503bb5a52e1a6 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -44,7 +44,7 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/curand.h" -#ifndef __APPLE__ +#if !defined(__APPLE__) and !defined(_WIN32) #include "paddle/fluid/platform/dynload/nccl.h" #endif // __APPLE__ #endif // PADDLE_WITH_CUDA @@ -205,7 +205,7 @@ inline typename std::enable_if::type throw_on_error( #endif } -#ifndef __APPLE__ +#if !defined(__APPLE__) and !defined(_WIN32) template inline typename std::enable_if::type throw_on_error( ncclResult_t stat, const Args&... args) { @@ -221,7 +221,7 @@ inline typename std::enable_if::type throw_on_error( #endif } } -#endif // __APPLE__ +#endif // __APPLE__ and windows #endif // PADDLE_WITH_CUDA template diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 173567a0a374ac4453025b67b047950936df2055..8bfe11916bd069cd2dd7016c03644d6cad1e188d 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1272,8 +1272,8 @@ class ConditionalBlock(object): parent_block.append_op( type='conditional_block', inputs={ - 'X': self.inputs, - 'Params': param_list, + 'Cond': self.inputs, + 'Input': param_list, }, outputs={'Out': out_list, 'Scope': [step_scope]}, diff --git a/python/paddle/fluid/tests/test_if_else_op.py b/python/paddle/fluid/tests/test_if_else_op.py index 10918a985f80f9f20629180b0359ce73f448d33b..61d81f483636a99ea9e0282de89f12e47f3b824c 100644 --- a/python/paddle/fluid/tests/test_if_else_op.py +++ b/python/paddle/fluid/tests/test_if_else_op.py @@ -30,7 +30,8 @@ import numpy as np class TestMNISTIfElseOp(unittest.TestCase): - def test_raw_api(self): + # FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379 + def not_test_raw_api(self): prog = Program() startup_prog = Program() with program_guard(prog, startup_prog): @@ -91,7 +92,8 @@ class TestMNISTIfElseOp(unittest.TestCase): return self.assertFalse(True) - def test_ifelse(self): + # FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379 + def not_test_ifelse(self): prog = Program() startup_prog = Program() with program_guard(prog, startup_prog): @@ -153,6 +155,13 @@ class TestIfElse(unittest.TestCase): self.cond_value = 0.5 self.data = np.random.rand(25, 1).astype(np.float32) + def numpy_cal(self): + s1 = self.data[np.where(self.data < self.cond_value)] + res = np.sum(np.exp(s1)) + s2 = self.data[np.where(self.data >= self.cond_value)] + res += np.sum(np.tanh(s2)) + return res + def compare_ifelse_op_and_numpy(self, place): self.set_test_case() @@ -166,10 +175,12 @@ class TestIfElse(unittest.TestCase): ie = layers.IfElse(ifcond) with ie.true_block(): true_target = ie.input(src) + true_target = fluid.layers.exp(true_target) ie.output(true_target) with ie.false_block(): false_target = ie.input(src) + false_target = fluid.layers.tanh(false_target) ie.output(false_target) if_out = ie() out = layers.reduce_sum(if_out) @@ -180,7 +191,8 @@ class TestIfElse(unittest.TestCase): o1, = exe.run(fluid.default_main_program(), feed={'data': self.data}, fetch_list=[out]) - o2 = np.sum(self.data) + o2 = self.numpy_cal() + self.assertTrue( np.allclose( o1, o2, atol=1e-8), diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index a2a5584459a5d4dc416b9c542a4bb0567982e765..2be8e24a0fae6945351eb767ac924d7ca70848ab 100644 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -41,7 +41,7 @@ class TestSqueezeOp(OpTest): self.new_shape = (3, 5) def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": False} + self.attrs = {"axes": self.axes} # Correct: There is mins axis. @@ -68,49 +68,5 @@ class TestSqueezeOp3(TestSqueezeOp): self.new_shape = (3, 5, 1, 4) -# Correct: Inplace. -class TestSqueezeOpInplace1(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = (1, 3, 1, 5) - self.axes = (0, 2) - self.new_shape = (3, 5) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - -# Correct: Inplace. There is mins axis. -class TestSqueezeOpInplace2(TestSqueezeOp): - def inti_test_case(self): - self.ori_shape = (1, 3, 1, 5) - self.axes = (0, -2) - self.new_shape = (3, 5) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - -# Correct: Inplace. No axes input. -class TestSqueezeOpInplace3(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = (1, 3, 1, 5) - self.axes = () - self.new_shape = (3, 5) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - -# Correct: Inpalce. Just part of axes be squeezed. -class TestSqueezeOpInplace4(TestSqueezeOp): - def init_test_case(self): - self.ori_shape = (3, 1, 5, 1, 4, 1) - self.axes = (1, -1) - self.new_shape = (3, 5, 1, 4) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 5fcabe4c83457a91ff0e9cd5568904698353b62a..a324438ba5a3c3b57fd956bd11189ef7d50267e2 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -41,7 +41,7 @@ class TestUnsqueezeOp(OpTest): self.new_shape = (3, 1, 1, 5) def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": False} + self.attrs = {"axes": self.axes} # Correct: Single input index. @@ -76,38 +76,5 @@ class TestUnsqueezeOp4(TestUnsqueezeOp): self.new_shape = (3, 1, 1, 2, 5, 1) -# Correct: Inplace. -class TestUnsqueezeOpInplace1(TestUnsqueezeOp): - def init_test_case(self): - self.ori_shape = (3, 5) - self.axes = (0, 2) - self.new_shape = (1, 3, 1, 5) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - -# Correct: Inplace. There is mins index. -class TestUnsqueezeOpInplace2(TestUnsqueezeOp): - def init_test_case(self): - self.ori_shape = (3, 5) - self.axes = (0, -2) - self.new_shape = (1, 3, 1, 5) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - -# Correct: Inplace. There is duplicated axis. -class TestUnsqueezeOpInplace3(TestUnsqueezeOp): - def init_test_case(self): - self.ori_shape = (3, 2, 5) - self.axes = (0, 3, 3) - self.new_shape = (1, 3, 2, 1, 1, 5) - - def init_attrs(self): - self.attrs = {"axes": self.axes, "inplace": True} - - if __name__ == "__main__": unittest.main()