diff --git a/cmake/external/ngraph.cmake b/cmake/external/ngraph.cmake index 2e335579f32df4f146c8d88e05e684a9a8105e20..e66459fa3a1508fe4a3687f07bbe18f2a5421296 100644 --- a/cmake/external/ngraph.cmake +++ b/cmake/external/ngraph.cmake @@ -32,6 +32,8 @@ IF(NOT ${WITH_NGRAPH}) return() ENDIF() +INCLUDE(GNUInstallDirs) + INCLUDE(ExternalProject) SET(NGRAPH_PROJECT "extern_ngraph") @@ -40,10 +42,14 @@ SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0") SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) +SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}) SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION}) SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so) SET(NGRAPH_TBB_LIB_NAME libtbb.so.2) SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git") +SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME}) +SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME}) +SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME}) ExternalProject_Add( ${NGRAPH_PROJECT} @@ -63,18 +69,6 @@ ExternalProject_Add( CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib ) -if(UNIX AND NOT APPLE) - include(GNUInstallDirs) - SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}) -else() - SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/lib) -endif() -MESSAGE(STATUS "nGraph lib will be installed at: ${NGRAPH_LIB_DIR}") - -SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME}) -SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME}) -SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME}) - # Workaround for nGraph expecting mklml to be in mkldnn install directory. ExternalProject_Add_Step( ${NGRAPH_PROJECT} diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 0b95a780721b0771d55c4dbb2ddce33418612018..c679d8507d8a9d3bce48b7f38491dadd9f2fb7f6 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -129,6 +129,15 @@ if (WITH_MKLDNN) ) endif () +if (WITH_NGRAPH) + set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/ngraph") + copy(ngraph_lib + SRCS ${NGRAPH_INC_DIR} ${NGRAPH_LIB_DIR} + DSTS ${dst_dir} ${dst_dir} + DEPS ngraph + ) +endif () + if (NOT WIN32) if (NOT MOBILE_INFERENCE AND NOT RPI) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy") diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 8f6797429c9dc21b4848961fcc9814f39b503308..14996e3d86a0908cbef3aa0b8b378abb4be7bed2 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -182,7 +182,7 @@ paddle.fluid.layers.clip ArgSpec(args=['x', 'min', 'max', 'name'], varargs=None, paddle.fluid.layers.clip_by_norm ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) -paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name'], varargs=None, keywords=None, defaults=(-100, None)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) @@ -299,6 +299,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 9f5631b87cba62aa984f27b13418d61e12e86c8a..c701a2ad63048f69e8443fee7434928698aa20cc 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -127,8 +127,9 @@ cc_library(version SRCS version.cc) cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) -cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto) + if(NOT WIN32) +cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog shape_inference data_transform lod_tensor profiler) endif(NOT WIN32) diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 1ce18c3d6b26c541beed668a113b7a4de7f0e79e..eea7e712f8f6e187cdceedce77cc76d1d4ca2101 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -32,9 +32,7 @@ enum OpInfoFillType { kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, - kShapeInference = 4, - kEstimateFlops = 5, - kUnknown = -1 + kShapeInference = 4 }; template @@ -50,10 +48,8 @@ struct OpInfoFillTypeID { ? kVarTypeInference : (std::is_base_of::value ? kShapeInference - : (std::is_base_of::value - ? kEstimateFlops - : kUnknown))))); + : static_cast( + -1))))); } }; @@ -143,16 +139,6 @@ struct OpInfoFiller { } }; -template -struct OpInfoFiller { - void operator()(const char* op_tpe, OpInfo* info) const { - info->estimate_flops_ = [](InferShapeContext* ctx) { - T estimate_flops; - return estimate_flops(ctx); - }; - } -}; - } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 292f232ffce48593e1827fe2dfe1b8472360054e..6d8f020918d4e56fa7f125a659f7f8511ca067ca 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -38,7 +38,7 @@ std::unique_ptr IsTestPass::ApplyImpl( for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); - if (op->HasAttr("is_test")) { + if (n->RuntimeHasAttr("is_test")) { op->SetAttr("is_test", true); } else if (std::find(begin(op_list), end(op_list), op->Type()) != end(op_list)) { diff --git a/paddle/fluid/framework/ir/is_test_pass_tester.cc b/paddle/fluid/framework/ir/is_test_pass_tester.cc index 9696441a21661db89146c448742a992d1f7df022..d9a68c7f1dd2a0dca5204719c4ce6cd9d68292a2 100644 --- a/paddle/fluid/framework/ir/is_test_pass_tester.cc +++ b/paddle/fluid/framework/ir/is_test_pass_tester.cc @@ -104,9 +104,9 @@ TEST(IsTestPass, basic) { auto* op = node->Op(); auto op_name = boost::get(op->GetAttr("name")); if (op_name == "conv3") { - ASSERT_FALSE(op->HasAttr("is_test")); + ASSERT_FALSE(node->RuntimeHasAttr("is_test")); } else { - ASSERT_TRUE(op->HasAttr("is_test")); + ASSERT_TRUE(node->RuntimeHasAttr("is_test")); EXPECT_TRUE(boost::get(op->GetAttr("is_test"))); } } diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc index 65be69b7f5b5e363d5d0753c45f9ff9e3f329fbe..1cf1315d3d3059261d84d0e8795a75df4deae005 100644 --- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc @@ -22,7 +22,7 @@ std::unique_ptr MKLDNNPlacementPass::ApplyImpl( std::unique_ptr graph) const { VLOG(3) << "Aplies MKL-DNN placement strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { + if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) { n->Op()->SetAttr("use_mkldnn", true); } } diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 50d9113088903aa7681d6c6af5cc65f846d32787..7a88cb2b681c1aa5e1b75481b1aba66a125a1d9c 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_info.h" namespace paddle { namespace framework { @@ -24,10 +25,33 @@ constexpr char Node::kControlDepVarName[]; const char Node::kControlDepVarName[] = "__control_var"; #endif -std::unique_ptr CreateNodeForTest(const std::string& name, +std::unique_ptr CreateNodeForTest(const std::string &name, Node::Type type) { return std::unique_ptr(new Node(name, type)); } + +bool Node::RuntimeHasAttr(const std::string &name) const { + if (Op()->HasAttr(name)) { + return true; + } else { + auto &op_info = OpInfoMap::Instance(); + auto op_type = Op()->Type(); + if (op_info.Has(op_type)) { + auto op_info_ptr = op_info.Get(op_type); + if (op_info_ptr.HasOpProtoAndChecker()) { + const proto::OpProto &proto = op_info_ptr.Proto(); + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + return true; + } + } + } + } + } + return false; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d2a393b3f19e9aab79098757dae663d030b0fa2b..1044a96430f060b750580ea0b225787ba6ebd2a0 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -108,6 +108,18 @@ class Node { Name().find(ir::Node::kControlDepVarName) != std::string::npos; } + // RuntimeHasAttr is different with HasAttr now. + // 1. For Op()->HasAttr(), it judges whether a stored program_desc_ has attr, + // thus, if stored program_desc_ are old which don't have an attr, a new + // library which adds the attr already will fail on this function. + // Details: + // https://github.com/PaddlePaddle/Paddle/pull/14608#issuecomment-442309087 + // 2. For Op()->RuntimeHasAttr, it judges the attr in runtime to avoid above + // problem. + // TODO(luotao): Maybe we should enhance HasAttr later, instead of adding + // RuntimeHasAttr. + bool RuntimeHasAttr(const std::string& name) const; + std::vector inputs; std::vector outputs; diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index 8177436d0bd90c3bcf8f91d5c55b66be188b19f9..e22c29037718a60ff7f24404d7749600e2edb80b 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -15,23 +15,105 @@ limitations under the License. */ #ifdef PADDLE_WITH_NGRAPH #include #include +#include #include "paddle/fluid/framework/ngraph_bridge.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" #include "ngraph/ngraph.hpp" namespace paddle { namespace framework { +static std::shared_ptr GetNode( + const std::shared_ptr& op, const std::string prm, + const VariableNameMap& var_map, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto& var_names = var_map.at(prm); + PADDLE_ENFORCE_EQ(var_names.size(), 1, + "op %s prm %s expects one associated var", op->Type(), prm); + if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { + return (*ngb_node_map)[var_names[0]]; + } else { + return nullptr; + } +} + +static std::shared_ptr GetInputNode( + const std::shared_ptr& op, const std::string prm, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + return GetNode(op, prm, op->Inputs(), ngb_node_map); +} + +static std::shared_ptr GetOutputNode( + const std::shared_ptr& op, const std::string prm, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + return GetNode(op, prm, op->Outputs(), ngb_node_map); +} + +static void SetOutputNode( + const std::shared_ptr& op, const std::string prm, + std::shared_ptr node, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto& var_names = op->Outputs().at(prm); + if (var_names.size() == 1) { + (*ngb_node_map)[var_names[0]] = node; + } else if (var_names.size() == 0) { + (*ngb_node_map)[""] = node; + } else { + PADDLE_THROW("prm %s has more than 1 var_names.", prm); + } +} + +static bool HasOutput(const std::shared_ptr& op, + const std::string prm) { + auto& outputs = op->Outputs(); + if (outputs.find(prm) == outputs.end()) return false; + return outputs.at(prm).size() > 0; +} + +template +static void BuildBinaryNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto x = GetInputNode(op, "X", ngb_node_map); + auto y = GetInputNode(op, "Y", ngb_node_map); + auto out = std::make_shared(x, y); + SetOutputNode(op, "Out", out, ngb_node_map); +} + +template +static void BuildUnaryNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = GetInputNode(op, "X", ngb_node_map); + auto out = std::make_shared(input); + SetOutputNode(op, "Out", out, ngb_node_map); +} + std::map&, std::shared_ptr>>)>> - NgraphBridge::NG_NODE_MAP = {}; + NgraphBridge::NG_NODE_MAP = {{"relu", BuildUnaryNode}, + {"tanh", BuildUnaryNode}}; -void NgraphBridge::build_graph(const std::shared_ptr& op) { +void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { auto& op_type = op->Type(); - NG_NODE_MAP[op_type](op, ngb_node_map); + NG_NODE_MAP[op_type](op, ngb_node_map_); } } // namespace framework diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/framework/ngraph_bridge.h index 55bf0d21f3471013b1fb780e852d813313345f03..9ed6b9510942136a61faa5755fd8fa74286939a8 100644 --- a/paddle/fluid/framework/ngraph_bridge.h +++ b/paddle/fluid/framework/ngraph_bridge.h @@ -20,16 +20,14 @@ limitations under the License. */ #include #include #include -#include -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/enforce.h" - -#include "ngraph/ngraph.hpp" +#include "ngraph/node.hpp" namespace paddle { namespace framework { +class OperatorBase; + class NgraphBridge { public: static std::map< @@ -43,14 +41,14 @@ class NgraphBridge { std::shared_ptr< std::unordered_map>> var_node_map) - : ngb_node_map(var_node_map) {} + : ngb_node_map_(var_node_map) {} - void build_graph(const std::shared_ptr& op); + void BuildNgNode(const std::shared_ptr& op); private: std::shared_ptr< std::unordered_map>> - ngb_node_map; + ngb_node_map_; }; } // namespace framework diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index d967b2780c21713a2f9a73a3402964103f44269e..3fea753f0659019395c9b214e52a7912058c501c 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -19,14 +19,29 @@ limitations under the License. */ #include #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/framework/ngraph_operator.h" -#include "paddle/fluid/framework/shape_inference.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_type.h" +#include "ngraph/ngraph.hpp" + namespace paddle { namespace framework { +static ngraph::Shape Ddim2Shape(const DDim& dims) { + ngraph::Shape sp; + for (int i = 0; i < dims.size(); ++i) { + int k = dims[i]; + k = k == 0 ? 1 : k; + sp.push_back(k); + } + return sp; +} + static std::map pd2ng_type_map = { {proto::VarType::FP32, ngraph::element::f32}, {proto::VarType::FP64, ngraph::element::f64}, @@ -42,6 +57,7 @@ typedef enum { /* nGraph support state on ops */ PARTIAL_TEST /* Support partial list of ops for test */ } op_state; +// perform graph build through bridge and execute computation class NgraphOperator { public: explicit NgraphOperator(const Scope& scope, const platform::Place& place, @@ -59,13 +75,23 @@ class NgraphOperator { persistables_(persist), fetches_(fetches), post_op_inputs_(post_op_inputs), - ng_op_state_(ng_op_state) {} + ng_op_state_(ng_op_state) { + var_in_node_map_ = std::make_shared< + std::unordered_map>>(); + + var_node_map_ = std::make_shared< + std::unordered_map>>(); + + BuildNgIO(); + + GetNgFunction(); + } void Run(const Scope& scope, const platform::Place& place) const; private: static std::unordered_map> - func_cache; + func_cache_; const Scope& scope_; const platform::Place& place_; std::vector> fused_ops_; @@ -74,6 +100,35 @@ class NgraphOperator { std::unordered_set fetches_; std::unordered_set post_op_inputs_; op_state ng_op_state_; + + // ngraph backend eg. CPU + static std::shared_ptr backend_; + // ngraph function to call and execute + std::shared_ptr ngraph_function_; + // var_name of inputs + std::vector var_in_; + // var_name of outputs from fetch in order + std::vector var_out_; + // map input vars to nodes + std::shared_ptr< + std::unordered_map>> + var_in_node_map_; + // map each var name with a ngraph node + std::shared_ptr< + std::unordered_map>> + var_node_map_; + // cache key to check if function is cached + std::shared_ptr GetCacheKey(); + // get ngraph input and define ngraph input parameters + void GetNgInputShape(std::shared_ptr op); + // Call ngraph bridge to map ops + void BuildNgNodes(); + // get the ngraph input and output var list + void BuildNgIO(); + // build ngraph function call + void BuildNgFunction(); + // Check cache for ngraph function or otherwise build the function + void GetNgFunction(); }; std::vector>::iterator>> @@ -86,7 +141,7 @@ FusedOperator::FusedOpIntervals( } size_t size = ops->size(); size_t left = 0; - while (left < size && ops.at(left)->Type() != kFeedOpType) { + while (left < size && ops->at(left)->Type() != kFeedOpType) { ++left; } if (left == size) { @@ -116,7 +171,7 @@ FusedOperator::FusedOpIntervals( size_t start = pivot, end = start; while (pivot < right && (paddle::framework::NgraphBridge::NG_NODE_MAP.find( - ops.at(pivot)->Type()) != + ops->at(pivot)->Type()) != paddle::framework::NgraphBridge::NG_NODE_MAP.end())) { ++pivot; ++end; @@ -136,7 +191,9 @@ FusedOperator::FusedOperator( std::vector>::iterator end, const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) { + : OperatorBase(type, inputs, outputs, attrs), + pdesc_(prog), + block_(block_id) { for (std::vector>::iterator it = start; it != end; ++it) { fused_ops_.push_back(std::move(*it)); @@ -152,7 +209,7 @@ FusedOperator::FusedOperator( } if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) { - is_complete = true; + is_full_ = true; } Process(); @@ -205,7 +262,7 @@ void FusedOperator::RunImpl(const Scope& scope, } } - if (is_full) { + if (is_full_) { ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN; } @@ -215,6 +272,280 @@ void FusedOperator::RunImpl(const Scope& scope, ngraph_op.Run(scope, place); } +std::unordered_map> + NgraphOperator::func_cache_ = {}; + +std::shared_ptr NgraphOperator::backend_ = + ngraph::runtime::Backend::create("CPU"); + +void NgraphOperator::GetNgInputShape(std::shared_ptr op) { + op->RuntimeInferShape(scope_, place_); + for (auto& var_name_item : op->Inputs()) { + for (auto& var_name : var_name_item.second) { + auto* var = scope_.FindVar(var_name); + if (var && var->IsType()) { + auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); + auto sp = Ddim2Shape(tensor_pd->dims()); + if (std::find(var_in_.begin(), var_in_.end(), var_name) != + var_in_.end()) { + if (var_node_map_->find(var_name) == var_node_map_->end()) { + auto ng_type = var_type_map_.at(var_name); + auto prm = + std::make_shared(ng_type, sp, true); + (*var_node_map_)[var_name] = prm; + (*var_in_node_map_)[var_name] = prm; + } + } + } + } + } +} + +void NgraphOperator::BuildNgNodes() { + for (auto& var_name : var_out_) { + if (var_node_map_->find(var_name) == var_node_map_->end()) { + auto* var = scope_.FindVar(var_name); + if (var && var->IsType()) { + auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); + auto& ddim = tensor_pd->dims(); + auto ng_shape = Ddim2Shape(ddim); + auto ng_type = var_type_map_.at(var_name); + auto prm = + std::make_shared(ng_type, ng_shape, true); + (*var_node_map_)[var_name] = prm; + } + } + } + + paddle::framework::NgraphBridge ngb(var_node_map_); + for (auto& op : fused_ops_) { + ngb.BuildNgNode(op); + } +} + +void NgraphOperator::BuildNgIO() { + std::unordered_set inputs; + std::unordered_set outputs; + + for (auto& op : fused_ops_) { + for (auto& var_name_item : op->Inputs()) { + for (auto& var_name : var_name_item.second) { + inputs.insert(var_name); + const bool is_output = outputs.find(var_name) != outputs.end(); + if (!is_output && + std::find(var_in_.begin(), var_in_.end(), var_name) == + var_in_.end()) { + // fill var_in here to keep lhs and rhs order + var_in_.push_back(var_name); + } + } + } + + if (op->Type() != "fill_constant") { + GetNgInputShape(op); + } + + for (auto& var_name_item : op->Outputs()) { + PADDLE_ENFORCE_LE(var_name_item.second.size(), 1, + "op %s has more than 1 output - Not handling yet", + op->Type()); + for (auto& var_name : var_name_item.second) { + outputs.insert(var_name); + } + } + } + + // var_out.clear(); + for (auto& op : fused_ops_) { + for (auto& var_name_item : op->Outputs()) { + PADDLE_ENFORCE_LE(var_name_item.second.size(), 1, + "op %s has more than 1 output - Not handling yet", + op->Type()); + for (auto& var_name : var_name_item.second) { + switch (ng_op_state_) { + case PARTIAL_TEST: + if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || + fetches_.find(var_name) != fetches_.end()) { + var_out_.push_back(var_name); + } + break; + case FULL_TEST: + if (fetches_.find(var_name) != fetches_.end()) { + var_out_.push_back(var_name); + } + break; + case PARTIAL_TRAIN: + if (fetches_.find(var_name) != fetches_.end() || + post_op_inputs_.find(var_name) != post_op_inputs_.end() || + persistables_.find(var_name) != persistables_.end()) { + var_out_.push_back(var_name); + } + break; + case FULL_TRAIN: + if (fetches_.find(var_name) != fetches_.end() || + persistables_.find(var_name) != persistables_.end()) { + var_out_.push_back(var_name); + } + break; + default: + var_out_.push_back(var_name); + } + } + } + } +} + +void NgraphOperator::BuildNgFunction() { + BuildNgNodes(); + ngraph_function_ = nullptr; + ngraph::NodeVector func_outputs; + ngraph::op::ParameterVector func_inputs; + + for (auto& vo : var_out_) { + func_outputs.push_back(var_node_map_->at(vo)); + } + + for (auto& vi : var_in_) { + std::shared_ptr prm = + std::dynamic_pointer_cast( + var_in_node_map_->at(vi)); + func_inputs.push_back(prm); + } + + ngraph_function_ = + std::make_shared(func_outputs, func_inputs); +} + +std::shared_ptr NgraphOperator::GetCacheKey() { + auto cache_key = std::make_shared(""); + *cache_key += std::to_string(fused_ops_.size()); + for (auto& op : fused_ops_) { + *cache_key += op->Type(); + } + for (auto& var_name : var_in_) { + auto shape = var_node_map_->at(var_name)->get_shape(); + *cache_key += var_name; + *cache_key += var_type_map_.at(var_name).c_type_string(); + for (size_t i = 0; i < shape.size(); ++i) { + *cache_key += std::to_string(shape.at(i)); + } + } + + for (auto& var_name : var_out_) { + auto* var = scope_.FindVar(var_name); + if (var && var->IsType()) { + auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); + auto& ddim = tensor_pd->dims(); + for (int i = 0; i < ddim.size(); ++i) { + *cache_key += std::to_string(ddim[i]); + } + } + } + return cache_key; +} + +void NgraphOperator::GetNgFunction() { + bool cache_on = true; + if (cache_on) { + std::string cache_key_val = *GetCacheKey(); + if (func_cache_.find(cache_key_val) != func_cache_.end()) { + ngraph_function_ = func_cache_.at(cache_key_val); + } else { + BuildNgFunction(); + func_cache_[cache_key_val] = ngraph_function_; + } + } else { + BuildNgFunction(); + } +} + +void NgraphOperator::Run(const Scope& scope, + const platform::Place& place) const { + std::vector> t_in; + std::vector> t_out; + + for (size_t i = 0; i < var_in_.size(); ++i) { + auto vi = var_in_.at(i); + auto sp = var_node_map_->at(vi)->get_shape(); + std::shared_ptr ti; + auto* var = scope.FindVar(vi); + if (var && var->IsType()) { + auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); + PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()), + "Ensure ngraph tensor layout align with paddle tensor"); + if (tensor_pd->type().hash_code() == + typeid(float).hash_code()) { // NOLINT + const float* arr = tensor_pd->data(); + ti = backend_->create_tensor(ngraph::element::f32, sp, + const_cast(arr)); + } else if (tensor_pd->type().hash_code() == + typeid(int).hash_code()) { // NOLINT + const int* arr = tensor_pd->data(); + ti = backend_->create_tensor(ngraph::element::i32, sp, + const_cast(arr)); + } else if (tensor_pd->type().hash_code() == typeid(int64_t).hash_code()) { + const int64_t* arr = tensor_pd->data(); + ti = backend_->create_tensor(ngraph::element::i64, sp, + const_cast(arr)); + } else if (tensor_pd->type().hash_code() == + typeid(double).hash_code()) { // NOLINT + const double* arr = tensor_pd->data(); + ti = backend_->create_tensor(ngraph::element::f64, sp, + const_cast(arr)); + } else if (tensor_pd->type().hash_code() == + typeid(bool).hash_code()) { // NOLINT + const bool* arr = tensor_pd->data(); + ti = backend_->create_tensor(ngraph::element::boolean, sp, + const_cast(arr)); + } else { + PADDLE_THROW("Data type not handling for var %s", vi); + } + } else { + PADDLE_THROW("Cannot find var or tensor with var name %s", vi); + } + bool is_test = (ng_op_state_ == PARTIAL_TEST || ng_op_state_ == FULL_TEST) + ? true + : false; + bool is_persistable = + (persistables_.find(vi) != persistables_.end()) ? true : false; + if (is_test && is_persistable) { + ti->set_stale(false); + } + t_in.push_back(ti); + } + + for (size_t i = 0; i < var_out_.size(); ++i) { + auto var_name = var_out_[i]; + auto* var = scope.FindVar(var_name); + std::shared_ptr to; + if (var && var->IsType()) { + auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); + auto dd = tensor_pd->dims(); + ngraph::Shape sp = Ddim2Shape(dd); + auto ng_type = var_type_map_.at(var_name); + if (ng_type == ngraph::element::f32) { + auto pd_arr = tensor_pd->mutable_data(place); + to = backend_->create_tensor(ngraph::element::f32, sp, pd_arr); + } else if (ng_type == ngraph::element::i64) { + auto pd_arr = tensor_pd->mutable_data(place); + to = backend_->create_tensor(ngraph::element::i64, sp, pd_arr); + } else if (ng_type == ngraph::element::f64) { + auto pd_arr = tensor_pd->mutable_data(place); + to = backend_->create_tensor(ngraph::element::f64, sp, pd_arr); + } else if (ng_type == ngraph::element::boolean) { + auto pd_arr = tensor_pd->mutable_data(place); + to = backend_->create_tensor(ngraph::element::boolean, sp, pd_arr); + } else { + PADDLE_THROW("Data type not handled in for var %s", var_name); + } + t_out.push_back(to); + } else { + PADDLE_THROW("Cannot find var or tensor with var name %s", var_name); + } + } + + backend_->call(ngraph_function_, t_out, t_in); +} // NgraphOperator::RunImpl } // namespace framework } // namespace paddle #endif diff --git a/paddle/fluid/framework/ngraph_operator.h b/paddle/fluid/framework/ngraph_operator.h index 0f655cef1dde624bcf4944b5c096279097e1c8ae..3ca023e11111c5b447b2cabbfb8bb29877297f65 100644 --- a/paddle/fluid/framework/ngraph_operator.h +++ b/paddle/fluid/framework/ngraph_operator.h @@ -17,24 +17,19 @@ limitations under the License. */ #ifdef PADDLE_WITH_NGRAPH #include -#include #include #include #include #include "paddle/fluid/framework/attribute.h" -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/variant.h" -#include "ngraph/ngraph.hpp" +#include "ngraph/type/element_type.hpp" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index e0bf5ed999f580f217af285bf97d0bc0232f1ded..19e5c2c73eac74dee030a4f7820531800f737e4e 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -31,12 +31,6 @@ class InferShapeBase { virtual void operator()(InferShapeContext*) const = 0; }; -class EstimateFlopsBase { - public: - virtual ~EstimateFlopsBase() = default; - virtual size_t operator()(InferShapeContext*) const = 0; -}; - struct OpInfo { OpCreator creator_; GradOpMakerFN grad_op_maker_; @@ -44,7 +38,6 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; - EstimateFlopsFN estimate_flops_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8bfdf3891203823826fd5bf919c176011f22213c..c6f3254e9f7cedcf47be8ce8c3eecf4aa1b57add 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -695,6 +695,12 @@ static void CheckTensorNANOrInf(const std::string& name, "Tensor %s contains NAN", name); } +void OperatorWithKernel::RuntimeInferShape(const Scope& scope, + const platform::Place& place) const { + RuntimeInferShapeContext infer_shape_ctx(*this, scope); + this->InferShape(&infer_shape_ctx); +} + void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5bd68f9ac2e1b30bc6ce3094960bb89842b99e01..0a6a28a5bce01d71cf56f25f5556033db94452c2 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -128,6 +128,8 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } + virtual void RuntimeInferShape(const Scope& scope, + const platform::Place& place) const {} protected: std::string type_; @@ -348,6 +350,9 @@ class OperatorWithKernel : public OperatorBase { OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); } + void RuntimeInferShape(const Scope& scope, + const platform::Place& place) const override; + protected: virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index cdc5fa6862e3b2a2784151302f15540a0e9db8ff..2de6233a9e0d320ec9a06d547db3575eb61925c0 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -54,7 +54,5 @@ using InferVarTypeFN = using InferShapeFN = std::function; -using EstimateFlopsFN = std::function; - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 8fb464c0f5443f116815b14324f6cbc966dc6482..ec93729cd2b379dc2ac39b51df6799b74c8529b6 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -79,6 +79,16 @@ link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/paddle/lib") +if (NOT WIN32) + set(NGRAPH_PATH "${PADDLE_LIB}/third_party/install/ngraph") + if(EXISTS ${NGRAPH_PATH}) + include(GNUInstallDirs) + include_directories("${NGRAPH_PATH}/include") + link_directories("${NGRAPH_PATH}/${CMAKE_INSTALL_LIBDIR}") + set(NGRAPH_LIB ${NGRAPH_PATH}/${CMAKE_INSTALL_LIBDIR}/libngraph${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() +endif() + add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) @@ -106,7 +116,7 @@ endif() if (NOT WIN32) set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(DEPS ${DEPS} - ${MATH_LIB} ${MKLDNN_LIB} + ${MATH_LIB} ${MKLDNN_LIB} ${NGRAPH_LIB} glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) else() diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h index 18acb735cecabd1e01f7821c880fd8ed5e52971f..8fceed3558b4357b7863368c18add329ea9922b3 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h @@ -36,12 +36,10 @@ class SequenceMaskOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); - auto maxlen = ctx->Attrs().Get("maxlen"); - if (maxlen > 0) { // We can only infershape when maxlen > 0 - auto dim = framework::vectorize2int(ctx->GetInputDim("X")); - dim.push_back(maxlen); - ctx->SetOutputDim("Y", framework::make_ddim(dim)); - } + int maxlen = ctx->Attrs().Get("maxlen"); + auto dim = framework::vectorize2int(ctx->GetInputDim("X")); + dim.push_back(maxlen > 0 ? maxlen : -1); + ctx->SetOutputDim("Y", framework::make_ddim(dim)); } }; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index 193de05422bb78572c0e5eaf4cd46744c3bcb113..14746fa95159d707be7c10c69a4ffc2211e17a93 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { using framework::Tensor; +const int kIgnoreIndex = -100; class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { public: @@ -100,6 +101,11 @@ class SigmoidCrossEntropyWithLogitsOpMaker AddOutput("Out", "(Tensor, default Tensor), a 2-D tensor with shape N x D " " of elementwise logistic losses."); + AddAttr("ignore_index", + "(int, default kIgnoreIndex), Specifies a target value that " + "is ignored and" + "does not contribute to the input gradient.") + .SetDefault(kIgnoreIndex); AddComment(R"DOC( SigmoidCrossEntropyWithLogits Operator. diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h index faef72866eb491887bbf221d32a8121b21fc3c66..b8731c232753074fa9e76b028485d3598c9a7295 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h @@ -15,33 +15,72 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/hostdevice.h" +#include "paddle/legacy/utils/Logging.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct SigmoidCrossEntropyWithLogitsForward { + HOSTDEVICE SigmoidCrossEntropyWithLogitsForward(const int &ignore_index) + : ignore_index(ignore_index) {} + + HOSTDEVICE T operator()(const T &x, const T &label) const { + if (static_cast(label) == ignore_index) { + return static_cast(0.); + } + T term1 = (x > 0) ? x : 0; + T term2 = x * label; + T term3 = std::log(static_cast(1) + std::exp(-(std::abs(x)))); + return term1 - term2 + term3; + } + + int ignore_index; +}; + +template +struct SigmoidCrossEntropyWithLogitsBackward { + HOSTDEVICE SigmoidCrossEntropyWithLogitsBackward(const int &ignore_index) + : ignore_index(ignore_index) {} + + HOSTDEVICE T operator()(const T &x, const T &label) const { + if (static_cast(label) == ignore_index) { + return static_cast(0.); + } + T simoid_x = static_cast(1) / (static_cast(1) + std::exp(-x)); + return simoid_x - label; + } + + int ignore_index; +}; + // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) template class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *X = context.Input("X"); - const framework::Tensor *Labels = context.Input("Label"); - framework::Tensor *Out = context.Output("Out"); + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + Tensor *Out = context.Output("Out"); Out->mutable_data(context.GetPlace()); + int ignore_index = context.Attr("ignore_index"); - auto x = framework::EigenVector::Flatten(*X); - auto labels = framework::EigenVector::Flatten(*Labels); - auto out = framework::EigenVector::Flatten(*Out); + auto x = EigenVector::Flatten(*X); + auto labels = EigenVector::Flatten(*Labels); + auto out = EigenVector::Flatten(*Out); auto &place = *context.device_context().eigen_device(); - // term1 = max(x, 0) - auto term1 = x.cwiseMax(static_cast(0)); - // term2 = x * labels - auto term2 = x * labels; - // term3 = log(1 + exp(-abs(x))) - auto term3 = (static_cast(1) + (-(x.abs())).exp()).log(); - - out.device(place) = term1 - term2 + term3; + out.device(place) = x.binaryExpr( + labels, SigmoidCrossEntropyWithLogitsForward(ignore_index)); } }; @@ -50,23 +89,23 @@ template class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *X = context.Input("X"); - const framework::Tensor *Labels = context.Input("Label"); - const framework::Tensor *dOut = - context.Input(framework::GradVarName("Out")); - framework::Tensor *dX = - context.Output(framework::GradVarName("X")); + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + Tensor *dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); - auto x = framework::EigenVector::Flatten(*X); - auto labels = framework::EigenVector::Flatten(*Labels); - auto dout = framework::EigenVector::Flatten(*dOut); - auto dx = framework::EigenVector::Flatten(*dX); + auto ignore_index = context.Attr("ignore_index"); + auto x = EigenVector::Flatten(*X); + auto labels = EigenVector::Flatten(*Labels); + auto dout = EigenVector::Flatten(*dOut); + auto dx = EigenVector::Flatten(*dX); auto &place = *context.template device_context().eigen_device(); - auto sigmoid_x = static_cast(1) / (static_cast(1) + (-x).exp()); - dx.device(place) = dout * (sigmoid_x - labels); + auto diff = x.binaryExpr(labels, SigmoidCrossEntropyWithLogitsBackward( + static_cast(ignore_index))); + dx.device(place) = dout * diff; } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7597f732430038a4a180297e730340d1bc47b8c --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -0,0 +1,221 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/yolov3_loss_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class Yolov3LossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTBox"), + "Input(GTBox) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTLabel"), + "Input(GTLabel) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of Yolov3LossOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_gtbox = ctx->GetInputDim("GTBox"); + auto dim_gtlabel = ctx->GetInputDim("GTLabel"); + auto anchors = ctx->Attrs().Get>("anchors"); + auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); + PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], + "Input(X) dim[3] and dim[4] should be euqal."); + PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_number * (5 " + "+ class_num))."); + PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, + "Input(GTBox) should be a 3-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); + PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2, + "Input(GTBox) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0], + "Input(GTBox) and Input(GTLabel) dim[0] should be same"); + PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], + "Input(GTBox) and Input(GTLabel) dim[1] should be same"); + PADDLE_ENFORCE_GT(anchors.size(), 0, + "Attr(anchors) length should be greater then 0."); + PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, + "Attr(anchors) length should be even integer."); + PADDLE_ENFORCE_GT(class_num, 0, + "Attr(class_num) should be an integer greater then 0."); + + std::vector dim_out({1}); + ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of YOLO v3 loss operator, " + "This is a 4-D tensor with shape of [N, C, H, W]." + "H and W should be same, and the second dimention(C) stores" + "box locations, confidence score and classification one-hot" + "key of each anchor box"); + AddInput("GTBox", + "The input tensor of ground truth boxes, " + "This is a 3-D tensor with shape of [N, max_box_num, 5], " + "max_box_num is the max number of boxes in each image, " + "In the third dimention, stores x, y, w, h coordinates, " + "x, y is the center cordinate of boxes and w, h is the " + "width and height and x, y, w, h should be divided by " + "input image height to scale to [0, 1]."); + AddInput("GTLabel", + "The input tensor of ground truth label, " + "This is a 2-D tensor with shape of [N, max_box_num], " + "and each element shoudl be an integer to indicate the " + "box class id."); + AddOutput("Loss", + "The output yolov3 loss tensor, " + "This is a 1-D tensor with shape of [1]"); + + AddAttr("class_num", "The number of classes to predict."); + AddAttr>("anchors", + "The anchor width and height, " + "it will be parsed pair by pair."); + AddAttr("ignore_thresh", + "The ignore threshold to ignore confidence loss."); + AddAttr("loss_weight_xy", "The weight of x, y location loss.") + .SetDefault(1.0); + AddAttr("loss_weight_wh", "The weight of w, h location loss.") + .SetDefault(1.0); + AddAttr( + "loss_weight_conf_target", + "The weight of confidence score loss in locations with target object.") + .SetDefault(1.0); + AddAttr("loss_weight_conf_notarget", + "The weight of confidence score loss in locations without " + "target object.") + .SetDefault(1.0); + AddAttr("loss_weight_class", "The weight of classification loss.") + .SetDefault(1.0); + AddComment(R"DOC( + This operator generate yolov3 loss by given predict result and ground + truth boxes. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, specify the grid size, each grid point predict given + number boxes, this given number is specified by anchors, it should be + half anchors length, which following will be represented as S. In the + second dimention(the channel dimention), C should be S * (class_num + 5), + class_num is the box categoriy number of source dataset(such as coco), + so in the second dimention, stores 4 box location coordinates x, y, w, h + and confidence score of the box and class one-hot key of each anchor box. + + While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions + correspnd to: + + $$ + b_x = \sigma(t_x) + c_x + b_y = \sigma(t_y) + c_y + b_w = p_w e^{t_w} + b_h = p_h e^{t_h} + $$ + + While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$ + is specified by anchors. + + As for confidence score, it is the logistic regression value of IoU between + anchor boxes and ground truth boxes, the score of the anchor box which has + the max IoU should be 1, and if the anchor box has IoU bigger then ignore + thresh, the confidence score loss of this anchor box will be ignored. + + Therefore, the yolov3 loss consist of three major parts, box location loss, + confidence score loss, and classification loss. The MSE loss is used for + box location, and binary cross entropy loss is used for confidence score + loss and classification loss. + + Final loss will be represented as follow. + + $$ + loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh} + + \loss_weight_{conf_target} * loss_{conf_target} + + \loss_weight_{conf_notarget} * loss_{conf_notarget} + + \loss_weight_{class} * loss_{class} + $$ + )DOC"); + } +}; + +class Yolov3LossOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("yolov3_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("GTBox", Input("GTBox")); + op->SetInput("GTLabel", Input("GTLabel")); + op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("GTBox"), {}); + op->SetOutput(framework::GradVarName("GTLabel"), {}); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, + ops::Yolov3LossGradMaker); +REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); +REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0bb285722ddedf721d98237760ec9868e2134442 --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -0,0 +1,483 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenTensor = framework::EigenTensor; +template +using EigenVector = framework::EigenVector; + +using Array5 = Eigen::DSizes; + +template +static inline bool isZero(T x) { + return fabs(x) < 1e-6; +} + +template +static inline T sigmoid(T x) { + return 1.0 / (exp(-1.0 * x) + 1.0); +} + +template +static inline T CalcMaskPointNum(const Tensor& mask) { + auto mask_t = EigenVector::Flatten(mask); + T count = 0.0; + for (int i = 0; i < mask_t.dimensions()[0]; i++) { + if (mask_t(i)) { + count += 1.0; + } + } + return count; +} + +template +static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + T error_sum = 0.0; + T points = 0.0; + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + error_sum += pow(x_t(i) - y_t(i), 2); + points += 1; + } + } + return (error_sum / points); +} + +template +static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, + const Tensor& mask, T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; + } + } +} + +template +static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + T error_sum = 0.0; + T points = 0.0; + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + error_sum += + -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i))); + points += 1; + } + } + return (error_sum / points); +} + +template +static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& mask, + T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; + } + } +} + +template +static void CalcPredResult(const Tensor& input, Tensor* pred_conf, + Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, const int anchor_num, + const int class_num) { + const int n = input.dims()[0]; + const int h = input.dims()[2]; + const int w = input.dims()[3]; + const int box_attr_num = 5 + class_num; + + auto input_t = EigenTensor::From(input); + auto pred_conf_t = EigenTensor::From(*pred_conf); + auto pred_class_t = EigenTensor::From(*pred_class); + auto pred_x_t = EigenTensor::From(*pred_x); + auto pred_y_t = EigenTensor::From(*pred_y); + auto pred_w_t = EigenTensor::From(*pred_w); + auto pred_h_t = EigenTensor::From(*pred_h); + + for (int i = 0; i < n; i++) { + for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + for (int j = 0; j < h; j++) { + for (int k = 0; k < w; k++) { + pred_x_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx, j, k)); + pred_y_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); + pred_w_t(i, an_idx, j, k) = + input_t(i, box_attr_num * an_idx + 2, j, k); + pred_h_t(i, an_idx, j, k) = + input_t(i, box_attr_num * an_idx + 3, j, k); + + pred_conf_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); + + for (int c = 0; c < class_num; c++) { + pred_class_t(i, an_idx, j, k, c) = + sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + } + } + } + } + } +} + +template +static T CalcBoxIoU(std::vector box1, std::vector box2) { + T b1_x1 = box1[0] - box1[2] / 2; + T b1_x2 = box1[0] + box1[2] / 2; + T b1_y1 = box1[1] - box1[3] / 2; + T b1_y2 = box1[1] + box1[3] / 2; + T b2_x1 = box2[0] - box2[2] / 2; + T b2_x2 = box2[0] + box2[2] / 2; + T b2_y1 = box2[1] - box2[3] / 2; + T b2_y2 = box2[1] + box2[3] / 2; + + T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1); + T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1); + + T inter_rect_x1 = std::max(b1_x1, b2_x1); + T inter_rect_y1 = std::max(b1_y1, b2_y1); + T inter_rect_x2 = std::min(b1_x2, b2_x2); + T inter_rect_y2 = std::min(b1_y2, b2_y2); + T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast(0.0)) * + std::max(inter_rect_y2 - inter_rect_y1, static_cast(0.0)); + + return inter_area / (b1_area + b2_area - inter_area); +} + +template +static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, + const float ignore_thresh, std::vector anchors, + const int grid_size, Tensor* obj_mask, + Tensor* noobj_mask, Tensor* tx, Tensor* ty, + Tensor* tw, Tensor* th, Tensor* tconf, + Tensor* tclass) { + const int n = gt_box.dims()[0]; + const int b = gt_box.dims()[1]; + const int anchor_num = anchors.size() / 2; + auto gt_box_t = EigenTensor::From(gt_box); + auto gt_label_t = EigenTensor::From(gt_label); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); + auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto tx_t = EigenTensor::From(*tx).setConstant(0.0); + auto ty_t = EigenTensor::From(*ty).setConstant(0.0); + auto tw_t = EigenTensor::From(*tw).setConstant(0.0); + auto th_t = EigenTensor::From(*th).setConstant(0.0); + auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); + auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < b; j++) { + if (isZero(gt_box_t(i, j, 0)) && isZero(gt_box_t(i, j, 1)) && + isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + continue; + } + + int cur_label = gt_label_t(i, j); + T gx = gt_box_t(i, j, 0) * grid_size; + T gy = gt_box_t(i, j, 1) * grid_size; + T gw = gt_box_t(i, j, 2) * grid_size; + T gh = gt_box_t(i, j, 3) * grid_size; + int gi = static_cast(gx); + int gj = static_cast(gy); + + T max_iou = static_cast(0); + T iou; + int best_an_index = -1; + std::vector gt_box_shape({0, 0, gw, gh}); + for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), + static_cast(anchors[2 * an_idx + 1])}); + iou = CalcBoxIoU(gt_box_shape, anchor_shape); + if (iou > max_iou) { + max_iou = iou; + best_an_index = an_idx; + } + if (iou > ignore_thresh) { + noobj_mask_t(i, an_idx, gj, gi) = 0; + } + } + obj_mask_t(i, best_an_index, gj, gi) = 1; + noobj_mask_t(i, best_an_index, gj, gi) = 0; + tx_t(i, best_an_index, gj, gi) = gx - gi; + ty_t(i, best_an_index, gj, gi) = gy - gj; + tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); + th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); + tclass_t(i, best_an_index, gj, gi, cur_label) = 1; + tconf_t(i, best_an_index, gj, gi) = 1; + } + } +} + +static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, + const Tensor& obj_mask) { + const int n = obj_mask_expand->dims()[0]; + const int an_num = obj_mask_expand->dims()[1]; + const int h = obj_mask_expand->dims()[2]; + const int w = obj_mask_expand->dims()[3]; + const int class_num = obj_mask_expand->dims()[4]; + auto obj_mask_expand_t = EigenTensor::From(*obj_mask_expand); + auto obj_mask_t = EigenTensor::From(obj_mask); + + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); +} + +template +static void AddAllGradToInputGrad( + Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, + const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, + const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, + const Tensor& grad_conf_target, const Tensor& grad_conf_notarget, + const Tensor& grad_class, const int class_num, const float loss_weight_xy, + const float loss_weight_wh, const float loss_weight_conf_target, + const float loss_weight_conf_notarget, const float loss_weight_class) { + const int n = pred_x.dims()[0]; + const int an_num = pred_x.dims()[1]; + const int h = pred_x.dims()[2]; + const int w = pred_x.dims()[3]; + const int attr_num = class_num + 5; + auto grad_t = EigenTensor::From(*grad).setConstant(0.0); + auto pred_x_t = EigenTensor::From(pred_x); + auto pred_y_t = EigenTensor::From(pred_y); + auto pred_conf_t = EigenTensor::From(pred_conf); + auto pred_class_t = EigenTensor::From(pred_class); + auto grad_x_t = EigenTensor::From(grad_x); + auto grad_y_t = EigenTensor::From(grad_y); + auto grad_w_t = EigenTensor::From(grad_w); + auto grad_h_t = EigenTensor::From(grad_h); + auto grad_conf_target_t = EigenTensor::From(grad_conf_target); + auto grad_conf_notarget_t = EigenTensor::From(grad_conf_notarget); + auto grad_class_t = EigenTensor::From(grad_class); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + grad_t(i, j * attr_num, k, l) = + grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * + (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy; + grad_t(i, j * attr_num + 1, k, l) = + grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * + (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy; + grad_t(i, j * attr_num + 2, k, l) = + grad_w_t(i, j, k, l) * loss * loss_weight_wh; + grad_t(i, j * attr_num + 3, k, l) = + grad_h_t(i, j, k, l) * loss * loss_weight_wh; + grad_t(i, j * attr_num + 4, k, l) = + grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target; + grad_t(i, j * attr_num + 4, k, l) += + grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss * + loss_weight_conf_notarget; + + for (int c = 0; c < class_num; c++) { + grad_t(i, j * attr_num + 5 + c, k, l) = + grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * + (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class; + } + } + } + } + } +} + +template +class Yolov3LossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* gt_box = ctx.Input("GTBox"); + auto* gt_label = ctx.Input("GTLabel"); + auto* loss = ctx.Output("Loss"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + float loss_weight_xy = ctx.Attr("loss_weight_xy"); + float loss_weight_wh = ctx.Attr("loss_weight_wh"); + float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); + float loss_weight_conf_notarget = + ctx.Attr("loss_weight_conf_notarget"); + float loss_weight_class = ctx.Attr("loss_weight_class"); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_conf, pred_class; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); + + Tensor obj_mask, noobj_mask; + Tensor tx, ty, tw, th, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + + T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); + T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); + T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); + T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); + T loss_conf_target = CalcBCEWithMask(pred_conf, tconf, obj_mask); + T loss_conf_notarget = CalcBCEWithMask(pred_conf, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); + + auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); + loss_data[0] = loss_weight_xy * (loss_x + loss_y) + + loss_weight_wh * (loss_w + loss_h) + + loss_weight_conf_target * loss_conf_target + + loss_weight_conf_notarget * loss_conf_notarget + + loss_weight_class * loss_class; + } +}; + +template +class Yolov3LossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* gt_box = ctx.Input("GTBox"); + auto* gt_label = ctx.Input("GTLabel"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Loss")); + const T loss = output_grad->data()[0]; + float loss_weight_xy = ctx.Attr("loss_weight_xy"); + float loss_weight_wh = ctx.Attr("loss_weight_wh"); + float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); + float loss_weight_conf_notarget = + ctx.Attr("loss_weight_conf_notarget"); + float loss_weight_class = ctx.Attr("loss_weight_class"); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_conf, pred_class; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); + + Tensor obj_mask, noobj_mask; + Tensor tx, ty, tw, th, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + + Tensor grad_x, grad_y, grad_w, grad_h; + Tensor grad_conf_target, grad_conf_notarget, grad_class; + grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + T obj_mf = CalcMaskPointNum(obj_mask); + T noobj_mf = CalcMaskPointNum(noobj_mask); + T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); + CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); + CalcBCEGradWithMask(&grad_conf_target, pred_conf, tconf, obj_mask, + obj_mf); + CalcBCEGradWithMask(&grad_conf_notarget, pred_conf, tconf, noobj_mask, + noobj_mf); + CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, + obj_expand_mf); + + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + AddAllGradToInputGrad( + input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, + grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class, + class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target, + loss_weight_conf_notarget, loss_weight_class); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index dbb73f7a27ac815ecfeee2efcc09bb2cafb8395e..62a2b08a8a13c4ac3c06cc5e3cc26ecc3568db17 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -442,6 +442,8 @@ EOF make install -j 8 if [ "$1" == "cp27-cp27m" ]; then pip install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl + set -e + python -c "import paddle.fluid" elif [ "$1" == "cp35-cp35m" ]; then pip3.5 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl elif [ "$1" == "cp36-cp36m" ]; then @@ -449,7 +451,7 @@ EOF elif [ "$1" == "cp37-cp37m" ]; then pip3.7 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl fi - + if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]] ; then paddle version fi diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 1738afe93e99f1de28bec2fb23be8e1a309d9288..5b8ff0514e7c87df786e71b5cea2d6ebfbaad6d4 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -134,12 +134,12 @@ class GradientClipByValue(BaseGradientClipAttr): Examples: .. code-block:: python - w_param_attrs = ParamAttr(name=None, - initializer=UniformInitializer(low=-1.0, high=1.0, seed=0), + w_param_attrs = fluid.ParamAttr(name=None, + initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0), learning_rate=1.0, - regularizer=L1Decay(1.0), + regularizer=fluid.regularizer.L1Decay(1.0), trainable=True, - clip=GradientClipByValue(-1.0, 1.0)) + clip=fluid.clip.GradientClipByValue(-1.0, 1.0)) y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs) """ @@ -185,12 +185,12 @@ class GradientClipByNorm(BaseGradientClipAttr): Examples: .. code-block:: python - w_param_attrs = ParamAttr(name=None, - initializer=UniformInitializer(low=-1.0, high=1.0, seed=0), + w_param_attrs = flui.ParamAttr(name=None, + initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0), learning_rate=1.0, - regularizer=L1Decay(1.0), + regularizer=fluid.regularizer.L1Decay(1.0), trainable=True, - clip=GradientClipByNorm(clip_norm=2.0)) + clip=fluid.clip.GradientClipByNorm(clip_norm=2.0)) y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs) """ diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 42c2484b284844a1f1acf53f79296e13da72676a..f2886090d75f87654b33cf7aa6f98ebf6f2e27d1 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -20,7 +20,7 @@ import six from .framework import Program, default_main_program, Variable from . import core -__all__ = ['Executor', 'global_scope', 'scope_guard', '_switch_scope'] +__all__ = ['Executor', 'global_scope', 'scope_guard'] g_scope = core.Scope() @@ -407,16 +407,17 @@ class Executor(object): Examples: - >>> data = layers.data(name='X', shape=[1], dtype='float32') - >>> hidden = layers.fc(input=data, size=10) - >>> layers.assign(hidden, out) - >>> loss = layers.mean(out) + >>> data = fluid.layers.data(name='X', shape=[1], dtype='float32') + >>> out = fluid.layers.create_tensor(dtype='float32') + >>> hidden = fluid.layers.fc(input=data, size=10) + >>> fluid.layers.assign(hidden,out) + >>> loss = fluid.layers.mean(out) >>> adam = fluid.optimizer.Adam() - >>> adam.minimize(loss) + >>> adam.minimize(loss) >>> cpu = core.CPUPlace() - >>> exe = Executor(cpu) - >>> exe.run(default_startup_program()) + >>> exe = fluid.Executor(cpu) + >>> exe.run(fluid.default_startup_program()) >>> x = numpy.random.random(size=(10, 1)).astype('float32') >>> outs = exe.run( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b991187d424108db176ebd6996d7d161f11dcd3d..b156db53d2928daefed0959fc3e0731709855343 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -89,12 +89,13 @@ def name_scope(prefix=None): Examples: .. code-block:: python + with name_scope("encoder"): ... with name_scope("decoder"): ... - with name_scope("attention"): - ... + with name_scope("attention"): + ... """ # TODO(panyx0718): Only [0-9a-z]. assert prefix, "namescope prefix cannot be empty." diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 4843af8340310e0f47964d41708b13216fcd2161..ce731f39ea099a4d8948812989ad19b3cce119ff 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,6 +20,7 @@ from __future__ import print_function from .layer_function_generator import generate_layer_fn from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper +from ..framework import Variable from . import tensor from . import nn from . import ops @@ -46,6 +47,7 @@ __all__ = [ 'iou_similarity', 'box_coder', 'polygon_box_transform', + 'yolov3_loss', ] @@ -401,6 +403,113 @@ def polygon_box_transform(input, name=None): return output +@templatedoc(op_type="yolov3_loss") +def yolov3_loss(x, + gtbox, + gtlabel, + anchors, + class_num, + ignore_thresh, + loss_weight_xy=None, + loss_weight_wh=None, + loss_weight_conf_target=None, + loss_weight_conf_notarget=None, + loss_weight_class=None, + name=None): + """ + ${comment} + + Args: + x (Variable): ${x_comment} + gtbox (Variable): groud truth boxes, should be in shape of [N, B, 4], + in the third dimenstion, x, y, w, h should be stored + and x, y, w, h should be relative value of input image. + N is the batch number and B is the max box number in + an image. + gtlabel (Variable): class id of ground truth boxes, shoud be ins shape + of [N, B]. + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + ignore_thresh (float): ${ignore_thresh_comment} + loss_weight_xy (float|None): ${loss_weight_xy_comment} + loss_weight_wh (float|None): ${loss_weight_wh_comment} + loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} + loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment} + loss_weight_class (float|None): ${loss_weight_class_comment} + name (string): the name of yolov3 loss + + Returns: + Variable: A 1-D tensor with shape [1], the value of yolov3 loss + + Raises: + TypeError: Input x of yolov3_loss must be Variable + TypeError: Input gtbox of yolov3_loss must be Variable" + TypeError: Input gtlabel of yolov3_loss must be Variable" + TypeError: Attr anchors of yolov3_loss must be list or tuple + TypeError: Attr class_num of yolov3_loss must be an integer + TypeError: Attr ignore_thresh of yolov3_loss must be a float number + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') + gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 + anchors=anchors, ignore_thresh=0.5) + """ + helper = LayerHelper('yolov3_loss', **locals()) + + if not isinstance(x, Variable): + raise TypeError("Input x of yolov3_loss must be Variable") + if not isinstance(gtbox, Variable): + raise TypeError("Input gtbox of yolov3_loss must be Variable") + if not isinstance(gtlabel, Variable): + raise TypeError("Input gtlabel of yolov3_loss must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(ignore_thresh, float): + raise TypeError( + "Attr ignore_thresh of yolov3_loss must be a float number") + + if name is None: + loss = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + loss = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + attrs = { + "anchors": anchors, + "class_num": class_num, + "ignore_thresh": ignore_thresh, + } + + if loss_weight_xy is not None and isinstance(loss_weight_xy, float): + self.attrs['loss_weight_xy'] = loss_weight_xy + if loss_weight_wh is not None and isinstance(loss_weight_wh, float): + self.attrs['loss_weight_wh'] = loss_weight_wh + if loss_weight_conf_target is not None and isinstance( + loss_weight_conf_target, float): + self.attrs['loss_weight_conf_target'] = loss_weight_conf_target + if loss_weight_conf_notarget is not None and isinstance( + loss_weight_conf_notarget, float): + self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget + if loss_weight_class is not None and isinstance(loss_weight_class, float): + self.attrs['loss_weight_class'] = loss_weight_class + + helper.append_op( + type='yolov3_loss', + inputs={"X": x, + "GTBox": gtbox, + "GTLabel": gtlabel}, + outputs={'Loss': loss}, + attrs=attrs) + return loss + + @templatedoc() def detection_map(detect_res, label, diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 3f47053961bcc41b82f1b6776e9365166e78ddbf..42f4959a83fe113d6cbbe0db355249a9c203d602 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -943,7 +943,18 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None): def shuffle(reader, buffer_size): """ - Shuffle the reader. + Creates a data reader whose data output is shuffled. + Output from the iterator that created by original reader will be + buffered into shuffle buffer, and then shuffled. The size of shuffle buffer + is determined by argument buf_size. + + Args: + param reader: the original reader whose output will be shuffled. + type reader: callable + param buf_size: shuffle buffer size. + type buf_size: int + return: the new reader whose output is shuffled. + rtype: callable """ return __create_unshared_decorated_reader__( 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 149224bb68ac869dec14ac9f953f0072bd24c7e2..dde05189722fef77e03a1c2d8f3cbae44a3e8245 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -308,13 +308,9 @@ def piecewise_decay(boundaries, values): def append_LARS(params_grads, learning_rate, weight_decay): - """Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for - each layer. - - ```python - learning_rate *= local_gw_ratio * sqrt(sumsq(param)) - / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param))) - ``` + """ + Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for + each layer. Args: learning_rate: A learning rate Variable. This @@ -323,6 +319,11 @@ def append_LARS(params_grads, learning_rate, weight_decay): Returns: The decayed learning rate + Examples: + .. code-block:: python + + learning_rate *= local_gw_ratio * sqrt(sumsq(param)) + / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param))) """ def _balanced_weight(param_norm, grad_norm): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index dbc39afccbbbefa88873329e7c6790fe26dcc11e..28b8ae895afb57e0b3461829e6aea175c7c90e9d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -172,6 +172,8 @@ __all__ = [ 'lstm', ] +kIgnoreIndex = -100 + def fc(input, size, @@ -926,7 +928,7 @@ def dynamic_gru(input, emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) hidden_dim = 512 x = fluid.layers.fc(input=emb, size=hidden_dim * 3) - hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim) + hidden = fluid.layers.dynamic_gru(input=x, size=hidden_dim) """ helper = LayerHelper('gru', **locals()) @@ -1267,7 +1269,7 @@ def dropout(x, return out -def cross_entropy(input, label, soft_label=False, ignore_index=-100): +def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): """ **Cross Entropy Layer** @@ -1314,7 +1316,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=-100): labels. Default: `False`. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. Only valid - if soft_label is set to False. Default: -100 + if soft_label is set to False. Default: kIgnoreIndex Returns: A 2-D tensor with shape [N x 1], the cross entropy loss. @@ -3584,6 +3586,7 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None): Examples: .. code-block:: python + # Suppose `ids` and `scores` are LodTensorArray variables reserving # the selected ids and scores of all steps finished_ids, finished_scores = layers.beam_search_decode( @@ -5081,7 +5084,7 @@ def im2sequence(input, output.lod = [[4, 4]] - Examples: + Examples: .. code-block:: python @@ -5185,7 +5188,7 @@ def multiplex(inputs, index): def softmax_with_cross_entropy(logits, label, soft_label=False, - ignore_index=-100, + ignore_index=kIgnoreIndex, numeric_stable_mode=False, return_softmax=False): """ @@ -5243,7 +5246,7 @@ def softmax_with_cross_entropy(logits, labels as soft labels. By default, `soft_label` is set to False. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. Only valid - if soft_label is set to False. Default: -100 + if soft_label is set to False. Default: kIgnoreIndex numeric_stable_mode (bool): A flag to indicate whether to use a more numerically stable algorithm. Only valid when soft_label is False and GPU is used. @@ -5868,24 +5871,23 @@ def pad_constant_like(x, y, pad_value=0., name=None): [[38, 39, 40]], [[41, 42, 43]]]] Y.shape = (1, 3, 1, 3) + And + pad_value = -1, - And - pad_value = -1, - - Return: - Out = [[[[35, 36, 37], - [-1, -1, -1]], - [[38, 39, 40], - [-1, -1, -1]], - [[41, 42, 43], - [-1, -1, -1]]], - [[[-1, -1, -1], - [-1, -1, -1]], - [[-1, -1, -1], - [-1, -1, -1]], - [[-1, -1, -1], - [-1, -1, -1]]]] - Out.shape = (2, 3, 2, 3) + Return: + Out = [[[[35, 36, 37], + [-1, -1, -1]], + [[38, 39, 40], + [-1, -1, -1]], + [[41, 42, 43], + [-1, -1, -1]]], + [[[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]], + [[-1, -1, -1], + [-1, -1, -1]]]] + Out.shape = (2, 3, 2, 3) Args: x (Variable): The input tensor variable. @@ -6124,6 +6126,7 @@ def image_resize(input, Supporting resample methods: 'BILINEAR' : Bilinear interpolation + 'NEAREST' : Nearest neighbor interpolation Args: @@ -6779,7 +6782,7 @@ def crop(x, shape=None, offsets=None, name=None): # or z = fluid.layers.data(name="z", shape=[3, 5], dtype="float32") - crop = fluid.layers.crop(z, shape=[2, 3]) + crop = fluid.layers.crop(z, shape=[-1, 2, 3]) """ helper = LayerHelper('crop', **locals()) @@ -7060,39 +7063,40 @@ def pad2d(input, than height-1. And the width dimension has the same condition. Example: + .. code-block:: text - Given that X is a channel of image from input: + Given that X is a channel of image from input: - X = [[1, 2, 3], - [4, 5, 6]] + X = [[1, 2, 3], + [4, 5, 6]] - Case 0: + Case 0: - paddings = [0, 1, 2, 3], - mode = 'constant' - pad_value = 0 + paddings = [0, 1, 2, 3], + mode = 'constant' + pad_value = 0 - Out = [[0, 0, 1, 2, 3, 0, 0, 0] - [0, 0, 4, 5, 6, 0, 0, 0] - [0, 0, 0, 0, 0, 0, 0, 0]] + Out = [[0, 0, 1, 2, 3, 0, 0, 0] + [0, 0, 4, 5, 6, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0]] - Case 1: + Case 1: - paddings = [0, 1, 2, 1], - mode = 'reflect' + paddings = [0, 1, 2, 1], + mode = 'reflect' - Out = [[3, 2, 1, 2, 3, 2] - [6, 5, 4, 5, 6, 5] - [3, 2, 1, 2, 3, 2]] + Out = [[3, 2, 1, 2, 3, 2] + [6, 5, 4, 5, 6, 5] + [3, 2, 1, 2, 3, 2]] - Case 2: + Case 2: - paddings = [0, 1, 2, 1], - mode = 'edge' + paddings = [0, 1, 2, 1], + mode = 'edge' - Out = [[1, 1, 1, 2, 3, 3] - [4, 4, 4, 5, 6, 6] - [4, 4, 4, 5, 6, 6]] + Out = [[1, 1, 1, 2, 3, 3] + [4, 4, 4, 5, 6, 6] + [4, 4, 4, 5, 6, 6]] Args: @@ -7330,13 +7334,13 @@ def prelu(x, mode, param_attr=None, name=None): Args: x (Variable): The input tensor. param_attr(ParamAttr|None): The parameter attribute for the learnable - weight (alpha). + weight (alpha). mode (string): The mode for weight sharing. It supports all, channel - and element. all: all elements share same weight - channel:elements in a channel share same weight - element:each element has a weight + and element. all: all elements share same weight + channel:elements in a channel share same weight + element:each element has a weight name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + will be named automatically. Returns: Variable: The output tensor with the same shape as input. @@ -8415,13 +8419,17 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): @templatedoc() -def sigmoid_cross_entropy_with_logits(x, label, name=None): +def sigmoid_cross_entropy_with_logits(x, + label, + ignore_index=kIgnoreIndex, + name=None): """ ${comment} Args: x(${x_type}): ${x_comment} label(${label_type}): ${label_comment} + ignore_index(&{ignore_index}): ${ignore_index_comment} name(basestring|None): Name of the output. Returns: @@ -8440,7 +8448,7 @@ def sigmoid_cross_entropy_with_logits(x, label, name=None): type="sigmoid_cross_entropy_with_logits", inputs={"X": x, "Label": label}, - attrs={}, + attrs={"ignore_index": ignore_index}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 829154f1b23d6e49bf963762be6b6488c98ec94a..85af8fea13d5b9a1e22014fbd727e1baed3247be 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -222,13 +222,13 @@ class Precision(MetricBase): Examples: .. code-block:: python - metric = fluid.metrics.Precision() - for pass in range(PASSES): - metric.reset() - for data in train_reader(): - loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) - metric.update(preds=preds, labels=labels) - numpy_precision = metric.eval() + metric = fluid.metrics.Precision() + for pass in range(PASSES): + metric.reset() + for data in train_reader(): + loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) + metric.update(preds=preds, labels=labels) + numpy_precision = metric.eval() """ def __init__(self, name=None): @@ -267,13 +267,13 @@ class Recall(MetricBase): Examples: .. code-block:: python - metric = fluid.metrics.Recall() - for pass in range(PASSES): - metric.reset() - for data in train_reader(): - loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) - metric.update(preds=preds, labels=labels) - numpy_recall = metric.eval() + metric = fluid.metrics.Recall() + for pass in range(PASSES): + metric.reset() + for data in train_reader(): + loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) + metric.update(preds=preds, labels=labels) + numpy_recall = metric.eval() """ def __init__(self, name=None): @@ -449,8 +449,9 @@ class EditDistance(MetricBase): distance_evaluator.update(distances, seq_num) distance, instance_error = distance_evaluator.eval() - In the above example: + In the above example: 'distance' is the average of the edit distance in a pass. + 'instance_error' is the instance error rate in a pass. """ diff --git a/python/paddle/fluid/param_attr.py b/python/paddle/fluid/param_attr.py index a51607bfdb1dde3d25f490770cc2ba368ceb27ff..38ddf93198d7c58382e36a5b7af488f56e6f9878 100644 --- a/python/paddle/fluid/param_attr.py +++ b/python/paddle/fluid/param_attr.py @@ -50,8 +50,9 @@ class ParamAttr(object): w_param_attrs = fluid.ParamAttr(name="fc_weight", learning_rate=0.5, - regularizer=fluid.L2Decay(1.0), + regularizer=fluid.regularizer.L2Decay(1.0), trainable=True) + x = fluid.layers.data(name='X', shape=[1], dtype='float32') y_predict = fluid.layers.fc(input=x, size=10, param_attr=w_param_attrs) """ diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index a2eca5541a152ca99804a7f87c9b0bc3d12d4eee..d99eaa0634f93dcd16dd80ae172f11e8090a2623 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -388,5 +388,18 @@ class TestGenerateProposals(unittest.TestCase): print(rpn_rois.shape) +class TestYoloDetection(unittest.TestCase): + def test_yolov3_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') + gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') + loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, + 0.5) + + self.assertIsNotNone(loss) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 7f477031ddef6b459dcda74a18d0831763f4aeca..be51fb06a37a376f6f410336184c95981ded35dc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -170,9 +170,10 @@ class TestBook(unittest.TestCase): with program_guard(program): dat = layers.data(name='data', shape=[10], dtype='float32') lbl = layers.data(name='label', shape=[10], dtype='float32') + ignore_index = -1 self.assertIsNotNone( layers.sigmoid_cross_entropy_with_logits( - x=dat, label=lbl)) + x=dat, label=lbl, ignore_index=ignore_index)) print(str(program)) def test_hsigmoid(self): diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py index 97ff203499c0bf223930c904de46e1abdd902799..41797a241cab9f2b3bc4b492a1c4b6db89ac2948 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py @@ -56,6 +56,40 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): """Test sigmoid_cross_entropy_with_logit_op with probabalistic label """ + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = 64 + num_classes = 20 + ignore_index = -1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32")), + 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) + .astype("float32") + } + self.attrs = {'ignore_index': ignore_index, } + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == ignore_index)] = 0 + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" batch_size = 64 diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..544fe4b4f81909b69a05d9751316e3d3137fdc45 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -0,0 +1,215 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def mse(x, y, num): + return ((y - x)**2).sum() / num + + +def bce(x, y, mask): + x = x.reshape((-1)) + y = y.reshape((-1)) + mask = mask.reshape((-1)) + + error_sum = 0.0 + count = 0 + for i in range(x.shape[0]): + if mask[i] > 0: + error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) + count += 1 + return error_sum / (-1.0 * count) + + +def box_iou(box1, box2): + b1_x1 = box1[0] - box1[2] / 2 + b1_x2 = box1[0] + box1[2] / 2 + b1_y1 = box1[1] - box1[3] / 2 + b1_y2 = box1[1] + box1[3] / 2 + b2_x1 = box2[0] - box2[2] / 2 + b2_x2 = box2[0] + box2[2] / 2 + b2_y1 = box2[1] - box2[3] / 2 + b2_y2 = box2[1] + box2[3] / 2 + + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + + inter_rect_x1 = max(b1_x1, b2_x1) + inter_rect_y1 = max(b1_y1, b2_y1) + inter_rect_x2 = min(b1_x2, b2_x2) + inter_rect_y2 = min(b1_y2, b2_y2) + inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max( + inter_rect_y2 - inter_rect_y1, 0) + + return inter_area / (b1_area + b2_area + inter_area) + + +def build_target(gtboxs, gtlabel, attrs, grid_size): + n, b, _ = gtboxs.shape + ignore_thresh = attrs["ignore_thresh"] + anchors = attrs["anchors"] + class_num = attrs["class_num"] + an_num = len(anchors) // 2 + obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') + tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tcls = np.zeros( + (n, an_num, grid_size, grid_size, class_num)).astype('float32') + + for i in range(n): + for j in range(b): + if gtboxs[i, j, :].sum() == 0: + continue + + gt_label = gtlabel[i, j] + gx = gtboxs[i, j, 0] * grid_size + gy = gtboxs[i, j, 1] * grid_size + gw = gtboxs[i, j, 2] * grid_size + gh = gtboxs[i, j, 3] * grid_size + + gi = int(gx) + gj = int(gy) + + gtbox = [0, 0, gw, gh] + max_iou = 0 + for k in range(an_num): + anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]] + iou = box_iou(gtbox, anchor_box) + if iou > max_iou: + max_iou = iou + best_an_index = k + if iou > ignore_thresh: + noobj_mask[i, best_an_index, gj, gi] = 0 + + obj_mask[i, best_an_index, gj, gi] = 1 + noobj_mask[i, best_an_index, gj, gi] = 0 + tx[i, best_an_index, gj, gi] = gx - gi + ty[i, best_an_index, gj, gi] = gy - gj + tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * + best_an_index]) + th[i, best_an_index, gj, gi] = np.log( + gh / anchors[2 * best_an_index + 1]) + tconf[i, best_an_index, gj, gi] = 1 + tcls[i, best_an_index, gj, gi, gt_label] = 1 + + return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + + +def YoloV3Loss(x, gtbox, gtlabel, attrs): + n, c, h, w = x.shape + an_num = len(attrs['anchors']) // 2 + class_num = attrs["class_num"] + x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + pred_x = sigmoid(x[:, :, :, :, 0]) + pred_y = sigmoid(x[:, :, :, :, 1]) + pred_w = x[:, :, :, :, 2] + pred_h = x[:, :, :, :, 3] + pred_conf = sigmoid(x[:, :, :, :, 4]) + pred_cls = sigmoid(x[:, :, :, :, 5:]) + + tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + gtbox, gtlabel, attrs, x.shape[2]) + + obj_mask_expand = np.tile( + np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) + loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) + loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) + loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) + loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) + loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) + loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask, + noobj_mask) + loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, + obj_mask_expand) + + return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + + attrs['loss_weight_wh'] * (loss_w + loss_h) \ + + attrs['loss_weight_conf_target'] * loss_conf_target \ + + attrs['loss_weight_conf_notarget'] * loss_conf_notarget \ + + attrs['loss_weight_class'] * loss_class + + +class TestYolov3LossOp(OpTest): + def setUp(self): + self.loss_weight_xy = 1.0 + self.loss_weight_wh = 1.0 + self.loss_weight_conf_target = 1.0 + self.loss_weight_conf_notarget = 1.0 + self.loss_weight_class = 1.0 + self.initTestCase() + self.op_type = 'yolov3_loss' + x = np.random.random(size=self.x_shape).astype('float32') + gtbox = np.random.random(size=self.gtbox_shape).astype('float32') + gtlabel = np.random.randint(0, self.class_num, + self.gtbox_shape[:2]).astype('int32') + + self.attrs = { + "anchors": self.anchors, + "class_num": self.class_num, + "ignore_thresh": self.ignore_thresh, + "loss_weight_xy": self.loss_weight_xy, + "loss_weight_wh": self.loss_weight_wh, + "loss_weight_conf_target": self.loss_weight_conf_target, + "loss_weight_conf_notarget": self.loss_weight_conf_notarget, + "loss_weight_class": self.loss_weight_class, + } + + self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} + self.outputs = { + 'Loss': np.array( + [YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32') + } + + def test_check_output(self): + place = core.CPUPlace() + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad_ignore_gtbox(self): + place = core.CPUPlace() + self.check_grad_with_place( + place, ['X'], + 'Loss', + no_grad_set=set(["GTBox", "GTLabel"]), + max_relative_error=0.06) + + def initTestCase(self): + self.anchors = [10, 13, 12, 12] + self.class_num = 10 + self.ignore_thresh = 0.5 + self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) + self.gtbox_shape = (5, 10, 4) + self.loss_weight_xy = 2.5 + self.loss_weight_wh = 0.8 + self.loss_weight_conf_target = 1.5 + self.loss_weight_conf_notarget = 0.5 + self.loss_weight_class = 1.2 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 5d348f0995fbff7bbefa3324caffb448c98f552f..1d867d9194347cf55fd1bd8b1962856d599be7ec 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -125,13 +125,14 @@ def slice_variable(var_list, slice_count, min_block_size): class DistributeTranspilerConfig(object): """ - slice_var_up (bool): Do Tensor slice for pservers, default is True. - split_method (PSDispatcher): RoundRobin or HashName can be used - try to choose the best method to balance loads for pservers. - min_block_size (int): Minimum splitted element number in block. - According:https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156 - We can use bandwidth effiently when data size is larger than 2MB.If you - want to change it, please be sure you see the slice_variable function. + Args: + slice_var_up (bool): Do Tensor slice for pservers, default is True. + split_method (PSDispatcher): RoundRobin or HashName can be used + try to choose the best method to balance loads for pservers. + min_block_size (int): Minimum splitted element number in block. + According:https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156 + We can use bandwidth effiently when data size is larger than 2MB.If you + want to change it, please be sure you see the slice_variable function. """ slice_var_up = True @@ -163,35 +164,35 @@ class DistributeTranspiler(object): Examples: .. code-block:: python - # for pserver mode - pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174" - trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174" - current_endpoint = "192.168.0.1:6174" - trainer_id = 0 - trainers = 4 - role = os.getenv("PADDLE_TRAINING_ROLE") - - t = fluid.DistributeTranspiler() - t.transpile( - trainer_id, pservers=pserver_endpoints, trainers=trainers) - if role == "PSERVER": - pserver_program = t.get_pserver_program(current_endpoint) - pserver_startup_program = t.get_startup_program(current_endpoint, + # for pserver mode + pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174" + trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174" + current_endpoint = "192.168.0.1:6174" + trainer_id = 0 + trainers = 4 + role = os.getenv("PADDLE_TRAINING_ROLE") + + t = fluid.DistributeTranspiler() + t.transpile( + trainer_id, pservers=pserver_endpoints, trainers=trainers) + if role == "PSERVER": + pserver_program = t.get_pserver_program(current_endpoint) + pserver_startup_program = t.get_startup_program(current_endpoint, pserver_program) - elif role == "TRAINER": - trainer_program = t.get_trainer_program() - - # for nccl2 mode - config = fluid.DistributeTranspilerConfig() - config.mode = "nccl2" - t = fluid.DistributeTranspiler(config=config) - t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep) - exe = fluid.ParallelExecutor( - use_cuda, - loss_name=loss_var.name, - num_trainers=len(trainers.split(",)), - trainer_id=trainer_id - ) + elif role == "TRAINER": + trainer_program = t.get_trainer_program() + + # for nccl2 mode + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep) + exe = fluid.ParallelExecutor( + use_cuda, + loss_name=loss_var.name, + num_trainers=len(trainers.split(",)), + trainer_id=trainer_id + ) """ def __init__(self, config=None): diff --git a/python/setup.py.in b/python/setup.py.in index 200b96ec54ee5daeb905e155d0b7b57ab7740250..5aee26b63832889272cde09c553b4615efb8872a 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -165,9 +165,9 @@ if '${WITH_MKL}' == 'ON': shutil.copy('${MKLML_LIB}', libs_path) shutil.copy('${MKLML_IOMP_LIB}', libs_path) package_data['paddle.libs']+=['libmklml_intel' + ext_name,'libiomp5' + ext_name] -if '${CMAKE_BUILD_TYPE}' == 'Release': - # only change rpath in Release mode. - if '${WITH_MKLDNN}' == 'ON': +if '${WITH_MKLDNN}' == 'ON': + if '${CMAKE_BUILD_TYPE}' == 'Release': + # only change rpath in Release mode. # TODO(typhoonzero): use install_name_tool to patch mkl libs once # we can support mkl on mac. # @@ -177,14 +177,19 @@ if '${CMAKE_BUILD_TYPE}' == 'Release': command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}" if os.system(command) != 0: raise Exception("patch libmkldnn.so failed, command: %s" % command) - package_data['paddle.libs']+=['libmkldnn.so.0'] - shutil.copy('${MKLDNN_SHARED_LIB}', libs_path) + package_data['paddle.libs']+=['libmkldnn.so.0'] + shutil.copy('${MKLDNN_SHARED_LIB}', libs_path) if '${WITH_NGRAPH}' == 'ON': + # only change rpath in Release mode, + # since in Debug mode, nGraph lib may be too large to be changed? if '${CMAKE_BUILD_TYPE}' == 'Release': - # only change rpath in Release mode. - command = "patchelf --set-rpath '$ORIGIN/' ${NGRAPH_SHARED_LIB}" - if os.system(command) != 0: - raise Exception("patch ${NGRAPH_SHARED_LIB_NAME} failed, command: %s" % command) + if os.name != 'nt': + if "@APPLE@" == "1": + command = "install_name_tool -id \"@loader_path/\" ${NGRAPH_SHARED_LIB}" + else: + command = "patchelf --set-rpath '$ORIGIN/' ${NGRAPH_SHARED_LIB}" + if os.system(command) != 0: + raise Exception("patch ${NGRAPH_SHARED_LIB_NAME} failed, command: %s" % command) shutil.copy('${NGRAPH_SHARED_LIB}', libs_path) shutil.copy('${NGRAPH_CPU_LIB}', libs_path) shutil.copy('${NGRAPH_TBB_LIB}', libs_path)