diff --git a/CMakeLists.txt b/CMakeLists.txt index fd552698a837dc02f48d4c20ba7b802d6b4d6602..996a79fbbc3005680205e9fc0442b6bc6199bebb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,8 @@ option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) option(WITH_ANAKIN "Compile with Anakin library" OFF) +option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF) +option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(ON_INFER "Turn on inference optimization." OFF) diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 84354c446e2f54fa13b90fa37221eed90968b251..06fc6061bc98eec8c4c71860333f7d3456952aeb 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -58,19 +58,21 @@ ExternalProject_Add( -DPROTOBUF_ROOT=${THIRD_PARTY_PATH}/install/protobuf -DMKLML_ROOT=${THIRD_PARTY_PATH}/install/mklml -DENABLE_OP_TIMER=${ANAKIN_ENABLE_OP_TIMER} + -DBUILD_FAT_BIN=${ANAKIN_BUILD_FAT_BIN} + -DBUILD_CROSS_PLANTFORM=${ANAKIN_BUILD_CROSS_PLANTFORM} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ANAKIN_INSTALL_DIR} ) message(STATUS "Anakin for inference is enabled") message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}") - +add_dependencies(extern_anakin protobuf mklml) add_library(anakin_shared SHARED IMPORTED GLOBAL) set_property(TARGET anakin_shared PROPERTY IMPORTED_LOCATION ${ANAKIN_SHARED_LIB}) -add_dependencies(anakin_shared extern_anakin protobuf mklml) +add_dependencies(anakin_shared extern_anakin) add_library(anakin_saber SHARED IMPORTED GLOBAL) set_property(TARGET anakin_saber PROPERTY IMPORTED_LOCATION ${ANAKIN_SABER_LIB}) -add_dependencies(anakin_saber extern_anakin protobuf mklml) +add_dependencies(anakin_saber extern_anakin) list(APPEND external_project_dependencies anakin_shared anakin_saber) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2bab3a15b11c6f2c24a568152741fc5798021804..cb9057672cc2c29af21b662edc189004bb0a4866 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -136,6 +136,11 @@ cc_library(version SRCS version.cc) cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) +cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto) +if(NOT WIN32) +cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog + shape_inference data_transform lod_tensor profiler) +endif(NOT WIN32) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) @@ -170,10 +175,14 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) + if(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator) + else(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) + endif(NOT WIN32) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() - + if (NOT WIN32) cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 47a221a9446cd238be94c72aef6844928c4823e3..0313a6a1e3d11b9c43714544db15b092bbc586b3 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/ngraph_operator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/operators/detail/macros.h" @@ -25,6 +26,7 @@ limitations under the License. */ DECLARE_bool(benchmark); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); +DEFINE_bool(use_ngraph, false, "Use NGRAPH to run"); namespace paddle { namespace framework { @@ -81,6 +83,24 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, } } +static void EnableFusedOp(ExecutorPrepareContext* ctx) { +#ifdef PADDLE_WITH_NGRAPH + VLOG(3) << "use_ngraph=True"; + auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_); + for (auto& interval : intervals) { + auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_, + interval.at(0), interval.at(1)); + *interval[0] = std::unique_ptr(fused_op); + } + for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { + ctx->ops_.erase(it->at(0) + 1, it->at(1)); + } +#else + LOG(WARNING) + << "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option"; +#endif +} + Executor::Executor(const platform::Place& place) : place_(place) {} void Executor::Close() { @@ -338,6 +358,7 @@ std::unique_ptr Executor::Prepare( for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); } + if (FLAGS_use_ngraph) EnableFusedOp(ctx.get()); return ctx; } @@ -486,6 +507,5 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; #endif } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc new file mode 100644 index 0000000000000000000000000000000000000000..8177436d0bd90c3bcf8f91d5c55b66be188b19f9 --- /dev/null +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_NGRAPH +#include +#include + +#include "paddle/fluid/framework/ngraph_bridge.h" + +#include "ngraph/ngraph.hpp" + +namespace paddle { +namespace framework { + +std::map&, + std::shared_ptr>>)>> + NgraphBridge::NG_NODE_MAP = {}; + +void NgraphBridge::build_graph(const std::shared_ptr& op) { + auto& op_type = op->Type(); + NG_NODE_MAP[op_type](op, ngb_node_map); +} + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/framework/ngraph_bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..55bf0d21f3471013b1fb780e852d813313345f03 --- /dev/null +++ b/paddle/fluid/framework/ngraph_bridge.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_NGRAPH + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +#include "ngraph/ngraph.hpp" + +namespace paddle { +namespace framework { + +class NgraphBridge { + public: + static std::map< + std::string, + std::function&, + std::shared_ptr>>)>> + NG_NODE_MAP; + + explicit NgraphBridge( + std::shared_ptr< + std::unordered_map>> + var_node_map) + : ngb_node_map(var_node_map) {} + + void build_graph(const std::shared_ptr& op); + + private: + std::shared_ptr< + std::unordered_map>> + ngb_node_map; +}; + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..d967b2780c21713a2f9a73a3402964103f44269e --- /dev/null +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -0,0 +1,220 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_NGRAPH +#include + +#include +#include + +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/ngraph_operator.h" +#include "paddle/fluid/framework/shape_inference.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/framework/var_type.h" + +namespace paddle { +namespace framework { + +static std::map pd2ng_type_map = { + {proto::VarType::FP32, ngraph::element::f32}, + {proto::VarType::FP64, ngraph::element::f64}, + {proto::VarType::INT32, ngraph::element::i32}, + {proto::VarType::INT64, ngraph::element::i64}, + {proto::VarType::BOOL, ngraph::element::boolean}, +}; + +typedef enum { /* nGraph support state on ops */ + FULL_TRAIN, /* Support full ops for train */ + PARTIAL_TRAIN, /* Support partial ops for train */ + FULL_TEST, /* Support full list of ops for test */ + PARTIAL_TEST /* Support partial list of ops for test */ +} op_state; + +class NgraphOperator { + public: + explicit NgraphOperator(const Scope& scope, const platform::Place& place, + const std::vector>& ops, + const std::unordered_map< + std::string, ngraph::element::Type>& var_type_map, + const std::unordered_set& persist, + const std::unordered_set& fetches, + const std::unordered_set& post_op_inputs, + op_state ng_op_state) + : scope_(scope), + place_(place), + fused_ops_(ops), + var_type_map_(var_type_map), + persistables_(persist), + fetches_(fetches), + post_op_inputs_(post_op_inputs), + ng_op_state_(ng_op_state) {} + + void Run(const Scope& scope, const platform::Place& place) const; + + private: + static std::unordered_map> + func_cache; + const Scope& scope_; + const platform::Place& place_; + std::vector> fused_ops_; + std::unordered_map var_type_map_; + std::unordered_set persistables_; + std::unordered_set fetches_; + std::unordered_set post_op_inputs_; + op_state ng_op_state_; +}; + +std::vector>::iterator>> +FusedOperator::FusedOpIntervals( + std::vector>* ops) { + std::vector>::iterator>> + intervals; + if (ops->empty()) { + return intervals; + } + size_t size = ops->size(); + size_t left = 0; + while (left < size && ops.at(left)->Type() != kFeedOpType) { + ++left; + } + if (left == size) { + return intervals; + } + while (left < size && ops->at(left)->Type() == kFeedOpType) { + ++left; + } + + size_t right = left; + while (right < size && ops->at(right)->Type() != kFetchOpType) { + ++right; + } + if (right == size) { + return intervals; + } + if (left >= right) return intervals; + + // (left, right - 1) represents indices between feed and fetch + size_t pivot = left; + while (pivot < right) { + auto op_type = ops->at(pivot)->Type(); + if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) == + paddle::framework::NgraphBridge::NG_NODE_MAP.end()) { + ++pivot; + } else { + size_t start = pivot, end = start; + while (pivot < right && + (paddle::framework::NgraphBridge::NG_NODE_MAP.find( + ops.at(pivot)->Type()) != + paddle::framework::NgraphBridge::NG_NODE_MAP.end())) { + ++pivot; + ++end; + } + std::vector>::iterator> + interval = {ops->begin() + start, ops->begin() + end}; + intervals.push_back(interval); + } + } // end while + + return intervals; +} + +FusedOperator::FusedOperator( + const ProgramDesc& prog, size_t block_id, + std::vector>::iterator start, + std::vector>::iterator end, + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) { + for (std::vector>::iterator it = start; + it != end; ++it) { + fused_ops_.push_back(std::move(*it)); + } + + for (std::vector>::iterator it = end; + (*it)->Type() != kFetchOpType; ++it) { + for (auto& var_name_item : (*it)->Inputs()) { + for (auto& var_name : var_name_item.second) { + post_op_inputs_.insert(var_name); + } + } + } + + if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) { + is_complete = true; + } + + Process(); +} + +void FusedOperator::Process() { + auto& bdesc = pdesc_.Block(block_); + for (auto& var : bdesc.AllVars()) { + if (!(var->GetType() == proto::VarType::SELECTED_ROWS || + var->GetType() == proto::VarType::LOD_TENSOR || + var->GetType() == proto::VarType::LOD_TENSOR_ARRAY)) { + continue; + } + + auto var_name = var->Name(); + if (var->Name() == framework::kEmptyVarName) { + continue; + } + + if (var_name != "fetch" && var_name != "feed") { + auto pd_type = var->GetDataType(); + if (pd2ng_type_map.find(pd_type) == pd2ng_type_map.end()) { + PADDLE_THROW("Data type of var %s not found in pd2ng_type_map", + var_name); + } + var_type_map_[var_name] = pd2ng_type_map[pd_type]; + } + + if (var->Persistable()) { + persistables_.insert(var->Name()); + } + } + + for (auto* op : bdesc.AllOps()) { + if (op->Type() == kFetchOpType) { + std::string fetch_target_name = op->Input("X")[0]; + fetches_.insert(fetch_target_name); + } + } +} + +void FusedOperator::RunImpl(const Scope& scope, + const platform::Place& place) const { + op_state ng_op_state = PARTIAL_TEST; + auto& bdesc = pdesc_.Block(block_); + for (auto* op : bdesc.AllOps()) { + if (op->Type().find("_grad") != std::string::npos) { + ng_op_state = PARTIAL_TRAIN; + break; + } + } + + if (is_full) { + ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN; + } + + NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_, + persistables_, fetches_, post_op_inputs_, + ng_op_state); + ngraph_op.Run(scope, place); +} + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/ngraph_operator.h b/paddle/fluid/framework/ngraph_operator.h new file mode 100644 index 0000000000000000000000000000000000000000..0f655cef1dde624bcf4944b5c096279097e1c8ae --- /dev/null +++ b/paddle/fluid/framework/ngraph_operator.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_NGRAPH + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/ngraph_bridge.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_kernel_type.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/variant.h" + +#include "ngraph/ngraph.hpp" + +namespace paddle { +namespace framework { + +class FusedOperator : public OperatorBase { + public: + static std::vector< + std::vector>::iterator>> + FusedOpIntervals( + std::vector>* ops); + + explicit FusedOperator( + const ProgramDesc& prog, size_t block_id, + std::vector>::iterator start, + std::vector>::iterator end, + const std::string& type = "fused_op", const VariableNameMap& inputs = {}, + const VariableNameMap& outputs = {}, const AttributeMap& attrs = {}); + + void RunImpl(const Scope& scope, const platform::Place& place) const final; + + private: + const ProgramDesc pdesc_; + size_t block_; + std::vector> fused_ops_; + std::unordered_map var_type_map_; + std::unordered_set persistables_; + std::unordered_set fetches_; + std::unordered_set post_op_inputs_; + bool is_full_ = false; + + void Process(); +}; +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index ec87793b442058ddfc9e22fee47fb0aa5f430b93..3400b5274679d8e859a008dcf47ac7122ace6b2d 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/memory/malloc.h" @@ -21,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/string/printf.h" DEFINE_bool(init_allocated_mem, false, "It is a mistake that the values of the memory allocated by " @@ -137,12 +139,18 @@ void* Alloc(platform::CUDAPlace place, size_t size) { platform::SetDeviceId(place.device); size_t avail, total; platform::GpuMemoryUsage(&avail, &total); - LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU " - << place.device << ", available " << avail << " bytes"; + LOG(WARNING) << "Cannot allocate " << string::HumanReadableSize(size) + << " in GPU " << place.device << ", available " + << string::HumanReadableSize(avail); LOG(WARNING) << "total " << total; - LOG(WARNING) << "GpuMinChunkSize " << buddy_allocator->GetMinChunkSize(); - LOG(WARNING) << "GpuMaxChunkSize " << buddy_allocator->GetMaxChunkSize(); - LOG(WARNING) << "GPU memory used: " << Used(place); + LOG(WARNING) << "GpuMinChunkSize " + << string::HumanReadableSize( + buddy_allocator->GetMinChunkSize()); + LOG(WARNING) << "GpuMaxChunkSize " + << string::HumanReadableSize( + buddy_allocator->GetMaxChunkSize()); + LOG(WARNING) << "GPU memory used: " + << string::HumanReadableSize(Used(place)); platform::SetDeviceId(cur_dev); } if (FLAGS_init_allocated_mem) { diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 3083e622c3066879e107f930a45bcec36d347f80..3a4086274d8a4bf6725df9f3195cec2446ceae6c 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -50,12 +50,18 @@ static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache"; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; -static constexpr size_t kNUM_CUDNN_FWD_ALGS = - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; +#if CUDNN_VERSION_MIN(6, 0, 5) +static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; +#else +// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc. +static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7; +static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4; +static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; +#endif template class CUDNNConvOpKernel : public framework::OpKernel { diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 52b459a6a2e56b7c256efdb535b4652c64bae23c..61c3cb34a2472c0ba7d2a7ea5abf8e826a793951 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/lrn_op.h" #include +#include "paddle/fluid/operators/math/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -29,34 +30,43 @@ struct LRNFunctor { const framework::Tensor& input, framework::Tensor* out, framework::Tensor* mid, int N, int C, int H, int W, int n, T k, T alpha, T beta) { - auto x_v = framework::EigenVector::Flatten(input); - - const int start = -(n - 1) / 2; - const int end = start + n; - - auto e_mid = framework::EigenTensor::From(*mid); - e_mid = e_mid.constant(k); - - auto e_x = framework::EigenTensor::From(input); - for (int m = 0; m < N; m++) { - for (int i = 0; i < C; i++) { - for (int c = start; c < end; c++) { - int ch = i + c; - if (ch >= 0 && ch < C) { - auto s = e_mid.slice(Eigen::array({{m, i, 0, 0}}), - Eigen::array({{1, 1, H, W}})); - - auto r = e_x.slice(Eigen::array({{m, ch, 0, 0}}), - Eigen::array({{1, 1, H, W}})); - - s += alpha * r.square(); - } - } + const T* idata = input.data(); + auto place = ctx.GetPlace(); + auto blas = math::GetBlas(ctx); + T* odata = out->mutable_data(place); + T* mdata = mid->mutable_data(place); + Tensor squared; + T* sdata = squared.mutable_data({1, C + n - 1, H, W}, place); + std::memset(sdata, 0, sizeof(T) * squared.numel()); + for (int i = 0; i < mid->numel(); ++i) { + mdata[i] = k; + } + int img_size = H * W; + int fea_size = C * img_size; + int pre_pad = (n - 1) / 2; + // compute batches one by one + for (int i = 0; i < N; ++i) { + blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size); + // init the first channel of mid + for (int c = 0; c < n; ++c) { + blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size); + } + for (int c = 1; c < C; ++c) { + // copy previous scale + int mid_offset = i * fea_size + c * img_size; + std::memcpy(mdata + mid_offset, mdata + mid_offset - img_size, + img_size * sizeof(T)); + // add last + blas.AXPY(img_size, alpha, sdata + (c + n - 1) * img_size, + mdata + mid_offset); + // sub rest + blas.AXPY(img_size, -alpha, sdata + (c - 1) * img_size, + mdata + mid_offset); } } - - auto out_e = framework::EigenVector::Flatten(*out); - out_e = x_v * e_mid.reshape(Eigen::DSizes(e_mid.size())).pow(-beta); + // compute the final output + blas.VPOW(mid->numel(), mdata, -beta, odata); + blas.VMUL(mid->numel(), odata, idata, odata); } }; template struct LRNFunctor; @@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel { auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4."); + int n = ctx->Attrs().Get("n"); + PADDLE_ENFORCE(n > 0 && n % 2 == 1, "n should be positive odd value"); + ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); ctx->SetOutputDim("MidOut", x_dim); diff --git a/paddle/fluid/operators/lrn_op.h b/paddle/fluid/operators/lrn_op.h index 0fd3175e8579df9e61368cc151a94fa45e433884..12d39c3815395896343238b536110aecac66a376 100644 --- a/paddle/fluid/operators/lrn_op.h +++ b/paddle/fluid/operators/lrn_op.h @@ -60,7 +60,6 @@ class LRNKernel : public framework::OpKernel { T beta = ctx.Attr("beta"); T k = ctx.Attr("k"); - PADDLE_ENFORCE(n > 0, "n should >= 0"); PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0"); PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0"); PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0"); diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index da185d93c09f9b06bd5968b9c8e93176f9ef014b..5d0d562030d2a20e4a1cefd3c36c6533fd35dc96 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -152,6 +152,12 @@ class Blas { template void VEXP(int n, const T* x, T* y) const; + template + void VSQR(int n, const T* x, T* y) const; + + template + void VPOW(int n, const T* x, T alpha, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; @@ -238,6 +244,16 @@ class BlasT : private Blas { Base()->template VEXP(args...); } + template + void VSQR(ARGS... args) const { + Base()->template VSQR(args...); + } + + template + void VPOW(ARGS... args) const { + Base()->template VPOW(args...); + } + template void GEMV(ARGS... args) const { Base()->template GEMV(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index e1df78d11e41c5f74e244643f40c6d0581fa6a4a..59454669be9e0f92a6fc0db52445307d88e1c7d8 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include #include "paddle/fluid/operators/math/math_function.h" @@ -102,6 +103,16 @@ struct CBlas { static void VEXP(ARGS... args) { platform::dynload::vsExp(args...); } + + template + static void VSQR(ARGS... args) { + platform::dynload::vsSqr(args...); + } + + template + static void VPOW(ARGS... args) { + platform::dynload::vsPowx(args...); + } }; template <> @@ -182,6 +193,16 @@ struct CBlas { static void VEXP(ARGS... args) { platform::dynload::vdExp(args...); } + + template + static void VSQR(ARGS... args) { + platform::dynload::vdSqr(args...); + } + + template + static void VPOW(ARGS... args) { + platform::dynload::vdPowx(args...); + } }; #else @@ -241,6 +262,8 @@ struct CBlas { } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } + static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); } + static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML @@ -398,6 +421,31 @@ void Blas::VEXP(int n, const T *x, T *y) const { #endif } +template <> +template +void Blas::VSQR(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSQR(n, x, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::sqrt(x[i]); + } +#endif +} + +template <> +template +void Blas::VPOW(int n, const T *x, T a, + T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VPOW(n, x, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::pow(x[i], a); + } +#endif +} + template <> template T Blas::DOT(int n, const T *x, const T *y) const { diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index aa20553ceffceded09447693c6e92f55fb48702d..9273e9b1e72f0ad7abd6c20d4a34283fbe24378a 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -76,6 +76,10 @@ extern void* mklml_dso_handle; __macro(vdMul); \ __macro(vsExp); \ __macro(vdExp); \ + __macro(vsSqr); \ + __macro(vdSqr); \ + __macro(vsPowx); \ + __macro(vdPowx); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); diff --git a/paddle/fluid/string/printf.h b/paddle/fluid/string/printf.h index 47de23377398423dabf3b0ed5b670e564f57cdfb..a2eec6e3c48dd126614bbff0227145537b678ac4 100644 --- a/paddle/fluid/string/printf.h +++ b/paddle/fluid/string/printf.h @@ -72,6 +72,7 @@ #include #include #include +#include #include "tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat @@ -102,5 +103,22 @@ void Printf(const char* fmt, const Args&... args) { Fprintf(std::cout, fmt, args...); } +template +std::string HumanReadableSize(T size) { + size_t i = 0; + double f_size = static_cast(size); + double orig = f_size; + const std::vector units( + {"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}); + while (f_size > 1024) { + f_size /= 1024; + i++; + } + if (i >= units.size()) { + return Sprintf("%fB", orig); + } + return Sprintf("%f%s", f_size, units[i]); +} + } // namespace string } // namespace paddle diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index a51c9becd416af243cb473c8856141db8d9f3bf0..32f9bca645d80a11274d128b6615a73ffa224705 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -156,6 +156,8 @@ function cmake_gen() { -DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} -DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} + -DANAKIN_BUILD_FAT_BIN=${ANAKIN_BUILD_FAT_BIN:OFF} + -DANAKIN_BUILD_CROSS_PLANTFORM=${ANAKIN_BUILD_CROSS_PLANTFORM:ON} -DPY_VERSION=${PY_VERSION:-2.7} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} ======================================== @@ -188,6 +190,8 @@ EOF -DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} \ -DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} \ -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} \ + -DANAKIN_BUILD_FAT_BIN=${ANAKIN_BUILD_FAT_BIN:OFF}\ + -DANAKIN_BUILD_CROSS_PLANTFORM=${ANAKIN_BUILD_CROSS_PLANTFORM:ON}\ -DPY_VERSION=${PY_VERSION:-2.7} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} @@ -777,6 +781,17 @@ function main() { test_fluid_lib assert_api_spec_approvals ;; + assert_api) + assert_api_not_changed ${PYTHON_ABI:-""} + ;; + test_inference) + gen_capi_package + gen_fluid_lib + test_fluid_lib + ;; + assert_api_approvals) + assert_api_spec_approvals + ;; maccheck) cmake_gen ${PYTHON_ABI:-""} build_mac diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2e1b4b2ead3e98f84ec87c7801d13c00a57f85e0..9a375d37e66332a55b00516e8476b0fe446402a2 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -112,9 +112,10 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) read_env_flags = [ - 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'eager_delete_scope', - 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', - 'free_idle_memory', 'paddle_num_threads', 'dist_threadpool_size', + 'use_pinned_memory', 'check_nan_inf', 'benchmark', + 'eager_delete_scope', 'use_mkldnn', 'use_ngraph', + 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', + 'paddle_num_threads', 'dist_threadpool_size', 'eager_delete_tensor_gb', 'reader_queue_speed_test_mode' ] if os.name != 'nt': diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 87847a0a4929f8563517d10e06072c90d9f6f28d..1b5009e76126e8b0fbe2805cc85f4f133a70c1ae 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6835,7 +6835,7 @@ def prelu(x, mode, param_attr=None, name=None): alpha_shape = x.shape dtype = helper.input_dtype(input_param_name='x') alpha = helper.create_parameter( - attr=param_attr, + attr=helper.param_attr, shape=alpha_shape, dtype='float32', is_bias=False,