diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d27289ed6d14a08895e33675781bcff986a3686b..4a5c8d93ff49cbc12b2ca90f9ea53095d78678f8 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -73,7 +73,6 @@ paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)) paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)) -paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)) paddle.fluid.initializer.NormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)) @@ -161,6 +160,12 @@ paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.elu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) +paddle.fluid.layers.relu6 ArgSpec(args=['x', 'threshold', 'name'], varargs=None, keywords=None, defaults=(6.0, None)) +paddle.fluid.layers.pow ArgSpec(args=['x', 'factor', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) +paddle.fluid.layers.stanh ArgSpec(args=['x', 'scale_a', 'scale_b', 'name'], varargs=None, keywords=None, defaults=(0.6666666666666666, 1.7159, None)) +paddle.fluid.layers.hard_sigmoid ArgSpec(args=['x', 'slope', 'offset', 'name'], varargs=None, keywords=None, defaults=(0.2, 0.5, None)) +paddle.fluid.layers.swish ArgSpec(args=['x', 'beta', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) @@ -277,12 +282,6 @@ paddle.fluid.layers.softsign ArgSpec(args=[], varargs='args', keywords='kwargs', paddle.fluid.layers.brelu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.leaky_relu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.soft_relu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.relu6 ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.pow ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.stanh ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.hard_sigmoid ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.swish ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.uniform_random ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None)) @@ -296,6 +295,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', ' paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) +paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)) paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) diff --git a/paddle/fluid/framework/details/cow_ptr.h b/paddle/fluid/framework/details/cow_ptr.h index 4fb015b0ffe27e6fa91d4eaf373fca4feca66361..21f75957be5f33f3dfc09c41fa9a1e1ca590f99e 100644 --- a/paddle/fluid/framework/details/cow_ptr.h +++ b/paddle/fluid/framework/details/cow_ptr.h @@ -20,41 +20,79 @@ namespace paddle { namespace framework { namespace details { -template -class COWPtr { +// Change it to thread safe flags if needed. +class ThreadUnsafeOwnershipFlags { public: - typedef std::shared_ptr RefPtr; + explicit ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} - private: - RefPtr m_sp; + ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete; + ThreadUnsafeOwnershipFlags& operator=( + const ThreadUnsafeOwnershipFlags& other) = delete; + ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default; - void detach() { - T* tmp = m_sp.get(); - if (!(tmp == nullptr || m_sp.unique())) { - m_sp = RefPtr(new T(*tmp)); + void SetOwnership(bool flag) { flag_ = flag; } + + // Invoke the callback if it is not owned. + template + void AcquireOwnershipOnce(Callback acquire) { + if (!flag_) { + acquire(); + flag_ = true; } } - public: - COWPtr() : m_sp(nullptr) {} - explicit COWPtr(T* t) : m_sp(t) {} - explicit COWPtr(const RefPtr& refptr) : m_sp(refptr) {} + private: + bool flag_; +}; - const T& Data() const { return operator*(); } +// Copy-On-Write pointer. +// It will hold a T* pointer, and only copy once when `MutableData` is invoked. +// +// The template parameter OwnershipFlags should have: +// * a constructor takes a bool. True if own. +// * SetOwnership(bool flag). +// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not +// owned. +// +// https://en.wikipedia.org/wiki/Copy-on-write +template +class COWPtr { + public: + // Ctor from raw pointer. + explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {} - T* MutableData() { return operator->(); } + // Move methods. Steal ownership from origin + COWPtr(COWPtr&& other) + : payload_(other.payload_), ownership_{std::move(other.ownership_)} {} + COWPtr& operator=(COWPtr&& origin) = default; - const T& operator*() const { return *m_sp; } - T& operator*() { - detach(); - return *m_sp; + // Copy methods. Not own payload + COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {} + COWPtr& operator=(const COWPtr& other) { + payload_ = other.payload_; + ownership_.SetOwnership(false); + return *this; } - const T* operator->() const { return m_sp.operator->(); } - T* operator->() { - detach(); - return m_sp.operator->(); + + // Access read only data. + const T& Data() const { return *payload_; } + + // Access mutable data. If the data is not owned, the data will be copied + // before. + T* MutableData() { + ownership_.AcquireOwnershipOnce( + [this] { payload_.reset(new T(*payload_)); }); + return payload_.get(); } + + private: + // Actual data pointer. + std::shared_ptr payload_; + + // Ownership flag. + OwnershipFlags ownership_; }; + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/cow_ptr_test.cc b/paddle/fluid/framework/details/cow_ptr_test.cc index 5b055d7cb4d127dc20f2cf70869134f24a93d429..d2142af277c0b356d83941b3baab1947cce31dac 100644 --- a/paddle/fluid/framework/details/cow_ptr_test.cc +++ b/paddle/fluid/framework/details/cow_ptr_test.cc @@ -30,14 +30,6 @@ TEST(COWPtr, all) { ASSERT_EQ(ptr2.Data(), 10); } -TEST(COWPtr, change_old) { - COWPtr ptr(new int{0}); - COWPtr ptr2 = ptr; - *ptr.MutableData() = 10; - ASSERT_EQ(ptr2.Data(), 0); - ASSERT_EQ(ptr.Data(), 10); -} - } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 8f319116ab80b75c624f35b0e1315e7362e88d9a..134fcee826715672a6e021e9bf694bb771ebb830 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -210,43 +210,6 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( return recv_vars; } -bool MultiDevSSAGraphBuilder::IsDistTrainOp( - ir::Node *node, const std::vector &send_vars, - const std::vector &recv_vars) const { - if (send_vars.size() == 0 || recv_vars.size() == 0) { - return false; - } - - /** - * Check any of opvars contains `.block` and in sendvars - */ - auto checker = [](const std::vector &opvars, - const std::vector &rpc_vars) -> bool { - for (auto &var : opvars) { - // a variable name with the suffix `.block` means it's a splited - // variable by (DistributeTranspiler) - // [python/paddle/fluid/transpiler/distribute_transpiler.py] - if (var.find(".block") != std::string::npos && - std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { - return true; - } - } - return false; - }; - - std::vector input_var_names; - std::vector output_var_names; - for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Name()); - } - for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Name()); - } - - return checker(output_var_names, send_vars) || - checker(input_var_names, recv_vars); -} - size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( const std::vector &var_names) const { int64_t numel_sum = 0; @@ -370,7 +333,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } } is_dist_train = true; - } else if (IsDistTrainOp(node, send_vars, recv_vars)) { + } else if (boost::get(node->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kDist)) { int op_dev_id = CreateDistTrainOp(&result, node); if (node->Op()->Type() == "concat") { auto origin_param_name = node->Op()->OutputArgumentNames()[0]; @@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, .emplace(varname, op_dev_id); } } else { + LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); PADDLE_THROW( "the distribute training related op should be in [split_byref, " "concat]."); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 47aaa80f4d66a48b729d0638badcab885a50585c..cdf9f13cde608b546d17a1e53e0f6acea9e12566 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; - /** - * Is this operator as the end-point operator before/after send operator. - */ - bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, - const std::vector &recv_vars) const; - std::vector FindDistTrainSendVars( const std::vector &nodes) const; diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h index 71db8d952f4c205b875ad254dc19c0c1f74e61b3..fc479a4c4a1e7d5c824d3c202e0cccf743dd52c9 100644 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" namespace paddle { @@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase { const std::vector &var_names, GarbageCollector *gc, AtomicReferenceCountMap *ref_cnts) - : OpHandleBase(node), - scope_(scope), - var_names_(var_names), - gc_(gc), - ref_cnts_(ref_cnts) { + : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) { dev_ctx_ = static_cast( platform::DeviceContextPool::Instance().Get(place)); if (IsStreamGarabageCollector()) { PADDLE_ENFORCE(cudaSetDevice(place.device)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } + + for (auto &name : var_names) AddVar(name); } ~ReferenceCountOpHandle() { @@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase { std::string Name() const override { return "reference_count"; } + void AddVar(const std::string &name) { + auto it = var_names_.find(name); + if (it != var_names_.end()) + ++(it->second); + else + var_names_[name] = 1; + } + protected: void RunImpl() override { auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); - std::vector tensors; - for (auto &name : var_names_) { + std::vector tensors; + for (auto &pair : var_names_) { + auto &name = pair.first; auto it = ref_cnts_->find(name); if (it == ref_cnts_->end()) continue; auto *var = exec_scope->FindVar(name); - if (var == nullptr || !var->IsType()) continue; - - if (it->second.fetch_sub(1) <= 1) { - tensors.emplace_back(var->GetMutable()); + if (var == nullptr) continue; + + if (var->IsType()) { + if (it->second.fetch_sub(pair.second) <= pair.second) { + tensors.emplace_back(var->GetMutable()); + } + } else if (var->IsType()) { + if (it->second.fetch_sub(pair.second) <= pair.second) { + tensors.emplace_back( + var->GetMutable()->mutable_value()); + } } } @@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase { } private: - void ClearTensors(const std::vector &tensors) { + void ClearTensors(const std::vector &tensors) { auto *gc = dynamic_cast *>(gc_); if (gc != nullptr) { auto compute_stream = dev_ctx_->stream(); @@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase { const Scope *scope_; platform::CUDADeviceContext *dev_ctx_; - std::vector var_names_; + std::unordered_map var_names_; GarbageCollector *gc_; // not own AtomicReferenceCountMap *ref_cnts_; // not own cudaEvent_t event_; diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 344754d5a1e119c04cae08ad50126924b5824315..b1ce551ce73de33bcede187c72feebad6e2fa1a5 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -23,6 +24,25 @@ namespace paddle { namespace framework { namespace details { +static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { + std::queue queue; + queue.push(var_in); + do { + auto *var = queue.front(); + queue.pop(); + for (auto *op : var->PendingOps()) { + auto *compute_op = dynamic_cast(op); + if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) { + return compute_op; + } + for (auto *out_var : op->Outputs()) { + queue.push(out_var); + } + } + } while (!queue.empty()); + return nullptr; +} + std::unique_ptr ReferenceCountPass::ApplyImpl( std::unique_ptr graph) const { auto &ref_cnts = Get(kGlobalReferenceCount); @@ -34,6 +54,9 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( // Step 2: Find all variables in non-computation ops which refers to variables // in computation ops std::unordered_set names; + std::unordered_map> + compute_ref_cnt_map; + auto get_ref_cnts_from_compute_op = [&]( const std::unique_ptr &op, const std::vector &vars) { @@ -54,15 +77,18 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( VarDesc *var_desc = var_handle->Node()->Var(); auto var_name = var_handle->Node()->Name(); - // This is wierd but there is really some variables without var_desc + // This is weird but there is really some variables without var_desc // in computation_op if (var_desc == nullptr) { if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) continue; } else { - if (var_desc->Persistable() || - var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) + if (var_desc->Persistable()) continue; + auto var_type = var_desc->Proto()->type().type(); + if (var_type != proto::VarType::LOD_TENSOR && + var_type != proto::VarType::SELECTED_ROWS) { continue; + } } // compute op only runs in one device @@ -93,12 +119,33 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( if (ref_cnts.count(place.device) && ref_cnts[place.device]->count(var_name)) { ++(*ref_cnts[place.device])[var_name]; + + auto *next_compute_op = FindNextComputationOpHandle(var_handle); + if (next_compute_op != nullptr) { + if (compute_ref_cnt_map.count(next_compute_op)) { + compute_ref_cnt_map[next_compute_op]->AddVar(var_name); + VLOG(5) << "Add reference count of " << var_name << " to Operator " + << next_compute_op->Name(); + } else { + // Create new reference_count_op_handle + ir::Node *ref_cnt_node = graph->CreateEmptyNode( + "reference_count", ir::Node::Type::kOperation); + auto *ref_cnt_handle = new ReferenceCountOpHandle( + ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, + gcs[place.device].get(), cur_ref_cnts[place.device].get()); + if (next_compute_op->Outputs().empty()) { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + next_compute_op->AddOutput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + } + ref_cnt_handle->AddInput(next_compute_op->Outputs().front()); + compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); + } + } } } }; - std::unordered_map - compute_ref_cnt_map; auto &all_ops = graph->Get(kGraphOps); for (auto &op : all_ops) { auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); @@ -113,11 +160,13 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( auto *ref_cnt_handle = new ReferenceCountOpHandle( ref_cnt_node, compute_op->GetScope(), place, in_var_names, gcs[place.device].get(), cur_ref_cnts[place.device].get()); - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - compute_op->AddOutput(dep_var); - ref_cnt_handle->AddInput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - compute_ref_cnt_map[compute_op] = ref_cnt_handle; + if (compute_op->Outputs().empty()) { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + compute_op->AddOutput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + } + ref_cnt_handle->AddInput(compute_op->Outputs().front()); + compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); } for (auto &op : all_ops) { @@ -131,7 +180,11 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( new_all_ops.emplace_back(std::move(op)); auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); if (it != compute_ref_cnt_map.end()) { - new_all_ops.emplace_back(it->second); + // Add LeafNode to ReferenceCountOpHandle + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dummy_leaf); + it->second->AddOutput(dummy_leaf); + new_all_ops.emplace_back(std::move(it->second)); } } diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index ba2c41eb8968e30bc8a801f4d8d239da8c522de9..7836ecb1272a07a79a70c9cb040335f9a42e5684 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -17,12 +17,10 @@ #include #include #include -#include #include -#include "paddle/fluid/framework/details/cow_ptr.h" + #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/memory/memcpy.h" #include "glog/logging.h" @@ -30,401 +28,206 @@ namespace paddle { namespace framework { #if defined(PADDLE_WITH_CUDA) -namespace details { -struct CUDABuffer { - void *data_{nullptr}; - size_t size_{0}; - platform::CUDAPlace place_; - - CUDABuffer() {} - CUDABuffer(platform::Place place, size_t size) - : size_(size), place_(boost::get(place)) { - data_ = memory::Alloc(place_, size); - } - - ~CUDABuffer() { ClearMemory(); } - - CUDABuffer(const CUDABuffer &o) = delete; - CUDABuffer &operator=(const CUDABuffer &o) = delete; - - void Resize(platform::Place place, size_t size) { - ClearMemory(); - place_ = boost::get(place); - data_ = memory::Alloc(place_, size); - size_ = size; - } - - void Swap(CUDABuffer &o) { - std::swap(data_, o.data_); - std::swap(place_, o.place_); - std::swap(size_, o.size_); - } - - private: - void ClearMemory() const { - if (data_) { - memory::Free(place_, data_); - } - } -}; -} // namespace details - // Vector implements the std::vector interface, and can get Data or // MutableData from any place. The data will be synced implicitly inside. template class Vector { public: using value_type = T; - using iterator = typename std::vector::iterator; - using const_iterator = typename std::vector::const_iterator; - - private: - // The actual class to implement vector logic - class VectorData { - public: - VectorData() : flag_(kDataInCPU) {} - VectorData(size_t count, const T &value) - : cpu_(count, value), flag_(kDataInCPU) {} - VectorData(std::initializer_list init) : cpu_(init), flag_(kDataInCPU) {} - template - explicit VectorData(const std::vector &dat) - : cpu_(dat), flag_(kDataInCPU) {} - - VectorData(const VectorData &o) { - o.ImmutableCPU(); - cpu_ = o.cpu_; - flag_ = kDataInCPU; - } - - VectorData &operator=(const VectorData &o) { - o.ImmutableCPU(); - cpu_ = o.cpu_; - flag_ = kDataInCPU; - details::CUDABuffer null; - gpu_.Swap(null); - return *this; - } - - T &operator[](size_t i) { - MutableCPU(); - return cpu_[i]; - } - - const T &operator[](size_t i) const { - ImmutableCPU(); - return cpu_[i]; - } - - size_t size() const { return cpu_.size(); } - - iterator begin() { - MutableCPU(); - return cpu_.begin(); - } - - iterator end() { - MutableCPU(); - return cpu_.end(); - } - - T &front() { - MutableCPU(); - return cpu_.front(); - } - - T &back() { - MutableCPU(); - return cpu_.back(); - } - - const_iterator begin() const { - ImmutableCPU(); - return cpu_.begin(); - } - - const_iterator end() const { - ImmutableCPU(); - return cpu_.end(); - } - - const T &back() const { - ImmutableCPU(); - return cpu_.back(); - } - - T *data() { return &(*this)[0]; } - - const T *data() const { return &(*this)[0]; } - - const T &front() const { - ImmutableCPU(); - return cpu_.front(); - } - - // assign this from iterator. - // NOTE: the iterator must support `end-begin` - template - void assign(Iter begin, Iter end) { - MutableCPU(); - cpu_.assign(begin, end); - } - - // push_back. If the previous capacity is not enough, the memory will - // double. - void push_back(T elem) { - MutableCPU(); - cpu_.push_back(elem); - } - - // extend a vector by iterator. - // NOTE: the iterator must support end-begin - template - void Extend(It begin, It end) { - MutableCPU(); - auto out_it = std::back_inserter>(this->cpu_); - std::copy(begin, end, out_it); - } - - // resize the vector - void resize(size_t size) { - MutableCPU(); - cpu_.resize(size); - } - - // get cuda ptr. immutable - const T *CUDAData(platform::Place place) const { - PADDLE_ENFORCE(platform::is_gpu_place(place), - "CUDA Data must on CUDA place"); - ImmutableCUDA(place); - return reinterpret_cast(gpu_.data_); - } - - // get cuda ptr. mutable - T *CUDAMutableData(platform::Place place) { - const T *ptr = CUDAData(place); - flag_ = kDirty | kDataInCUDA; - return const_cast(ptr); - } - - // clear - void clear() { - cpu_.clear(); - flag_ = kDirty | kDataInCPU; - } - - size_t capacity() const { return cpu_.capacity(); } - - // reserve data - void reserve(size_t size) { cpu_.reserve(size); } - - // implicit cast operator. Vector can be cast to std::vector implicitly. - operator std::vector() const { - ImmutableCPU(); - return cpu_; - } - - bool operator==(const VectorData &other) const { - ImmutableCPU(); - other.ImmutableCPU(); - return cpu_ == other.cpu_; - } - - private: - enum DataFlag { - kDataInCPU = 0x01, - kDataInCUDA = 0x02, - // kDirty means the data has been changed in one device. - kDirty = 0x10 - }; - - void CopyToCPU() const { - // COPY GPU Data To CPU - void *src = gpu_.data_; - void *dst = cpu_.data(); - memory::Copy(platform::CPUPlace(), dst, gpu_.place_, src, gpu_.size_, - nullptr); - } - - void MutableCPU() { - if (IsInCUDA() && IsDirty()) { - CopyToCPU(); - } - flag_ = kDirty | kDataInCPU; - } - - void ImmutableCUDA(platform::Place place) const { - if (IsDirty()) { - if (IsInCPU()) { - CopyCPUDataToCUDA(place); - UnsetFlag(kDirty); - SetFlag(kDataInCUDA); - } else if (IsInCUDA() && - !(boost::get(place) == gpu_.place_)) { - CopyCUDADataToAnotherPlace(place); - // Still dirty - } else { - // Dirty && DataInCUDA && Device is same - // Do nothing - } - } else { - if (!IsInCUDA()) { - // Even data is not dirty. However, data is not in CUDA. Copy data. - CopyCPUDataToCUDA(place); - SetFlag(kDataInCUDA); - } else if (!(boost::get(place) == gpu_.place_)) { - CopyCUDADataToAnotherPlace(place); - } else { - // Not Dirty && DataInCUDA && Device is same - // Do nothing. - } - } - } - void CopyCUDADataToAnotherPlace(const platform::Place &place) const { - details::CUDABuffer tmp(place, gpu_.size_); - const void *src = gpu_.data_; - void *dst = tmp.data_; - - memory::Copy(tmp.place_, dst, gpu_.place_, src, gpu_.size_, nullptr); - gpu_.Swap(tmp); - } - void CopyCPUDataToCUDA(const platform::Place &place) const { - void *src = cpu_.data(); - gpu_.Resize(place, cpu_.size() * sizeof(T)); - void *dst = gpu_.data_; - auto stream = static_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - memory::Copy(gpu_.place_, dst, platform::CPUPlace(), src, gpu_.size_, - stream); - } - - void ImmutableCPU() const { - if (IsDirty() && !IsInCPU()) { // If data has been changed in CUDA, or - // CPU has no data. - CopyToCPU(); - UnsetFlag(kDirty); - } - SetFlag(kDataInCPU); - } - - void UnsetFlag(int flag) const { flag_ &= ~flag; } - void SetFlag(int flag) const { flag_ |= flag; } - - bool IsDirty() const { return flag_ & kDirty; } - - bool IsInCUDA() const { return flag_ & kDataInCUDA; } - bool IsInCPU() const { return flag_ & kDataInCPU; } - - mutable std::vector cpu_; - mutable details::CUDABuffer gpu_; - mutable int flag_; - }; - - public: // Default ctor. Create empty Vector - Vector() : m_(new VectorData()) {} + Vector() { InitEmpty(); } // Fill vector with value. The vector size is `count`. - explicit Vector(size_t count, const T &value = T()) - : m_(new VectorData(count, value)) {} + explicit Vector(size_t count, const T &value = T()) { + InitEmpty(); + if (count != 0) { + resize(count); + T *ptr = begin(); + for (size_t i = 0; i < count; ++i) { + ptr[i] = value; + } + } + } // Ctor with init_list - Vector(std::initializer_list init) : m_(new VectorData(init)) {} + Vector(std::initializer_list init) { + if (init.size() == 0) { + InitEmpty(); + } else { + InitByIter(init.size(), init.begin(), init.end()); + } + } // implicit cast from std::vector. template - Vector(const std::vector &dat) : m_(new VectorData(dat)) { // NOLINT + Vector(const std::vector &dat) { // NOLINT + if (dat.size() == 0) { + InitEmpty(); + } else { + InitByIter(dat.size(), dat.begin(), dat.end()); + } } // Copy ctor - Vector(const Vector &other) { m_ = other.m_; } + Vector(const Vector &other) { this->operator=(other); } // Copy operator Vector &operator=(const Vector &other) { - m_ = other.m_; + if (other.size() != 0) { + this->InitByIter(other.size(), other.begin(), other.end()); + } else { + InitEmpty(); + } return *this; } // Move ctor - Vector(Vector &&other) { m_ = std::move(other.m_); } + Vector(Vector &&other) { + this->size_ = other.size_; + this->flag_ = other.flag_; + if (other.cuda_vec_.memory_size()) { + this->cuda_vec_.ShareDataWith(other.cuda_vec_); + } + if (other.cpu_vec_.memory_size()) { + this->cpu_vec_.ShareDataWith(other.cpu_vec_); + } + } // CPU data access method. Mutable. - T &operator[](size_t i) { return (*m_)[i]; } + T &operator[](size_t i) { + MutableCPU(); + return const_cast(cpu_vec_.data())[i]; + } // CPU data access method. Immutable. - const T &operator[](size_t i) const { return (*m_)[i]; } + const T &operator[](size_t i) const { + ImmutableCPU(); + return cpu_vec_.data()[i]; + } // std::vector iterator methods. Based on CPU data access method - size_t size() const { return m_->size(); } + size_t size() const { return size_; } - iterator begin() { return m_->begin(); } + T *begin() { return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); } - iterator end() { return m_->end(); } + T *end() { + return capacity() == 0 ? &EmptyDummy() : &this->operator[](size()); + } - T &front() { return m_->front(); } + T &front() { return *begin(); } - T &back() { return m_->back(); } + T &back() { + auto it = end(); + --it; + return *it; + } - const_iterator begin() const { return m_->begin(); } + const T *begin() const { + return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); + } - const_iterator end() const { return m_->end(); } + const T *end() const { + return capacity() == 0 ? &EmptyDummy() : &this->operator[](size()); + } - const_iterator cbegin() const { return begin(); } + const T *cbegin() const { return begin(); } - const_iterator cend() const { return end(); } + const T *cend() const { return end(); } - const T &back() const { return m_->back(); } + const T &back() const { + auto it = end(); + --it; + return *it; + } - T *data() { return m_->data(); } + T *data() { return begin(); } - const T *data() const { return m_->data(); } + const T *data() const { return begin(); } - const T &front() const { return m_->front(); } + const T &front() const { return *begin(); } // end of std::vector iterator methods // assign this from iterator. // NOTE: the iterator must support `end-begin` template void assign(Iter begin, Iter end) { - m_->assign(begin, end); + InitByIter(end - begin, begin, end); } // push_back. If the previous capacity is not enough, the memory will // double. - void push_back(T elem) { m_->push_back(elem); } + void push_back(T elem) { + if (size_ + 1 > capacity()) { + reserve((size_ + 1) << 1); + } + *end() = elem; + ++size_; + } // extend a vector by iterator. // NOTE: the iterator must support end-begin template void Extend(It begin, It end) { - m_->Extend(begin, end); + size_t pre_size = size_; + resize(pre_size + (end - begin)); + T *ptr = this->begin() + pre_size; + for (; begin < end; ++begin, ++ptr) { + *ptr = *begin; + } } // resize the vector void resize(size_t size) { - if (m_.Data().size() != size) { - m_->resize(size); + if (size + 1 <= capacity()) { + size_ = size; + } else { + MutableCPU(); + Tensor cpu_tensor; + platform::Place cpu = platform::CPUPlace(); + T *ptr = cpu_tensor.mutable_data( + framework::make_ddim({static_cast(size)}), cpu); + const T *old_ptr = + cpu_vec_.memory_size() == 0 ? nullptr : cpu_vec_.data(); + if (old_ptr != nullptr) { + std::copy(old_ptr, old_ptr + size_, ptr); + } + size_ = size; + cpu_vec_.ShareDataWith(cpu_tensor); } } // get cuda ptr. immutable const T *CUDAData(platform::Place place) const { - return m_.Data().CUDAData(place); + PADDLE_ENFORCE(platform::is_gpu_place(place), + "CUDA Data must on CUDA place"); + ImmutableCUDA(place); + return cuda_vec_.data(); } // get cuda ptr. mutable T *CUDAMutableData(platform::Place place) { - return m_->CUDAMutableData(place); + const T *ptr = CUDAData(place); + flag_ = kDirty | kDataInCUDA; + return const_cast(ptr); } // clear - void clear() { m_->clear(); } + void clear() { + size_ = 0; + flag_ = kDirty | kDataInCPU; + } - size_t capacity() const { return m_->capacity(); } + size_t capacity() const { + return cpu_vec_.memory_size() / SizeOfType(typeid(T)); + } // reserve data - void reserve(size_t size) { m_->reserve(size); } + void reserve(size_t size) { + size_t pre_size = size_; + resize(size); + resize(pre_size); + } // the unify method to access CPU or CUDA data. immutable. const T *Data(platform::Place place) const { @@ -445,7 +248,12 @@ class Vector { } // implicit cast operator. Vector can be cast to std::vector implicitly. - operator std::vector() const { return *m_; } + operator std::vector() const { + std::vector result; + result.resize(size()); + std::copy(begin(), end(), result.begin()); + return result; + } bool operator==(const Vector &other) const { if (size() != other.size()) return false; @@ -459,11 +267,118 @@ class Vector { return true; } - const void *Handle() const { return &m_.Data(); } - private: - // Vector is an COW object. - details::COWPtr m_; + void InitEmpty() { + size_ = 0; + flag_ = kDataInCPU; + } + + template + void InitByIter(size_t size, Iter begin, Iter end) { + platform::Place cpu = platform::CPUPlace(); + T *ptr = this->cpu_vec_.template mutable_data( + framework::make_ddim({static_cast(size)}), cpu); + for (size_t i = 0; i < size; ++i) { + *ptr++ = *begin++; + } + flag_ = kDataInCPU | kDirty; + size_ = size; + } + + enum DataFlag { + kDataInCPU = 0x01, + kDataInCUDA = 0x02, + // kDirty means the data has been changed in one device. + kDirty = 0x10 + }; + + void CopyToCPU() const { + // COPY GPU Data To CPU + TensorCopy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); + WaitPlace(cuda_vec_.place()); + } + + void MutableCPU() { + if (IsInCUDA() && IsDirty()) { + CopyToCPU(); + } + flag_ = kDirty | kDataInCPU; + } + + void ImmutableCUDA(platform::Place place) const { + if (IsDirty()) { + if (IsInCPU()) { + TensorCopy(cpu_vec_, boost::get(place), + &cuda_vec_); + WaitPlace(place); + UnsetFlag(kDirty); + SetFlag(kDataInCUDA); + } else if (IsInCUDA() && !(place == cuda_vec_.place())) { + framework::Tensor tmp; + TensorCopy(cuda_vec_, boost::get(place), &tmp); + WaitPlace(cuda_vec_.place()); + cuda_vec_.ShareDataWith(tmp); + // Still dirty + } else { + // Dirty && DataInCUDA && Device is same + // Do nothing + } + } else { + if (!IsInCUDA()) { + // Even data is not dirty. However, data is not in CUDA. Copy data. + TensorCopy(cpu_vec_, boost::get(place), + &cuda_vec_); + WaitPlace(place); + SetFlag(kDataInCUDA); + } else if (!(place == cuda_vec_.place())) { + framework::Tensor tmp; + WaitPlace(cuda_vec_.place()); + TensorCopy(cuda_vec_, boost::get(place), &tmp); + WaitPlace(cuda_vec_.place()); + WaitPlace(place); + cuda_vec_.ShareDataWith(tmp); + } else { + // Not Dirty && DataInCUDA && Device is same + // Do nothing. + } + } + } + + void ImmutableCPU() const { + if (IsDirty() && + !IsInCPU()) { // If data has been changed in CUDA, or CPU has no data. + CopyToCPU(); + UnsetFlag(kDirty); + } + SetFlag(kDataInCPU); + } + + void UnsetFlag(int flag) const { flag_ &= ~flag; } + void SetFlag(int flag) const { flag_ |= flag; } + + bool IsDirty() const { return flag_ & kDirty; } + + bool IsInCUDA() const { return flag_ & kDataInCUDA; } + + bool IsInCPU() const { return flag_ & kDataInCPU; } + + static void WaitPlace(const platform::Place place) { + if (platform::is_gpu_place(place)) { + platform::DeviceContextPool::Instance() + .Get(boost::get(place)) + ->Wait(); + } + } + + static T &EmptyDummy() { + static T dummy = T(); + return dummy; + } + + mutable int flag_; + mutable Tensor cpu_vec_; + mutable Tensor cuda_vec_; + size_t size_; }; #else // PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 4fa047bf3ee3d06ac4aec5d2cc6a355965836d42..df2a7a27ca4a6011b214202ac9bf4f30dc482ece 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, {static_cast(OpRole::kForward), static_cast(OpRole::kBackward), static_cast(OpRole::kOptimize), static_cast(OpRole::kRPC), + static_cast(OpRole::kDist), static_cast(OpRole::kLRSched), static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 18827385ad659922230ff68709a2926a8c9013ac..4ed3cc45d66849267ef4945a03da1db76b53e4ea 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -26,7 +26,13 @@ enum class OpRole { kForward = 0x0000, kBackward = 0x0001, kOptimize = 0x0002, + // RPC role is for send/recv releated op kRPC = 0x0003, + // Dist role is for split_byref/split_selected_rows/concat + // used for distributed training. + kDist = 0x0004, + // Tag all learning rate scheduler operators. + kLRSched = 0x0005, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 5b27068c9e805146b8bce03f4f676ef0d4d16c53..4cb1f3a80e95bdda79e6451dc3cc87e899b11779 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include // for sqrt in CPU and CUDA #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" @@ -306,26 +307,43 @@ class AdamOpKernel : public framework::OpKernel { VLOG(3) << "grad row size is 0!!"; return; } - // merge duplicated rows if any. - // The rows of grad_merge have been sorted inside MergeAdd functor - scatter::MergeAdd merge_func; - auto& grad_merge = *(ctx.scope() - .NewScope() - .Var("sparse_adam_grad_merge") - ->GetMutable()); - merge_func(ctx.template device_context(), grad, - &grad_merge); + + std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; + } + } + + const framework::SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = &grad; + } else { + // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor + scatter::MergeAdd merge_func; + auto* grad_merge_var = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + merge_func(ctx.template device_context(), grad, + grad_merge_var); + grad_merge_ptr = grad_merge_var; + } + + auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - int64_t* rows = nullptr; -// When compiled without CUDA, the CUDAMutableData() interface should not be + const int64_t* rows = nullptr; +// When compiled without CUDA, the CUDAData() interface should not be // provided. #if defined(PADDLE_WITH_CUDA) if (platform::is_gpu_place(ctx.GetPlace())) { - rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace()); + rows = grad_merge.rows().CUDAData(ctx.GetPlace()); } else { #endif - rows = grad_merge.mutable_rows()->data(); + rows = grad_merge.rows().data(); #if defined(PADDLE_WITH_CUDA) } diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index f4983c65432991a45f226d97f0fb05b08a30ca89..5a058ddbc59c6135bacf7c2dc4b5c8b687f9b2b1 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -31,5 +31,6 @@ polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) detection_library(generate_proposals_op SRCS generate_proposals_op.cc) +detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) #Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b98190d40a2afa684cfd29cc52fc29fac851cca7 --- /dev/null +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -0,0 +1,587 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kROISize = 4; + +template +bool GT_E(T a, T b) { + return (a > b) || fabs(a - b) < 1e-4; +} + +template +bool LT_E(T a, T b) { + return (a < b) || fabs(a - b) < 1e-4; +} + +template +bool GT(T a, T b) { + return (a - b) > 1e-4; +} + +/* +*check if (x, y) is in the boundary of roi +*/ +template +bool in_quad(T x, T y, T roi_x[], T roi_y[]) { + for (int i = 0; i < 4; i++) { + T xs = roi_x[i]; + T ys = roi_y[i]; + T xe = roi_x[(i + 1) % 4]; + T ye = roi_y[(i + 1) % 4]; + if (fabs(ys - ye) < 1e-4) { + if (fabs(y - ys) < 1e-4 && fabs(y - ye) < 1e-4 && + GT_E(x, std::min(xs, xe)) && LT_E(x, std::max(xs, xe))) { + return true; + } + } else { + T intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs; + if (fabs(intersec_x - x) < 1e-4 && GT_E(y, std::min(ys, ye)) && + LT_E(y, std::max(ys, ye))) { + return true; + } + } + } + + int n_cross = 0; + for (int i = 0; i < 4; i++) { + T xs = roi_x[i]; + T ys = roi_y[i]; + T xe = roi_x[(i + 1) % 4]; + T ye = roi_y[(i + 1) % 4]; + if (fabs(ys - ye) < 1e-4) { + continue; + } + if (LT_E(y, std::min(ys, ye)) || GT(y, std::max(ys, ye))) { + continue; + } + T intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs; + if (fabs(intersec_x - x) < 1e-4) { + return true; + } + if (GT(intersec_x, x)) { + n_cross++; + } + } + return (n_cross % 2 == 1); +} + +/** + * Get the matrix of perspective transform. + * + * dx1 = x1 - x2 + * dx2 = x3 - x2 + * dx3 = x0 - x1 + x2 - x3 + * dy1 = y1 - y2 + * dy2 = y3 - y2 + * dy3 = y0 - y1 + y2 - y3 + * + * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1) + * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1) + * a13 = x0 + * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1) + * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1) + * a23 = y0 + * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1) + * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1) + * a33 = 1 + * + */ +template +void get_transform_matrix(const int transformed_width, + const int transformed_height, T roi_x[], T roi_y[], + T matrix[]) { + T x0 = roi_x[0]; + T x1 = roi_x[1]; + T x2 = roi_x[2]; + T x3 = roi_x[3]; + T y0 = roi_y[0]; + T y1 = roi_y[1]; + T y2 = roi_y[2]; + T y3 = roi_y[3]; + + // Estimate the height and width of RoI + T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)); + T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)); + T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)); + T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)); + T estimated_height = (len2 + len4) / 2.0; + T estimated_width = (len1 + len3) / 2.0; + + // Get the normalized height and normalized width + int normalized_height = transformed_height; + int normalized_width = + std::round(estimated_width * (normalized_height - 1) / estimated_height) + + 1; + normalized_width = std::min(normalized_width, transformed_width); + + T dx1 = x1 - x2; + T dx2 = x3 - x2; + T dx3 = x0 - x1 + x2 - x3; + T dy1 = y1 - y2; + T dy2 = y3 - y2; + T dy3 = y0 - y1 + y2 - y3; + + matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / + (normalized_width - 1); + matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / + (normalized_height - 1); + matrix[8] = 1; + + matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) / + (normalized_width - 1); + matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) / + (normalized_height - 1); + matrix[5] = y0; + + matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) / + (normalized_width - 1); + matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) / + (normalized_height - 1); + matrix[2] = x0; +} + +/** + * Get the source coordinates in the input feature map. + * + * (u, v, w)^matrix = matrix * (out_w, out_h, 1)^matrix + * + * in_w = u / w + * in_h = v / w + * + */ +template +void get_source_coords(T matrix[], int out_w, int out_h, T* in_w, T* in_h) { + T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2]; + T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5]; + T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8]; + + in_w[0] = u / w; + in_h[0] = v / w; +} + +/** + * Perform bilinear interpolation in the input feature map. + */ +template +void bilinear_interpolate(const T* in_data, const int channels, const int width, + const int height, int in_n, int in_c, T in_w, T in_h, + T* val) { + // Deal with cases that source coords are out of feature map boundary + if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || + GT(in_h, height - 0.5)) { + // empty + val[0] = 0.0; + return; + } + + if (GT(0, in_w)) { + in_w = 0; + } + if (GT(0, in_h)) { + in_h = 0; + } + + int in_w_floor = floor(in_w); + int in_h_floor = floor(in_h); + int in_w_ceil; + int in_h_ceil; + + if (GT_E(in_w_floor, width - 1)) { + in_w_ceil = in_w_floor = width - 1; + in_w = static_cast(in_w_floor); + } else { + in_w_ceil = in_w_floor + 1; + } + + if (GT_E(in_h_floor, height - 1)) { + in_h_ceil = in_h_floor = height - 1; + in_h = static_cast(in_h_floor); + } else { + in_h_ceil = in_h_floor + 1; + } + T w_floor = in_w - in_w_floor; + T h_floor = in_h - in_h_floor; + T w_ceil = 1 - w_floor; + T h_ceil = 1 - h_floor; + const T* data = in_data + (in_n * channels + in_c) * height * width; + // Do bilinear interpolation + T v1 = data[in_h_floor * width + in_w_floor]; + T v2 = data[in_h_ceil * width + in_w_floor]; + T v3 = data[in_h_ceil * width + in_w_ceil]; + T v4 = data[in_h_floor * width + in_w_ceil]; + T w1 = w_ceil * h_ceil; + T w2 = w_ceil * h_floor; + T w3 = w_floor * h_floor; + T w4 = w_floor * h_ceil; + val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +template +class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + const T* input_data = in->data(); + + framework::Tensor roi2image; + roi2image.Resize({rois_num}); + int* roi2image_data = roi2image.mutable_data(ctx.GetPlace()); + auto lod = rois->lod().back(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + roi2image_data[j] = i; + } + } + + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* rois_data = rois->data(); + + for (int n = 0; n < rois_num; ++n) { + const T* n_rois = rois_data + n * 8; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = n_rois[2 * k] * spatial_scale; + roi_y[k] = n_rois[2 * k + 1] * spatial_scale; + } + int image_id = roi2image_data[n]; + // Get transform matrix + T transform_matrix[9]; + get_transform_matrix(transformed_width, transformed_height, roi_x, + roi_y, transform_matrix); + + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < transformed_height; ++out_h) { + for (int out_w = 0; out_w < transformed_width; ++out_w) { + int out_index = + n * channels * transformed_height * transformed_width + + c * transformed_height * transformed_width + + out_h * transformed_width + out_w; + T in_w, in_h; + get_source_coords(transform_matrix, out_w, out_h, &in_w, &in_h); + if (in_quad(in_w, in_h, roi_x, roi_y)) { + if (GT(-0.5, in_w) || + GT(in_w, static_cast(in_width - 0.5)) || + GT(-0.5, in_h) || + GT(in_h, static_cast(in_height - 0.5))) { + output_data[out_index] = 0.0; + } else { + bilinear_interpolate(input_data, channels, in_width, in_height, + image_id, c, in_w, in_h, + output_data + out_index); + } + } else { + output_data[out_index] = 0.0; + } + } + } + } + } + } +}; + +template +T get_feature_gradient(T xs, T ys, int w, int h, const int width, + const int height) { + if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || + GT(ys, height - 0.5)) { + return 0; + } + + if (GT(0, xs)) { + xs = 0; + } + if (GT(0, ys)) { + ys = 0; + } + + int xs_floor = floor(xs); + int ys_floor = floor(ys); + int xs_ceil; + int ys_ceil; + + if (GT_E(xs_floor, width - 1)) { + xs_ceil = xs_floor = width - 1; + xs = static_cast(xs_floor); + } else { + xs_ceil = xs_floor + 1; + } + + if (GT_E(ys_floor, height - 1)) { + ys_ceil = ys_floor = height - 1; + ys = static_cast(ys_floor); + } else { + ys_ceil = ys_floor + 1; + } + + T weight = 0; + if (w == xs_floor) { + if (h == ys_floor) { + weight = (w + 1 - xs) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (w + 1 - xs) * (ys + 1 - h); + } + } else if (w == xs_ceil) { + if (h == ys_floor) { + weight = (xs + 1 - w) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (xs + 1 - w) * (ys + 1 - h); + } + } + return weight; +} + +template +class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + const T* out_grad_data = out_grad->data(); + const T* rois_data = rois->data(); + + framework::Tensor roi2image; + roi2image.Resize({rois_num}); + int* roi2image_data = roi2image.mutable_data(ctx.GetPlace()); + auto lod = rois->lod().back(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + roi2image_data[j] = i; + } + } + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int in_h = 0; in_h < in_height; ++in_h) { + for (int in_w = 0; in_w < in_width; ++in_w) { + T gradient = 0.0; + for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { + const T* rois = rois_data + roi_idx * 8; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = rois[2 * k] * spatial_scale; + roi_y[k] = rois[2 * k + 1] * spatial_scale; + } + + // Get transform matrix + T matrix[9]; + get_transform_matrix(transformed_width, transformed_height, + roi_x, roi_y, matrix); + const T* out_grad_ptr = out_grad_data + + (roi_idx * channels + c) * + transformed_height * + transformed_width; + for (int out_h = 0; out_h < transformed_height; ++out_h) { + for (int out_w = 0; out_w < transformed_width; ++out_w) { + T src_w; + T src_h; + get_source_coords(matrix, out_w, out_h, &src_w, &src_h); + if (in_quad(src_w, src_h, roi_x, roi_y)) { + if (GT(-0.5, src_w) || + GT(src_w, static_cast(in_width - 0.5)) || + GT(-0.5, src_h) || + GT(src_h, static_cast(in_height - 0.5))) { + continue; + } + T weight = get_feature_gradient(src_w, src_h, in_w, in_h, + in_width, in_height); + gradient += + out_grad_ptr[out_h * transformed_width + out_w] * + weight; + } + } + } + } + int out_idx = (n * channels + c) * in_height * in_width + + in_h * in_width + in_w; + in_grad_data[out_idx] = gradient; + } + } + } + } + } +}; + +class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ROIPerspectiveTransformOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("ROIs"), + "Input(ROIs) of ROIPerspectiveTransformOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of ROIPerspectiveTransformOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + auto rois_dims = ctx->GetInputDim("ROIs"); + + PADDLE_ENFORCE(input_dims.size() == 4, + "The format of input tensor is NCHW."); + PADDLE_ENFORCE(rois_dims.size() == 2, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]"); + PADDLE_ENFORCE(rois_dims[1] == 8, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]."); + + int transformed_height = ctx->Attrs().Get("transformed_height"); + int transformed_width = ctx->Attrs().Get("transformed_width"); + float spatial_scale = ctx->Attrs().Get("spatial_scale"); + + PADDLE_ENFORCE_GT(transformed_height, 0, + "The transformed output height must greater than 0"); + PADDLE_ENFORCE_GT(transformed_width, 0, + "The transformed output width must greater than 0"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0"); + std::vector out_dims_v({rois_dims[0], // num_rois + input_dims[1], // channels + static_cast(transformed_height), + static_cast(transformed_width)}); + auto out_dims = framework::make_ddim(out_dims_v); + + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIPerspectiveTransformOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "the input of ROIPerspectiveTransformOp. " + "The format of input tensor is NCHW. Where N is batch size, " + "C is the number of input channels, " + "H is the height of the feature, and " + "W is the width of the feature."); + AddInput("ROIs", + "(LoDTensor), " + "ROIs (Regions of Interest) to be transformed. " + "should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x1, y1, x2, y2, x3, y3, x4, y4], ...]." + "(x1, y1) is the top left coordinates, and " + "(x2, y2) is the top right coordinates, and" + "(x3, y3) is the bottom right coordinates, and" + "(x4, y4) is the bottom left coordinates."); + AddOutput( + "Out", + "(Tensor), " + "The output of ROIPerspectiveTransformOp is a 4-D tensor with shape " + "(num_rois, channels, transformed_h, transformed_w)."); + AddAttr("spatial_scale", + "(float, default 1.0), " + "Spatial scale factor to scale ROI coords.") + .SetDefault(1.0); + AddAttr("transformed_height", + "(int, default 1), " + "The height of transformed output.") + .SetDefault(1); + AddAttr("transformed_width", + "(int, default 1), " + "The width of transformed output.") + .SetDefault(1); + AddComment(R"DOC( +**ROIPerspectiveTransform Operator** + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(roi_perspective_transform, ops::ROIPerspectiveTransformOp, + ops::ROIPerspectiveTransformOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(roi_perspective_transform_grad, + ops::ROIPerspectiveTransformGradOp); +REGISTER_OP_CPU_KERNEL(roi_perspective_transform, + ops::CPUROIPerspectiveTransformOpKernel); +REGISTER_OP_CPU_KERNEL(roi_perspective_transform_grad, + ops::CPUROIPerspectiveTransformGradOpKernel); diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b683b7573db747bde5f57e530ec53760db099843 --- /dev/null +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -0,0 +1,523 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +// CUDA: index helpers +#define idx4_4(index, d1, d2, d3, d4) (index % d4) +#define idx4_3(index, d1, d2, d3, d4) ((index / d4) % d3) +#define idx4_2(index, d1, d2, d3, d4) ((index / d4 / d3) % d2) +#define idx4_1(index, d1, d2, d3, d4) ((index / d4 / d3 / d2) % d1) + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ bool GT_E(T a, T b) { + return (a > b) || fabs(a - b) < 1e-4; +} + +template +__device__ bool LT_E(T a, T b) { + return (a < b) || fabs(a - b) < 1e-4; +} + +template +__device__ bool GT(T a, T b) { + return (a - b) > 1e-4; +} + +template +__device__ T max(T a, T b) { + return a > b ? a : b; +} + +template +__device__ T min(T a, T b) { + return a < b ? a : b; +} + +/* +* check if (x, y) is in the boundary of roi +*/ +template +__device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) { + for (int i = 0; i < 4; i++) { + T start_w = roi_x[i]; + T start_h = roi_y[i]; + T end_w = roi_x[(i + 1) % 4]; + T end_h = roi_y[(i + 1) % 4]; + if (fabs(start_h - end_h) < 1e-4) { + if (fabs(y - start_h) < 1e-4 && fabs(y - end_h) < 1e-4 && + GT_E(x, min(start_w, end_w)) && + LT_E(x, max(start_w, end_w))) { + return true; + } + } else { + T intersec_x = + (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w; + if (fabs(intersec_x - x) < 1e-4 && GT_E(y, min(start_h, end_h)) && + LT_E(y, max(start_h, end_h))) { + return true; + } + } + } + + int n_cross = 0; + for (int i = 0; i < 4; i++) { + T start_w = roi_x[i]; + T start_h = roi_y[i]; + T end_w = roi_x[(i + 1) % 4]; + T end_h = roi_y[(i + 1) % 4]; + if (fabs(start_h - end_h) < 1e-4) { + continue; + } + if (LT_E(y, min(start_h, end_h)) || + GT(y, max(start_h, end_h))) { + continue; + } + T intersec_x = + (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w; + if (fabs(intersec_x - x) < 1e-4) { + return true; + } + if (GT(intersec_x, x)) { + n_cross++; + } + } + return (n_cross % 2 == 1); +} + +/** + * Perform bilinear interpolation in the input feature map. + */ +template +__device__ void bilinear_interpolate(const T* in_data, const int channels, + const int width, const int height, + int in_n, int in_c, T in_w, T in_h, + T* val) { + // Deal with cases that source coords are out of feature map boundary + if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || + GT(in_h, height - 0.5)) { + val[0] = 0.0; + return; + } + + if (GT(0, in_w)) { + in_w = 0; + } + if (GT(0, in_h)) { + in_h = 0; + } + + int in_w_floor = floor(in_w); + int in_h_floor = floor(in_h); + int in_w_ceil; + int in_h_ceil; + + if (GT_E(in_w_floor, width - 1)) { + in_w_ceil = in_w_floor = width - 1; + in_w = static_cast(in_w_floor); + } else { + in_w_ceil = in_w_floor + 1; + } + + if (GT_E(in_h_floor, height - 1)) { + in_h_ceil = in_h_floor = height - 1; + in_h = static_cast(in_h_floor); + } else { + in_h_ceil = in_h_floor + 1; + } + + T w_floor = in_w - in_w_floor; + T h_floor = in_h - in_h_floor; + T w_ceil = 1 - w_floor; + T h_ceil = 1 - h_floor; + const T* data = in_data + (in_n * channels + in_c) * height * width; + // Do bilinear interpolation + T v1 = data[in_h_floor * width + in_w_floor]; + T v2 = data[in_h_ceil * width + in_w_floor]; + T v3 = data[in_h_ceil * width + in_w_ceil]; + T v4 = data[in_h_floor * width + in_w_ceil]; + T w1 = w_ceil * h_ceil; + T w2 = w_ceil * h_floor; + T w3 = w_floor * h_floor; + T w4 = w_floor * h_ceil; + val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +/** + * Get the source coordinates in the input feature map. + * + * (u, v, w)^matrix = T * (out_w, out_h, 1)^matrix + * + * in_w = u / w + * in_h = v / w + * + */ +template +__device__ void get_source_coords(T matrix[], int out_w, int out_h, T* in_w, + T* in_h) { + T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2]; + T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5]; + T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8]; + + in_w[0] = u / w; + in_h[0] = v / w; +} + +/** + * Get the matrix of perspective transform. + * + * dx1 = x1 - x2 + * dx2 = x3 - x2 + * dx3 = x0 - x1 + x2 - x3 + * dy1 = y1 - y2 + * dy2 = y3 - y2 + * dy3 = y0 - y1 + y2 - y3 + * + * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1) + * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1) + * a13 = x0 + * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1) + * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1) + * a23 = y0 + * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1) + * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1) + * a33 = 1 + * + */ +template +__device__ void get_transform_matrix(const int transformed_width, + const int transformed_height, T roi_x[], + T roi_y[], T matrix[]) { + T x0 = roi_x[0]; + T x1 = roi_x[1]; + T x2 = roi_x[2]; + T x3 = roi_x[3]; + T y0 = roi_y[0]; + T y1 = roi_y[1]; + T y2 = roi_y[2]; + T y3 = roi_y[3]; + + // Estimate the height and width of RoI + T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)); + T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)); + T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)); + T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)); + T estimated_height = (len2 + len4) / 2.0; + T estimated_width = (len1 + len3) / 2.0; + + // Get the normalized height and normalized width + int normalized_height = transformed_height; + int normalized_width = + round(estimated_width * (normalized_height - 1) / estimated_height) + 1; + normalized_width = min(normalized_width, transformed_width); + + T dx1 = x1 - x2; + T dx2 = x3 - x2; + T dx3 = x0 - x1 + x2 - x3; + T dy1 = y1 - y2; + T dy2 = y3 - y2; + T dy3 = y0 - y1 + y2 - y3; + + matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / + (normalized_width - 1); + matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / + (normalized_height - 1); + matrix[8] = 1; + + matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) / + (normalized_width - 1); + matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) / + (normalized_height - 1); + matrix[5] = y0; + + matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) / + (normalized_width - 1); + matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) / + (normalized_height - 1); + matrix[2] = x0; +} + +template +__global__ void RoiTransformKernel(const float* input_data, + const float* rois_data, + const int* roi2image_data, int num_rois, + int in_height, int in_width, int channels, + int transformed_height, + int transformed_width, float spatial_scale, + T* output_data) { + int output_size = + num_rois * transformed_height * transformed_width * channels; + + CUDA_1D_KERNEL_LOOP(index, output_size) { + // (n, c, out_h, out_w) is an element in the transformed output + int out_w = idx4_4(index, num_rois, channels, transformed_height, + transformed_width); + int out_h = idx4_3(index, num_rois, channels, transformed_height, + transformed_width); + int c = idx4_2(index, num_rois, channels, transformed_height, + transformed_width); + int n = idx4_1(index, num_rois, channels, transformed_height, + transformed_width); + + auto bottom_rois = rois_data + n * 8; + int roi_batch_ind = bottom_rois[0]; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = bottom_rois[2 * k] * spatial_scale; + roi_y[k] = bottom_rois[2 * k + 1] * spatial_scale; + } + + // Get transform matrix + T matrix[9]; + get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y, + matrix); + + // Get source coords + T in_w; + T in_h; + get_source_coords(matrix, out_w, out_h, &in_w, &in_h); + + if (in_quad(in_w, in_h, roi_x, roi_y)) { + if (GT(-0.5, in_w) || GT(in_w, static_cast(in_width - 0.5)) || + GT(-0.5, in_h) || GT(in_h, static_cast(in_height - 0.5))) { + // Skip if source coords is not in input image + output_data[index] = 0.0; + } else { + // Perform bilinear interpolation + int in_n = roi2image_data[n]; + bilinear_interpolate(input_data, channels, in_width, in_height, in_n, + c, in_w, in_h, output_data + index); + } + + } else { + // Skip if source coords is not in quad + output_data[index] = 0.0; + } + } +} + +template +class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + const T* input_data = in->data(); + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* rois_data = rois->data(); + + framework::Tensor roi2image; + framework::Tensor roi2image_dev; + roi2image.Resize({rois_num}); + int* roi2image_data = roi2image.mutable_data(platform::CPUPlace()); + auto lod = rois->lod().back(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + roi2image_data[j] = i; + } + } + TensorCopySync(roi2image, ctx.GetPlace(), &roi2image_dev); + + int out_size = rois_num * transformed_height * transformed_width * channels; + auto stream = ctx.cuda_device_context().stream(); + int block = 512; + int grid = (out_size + block - 1) / block; + + RoiTransformKernel<<>>( + input_data, rois_data, roi2image_dev.data(), rois_num, in_height, + in_width, channels, transformed_height, transformed_width, + spatial_scale, output_data); + } +}; + +template +__device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, + const int height) { + if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || + GT(ys, height - 0.5)) { + return 0; + } + + if (GT(0, xs)) { + xs = 0; + } + if (GT(0, ys)) { + ys = 0; + } + + int xs_floor = floor(xs); + int ys_floor = floor(ys); + int xs_ceil; + int ys_ceil; + + if (GT_E(xs_floor, width - 1)) { + xs_ceil = xs_floor = width - 1; + xs = static_cast(xs_floor); + } else { + xs_ceil = xs_floor + 1; + } + + if (GT_E(ys_floor, height - 1)) { + ys_ceil = ys_floor = height - 1; + ys = static_cast(ys_floor); + } else { + ys_ceil = ys_floor + 1; + } + + T weight = 0; + if (w == xs_floor) { + if (h == ys_floor) { + weight = (w + 1 - xs) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (w + 1 - xs) * (ys + 1 - h); + } + } else if (w == xs_ceil) { + if (h == ys_floor) { + weight = (xs + 1 - w) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (xs + 1 - w) * (ys + 1 - h); + } + } + return weight; +} + +template +__global__ void RoiTransformGradKernel( + const size_t* lod, const T* rois_data, int batch_size, int num_rois, + int in_height, int in_width, int channels, int transformed_height, + int transformed_width, float spatial_scale, const T* out_grad_data, + T* in_grad_data) { + int input_size = batch_size * in_height * in_width * channels; + + CUDA_1D_KERNEL_LOOP(index, input_size) { + // (n, c, h, w) coords in input + int in_w = idx4_4(index, batch_size, channels, in_height, in_width); + int in_h = idx4_3(index, batch_size, channels, in_height, in_width); + int c = idx4_2(index, batch_size, channels, in_height, in_width); + int n = idx4_1(index, batch_size, channels, in_height, in_width); + + T gradient = 0.0; + // Accumulate gradient over all RoIs that interpolated this element + for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { + const T* rois = rois_data + roi_idx * 8; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = rois[2 * k] * spatial_scale; + roi_y[k] = rois[2 * k + 1] * spatial_scale; + } + + // Get transform matrix + T matrix[9]; + get_transform_matrix(transformed_width, transformed_height, roi_x, + roi_y, matrix); + + const T* out_grad_ptr = + out_grad_data + + (roi_idx * channels + c) * transformed_height * transformed_width; + for (int out_h = 0; out_h < transformed_height; ++out_h) { + for (int out_w = 0; out_w < transformed_width; ++out_w) { + T src_w; + T src_h; + get_source_coords(matrix, out_w, out_h, &src_w, &src_h); + if (in_quad(src_w, src_h, roi_x, roi_y)) { + if (GT(-0.5, src_w) || + GT(src_w, static_cast(in_width - 0.5)) || + GT(-0.5, src_h) || + GT(src_h, static_cast(in_height - 0.5))) { + continue; + } + T weight = get_feature_gradient(src_w, src_h, in_w, in_h, + in_width, in_height); + gradient += + out_grad_ptr[out_h * transformed_width + out_w] * weight; + } + } + } + } + in_grad_data[index] = gradient; + } +} + +template +class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + const T* out_grad_data = out_grad->data(); + const T* rois_data = rois->data(); + + auto lod = rois->lod().back(); + auto lod_data = lod.CUDAData(ctx.GetPlace()); + + int in_size = in->numel(); + auto stream = ctx.cuda_device_context().stream(); + int block = 512; + int grid = (in_size + block - 1) / block; + + RoiTransformGradKernel<<>>( + lod_data, rois_data, batch_size, rois_num, in_height, in_width, + channels, transformed_height, transformed_width, spatial_scale, + out_grad_data, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(roi_perspective_transform, + ops::CUDAROIPerspectiveTransformOpKernel); +REGISTER_OP_CUDA_KERNEL(roi_perspective_transform_grad, + ops::CUDAROIPerspectiveTransformGradOpKernel); diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h index dd5d138a1e979826d59c4731920379b030e3b492..dd1ab85fd8d0c8170afcd9dd2a49ee55c41dc8be 100644 --- a/paddle/fluid/operators/detection_map_op.h +++ b/paddle/fluid/operators/detection_map_op.h @@ -76,8 +76,8 @@ class DetectionMAPOpKernel : public framework::OpKernel { auto ap_type = GetAPType(ctx.Attr("ap_type")); int class_num = ctx.Attr("class_num"); - auto& label_lod = in_label->lod(); - auto& detect_lod = in_detect->lod(); + auto label_lod = in_label->lod(); + auto detect_lod = in_detect->lod(); PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(), @@ -166,11 +166,11 @@ class DetectionMAPOpKernel : public framework::OpKernel { auto labels = framework::EigenTensor::From(input_label); auto detect = framework::EigenTensor::From(input_detect); - auto& label_lod = input_label.lod(); - auto& detect_lod = input_detect.lod(); + auto label_lod = input_label.lod(); + auto detect_lod = input_detect.lod(); int batch_size = label_lod[0].size() - 1; - auto& label_index = label_lod[0]; + auto label_index = label_lod[0]; for (int n = 0; n < batch_size; ++n) { std::map> boxes; @@ -274,6 +274,7 @@ class DetectionMAPOpKernel : public framework::OpKernel { output_true_pos->set_lod(true_pos_lod); output_false_pos->set_lod(false_pos_lod); + return; } void GetInputPos(const framework::Tensor& input_pos_count, @@ -291,7 +292,7 @@ class DetectionMAPOpKernel : public framework::OpKernel { auto SetData = [](const framework::LoDTensor& pos_tensor, std::map>>& pos) { const T* pos_data = pos_tensor.data(); - auto& pos_data_lod = pos_tensor.lod()[0]; + auto pos_data_lod = pos_tensor.lod()[0]; for (size_t i = 0; i < pos_data_lod.size() - 1; ++i) { for (size_t j = pos_data_lod[i]; j < pos_data_lod[i + 1]; ++j) { T score = pos_data[j * 2]; @@ -316,23 +317,20 @@ class DetectionMAPOpKernel : public framework::OpKernel { std::map>>* false_pos) const { int batch_size = gt_boxes.size(); for (int n = 0; n < batch_size; ++n) { - auto& image_gt_boxes = gt_boxes[n]; - for (auto& image_gt_box : image_gt_boxes) { + auto image_gt_boxes = gt_boxes[n]; + for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) { size_t count = 0; - auto& labeled_bboxes = image_gt_box.second; + auto labeled_bboxes = it->second; if (evaluate_difficult) { count = labeled_bboxes.size(); } else { - for (auto& box : labeled_bboxes) { - if (!box.is_difficult) { - ++count; - } - } + for (size_t i = 0; i < labeled_bboxes.size(); ++i) + if (!(labeled_bboxes[i].is_difficult)) ++count; } if (count == 0) { continue; } - int label = image_gt_box.first; + int label = it->first; if (label_pos_count->find(label) == label_pos_count->end()) { (*label_pos_count)[label] = count; } else { diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 1617cc1b95216b118cf2c2122dbe8b6c106554c3..c4854d50b6371064003a10e18efc9e5f160d9a42 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { + auto server_var = GetVar(); + if (!server_var) { + LOG(ERROR) << "recved var should not on current server: " + << meta_.varname(); + return false; + } auto* tensor = GetVar()->GetMutable(); tensor->Resize(dims); - framework::LoD lod; for (int i = 0; i < meta_.lod_level(); ++i) { framework::Vector v; @@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData( void* tensor_data = tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); - if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { return false; } diff --git a/paddle/fluid/operators/extract_rows_op.cc b/paddle/fluid/operators/extract_rows_op.cc index 3acae3bcdf4a509ab6e7e19f21c4b2ec4d72b7d7..9a297d03cfb041e584159a5fc5ba214f8ac404b4 100644 --- a/paddle/fluid/operators/extract_rows_op.cc +++ b/paddle/fluid/operators/extract_rows_op.cc @@ -50,7 +50,7 @@ class ExtractRowsOp : public framework::OperatorBase { auto &in = scope.FindVar(Input("X"))->Get(); auto out = scope.FindVar(Output("Out"))->GetMutable(); - auto &in_rows = in.rows(); + auto in_rows = in.rows(); auto out_dim = framework::make_ddim( std::vector{static_cast(in_rows.size()), 1}); auto dst_ptr = out->mutable_data(out_dim, in.place()); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index ae51a53a7197950338ef773d63103fa13bf0a5f5..b27880c232a51d32777569cf9ac67656ce02f232 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -60,9 +60,11 @@ struct SelectedRowsAdd { auto out_place = context.GetPlace(); PADDLE_ENFORCE(platform::is_gpu_place(out_place)); - memory::Copy(boost::get(out_place), out_data, - boost::get(in1_place), in1_data, - in1_value.numel() * sizeof(T), context.stream()); + memory::Copy( + boost::get(out_place), out_data, + boost::get(in1_place), in1_data, + in1_value.numel() * sizeof(T), + reinterpret_cast(context).stream()); auto* in2_data = in2_value.data(); memory::Copy(boost::get(out_place), @@ -107,7 +109,7 @@ struct SelectedRowsAddTensor { PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); auto& in1_value = input1.value(); - framework::Vector in1_rows(input1.rows()); + auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); @@ -146,7 +148,7 @@ struct SelectedRowsAddTo { auto in1_height = input1.height(); PADDLE_ENFORCE_EQ(in1_height, input2->height()); - auto& in1_rows = input1.rows(); + framework::Vector in1_rows(input1.rows()); auto& in2_rows = *(input2->mutable_rows()); auto& in1_value = input1.value(); @@ -206,7 +208,7 @@ struct SelectedRowsAddToTensor { PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); auto& in1_value = input1.value(); - framework::Vector in1_rows(input1.rows()); + auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu index e89b27855bdeba3a5189feff94eb063ddfb9b9b8..5fc50aba25d8e69480a17f0f80877b0d03e17276 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu @@ -20,7 +20,9 @@ limitations under the License. */ TEST(selected_rows_functor, gpu_add) { paddle::platform::CUDAPlace gpu_place(0); paddle::platform::CPUPlace cpu_place; - paddle::platform::CUDADeviceContext ctx(gpu_place); + paddle::platform::CUDADeviceContext& ctx = + *reinterpret_cast( + paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); paddle::operators::math::SetConstant functor; @@ -132,7 +134,9 @@ TEST(selected_rows_functor, gpu_add) { TEST(selected_rows_functor, gpu_add_to) { paddle::platform::CUDAPlace gpu_place(0); paddle::platform::CPUPlace cpu_place; - paddle::platform::CUDADeviceContext ctx(gpu_place); + paddle::platform::CUDADeviceContext& ctx = + *reinterpret_cast( + paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); paddle::operators::math::SetConstant functor; diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 2c4c2411259e9b7bfa223e4d65823ce4610596d0..6dffe527c1072ee97fcde1725bfc1a47ed1ad74a 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -123,6 +123,7 @@ class SumKernel : public framework::OpKernel { out_value->Resize(framework::make_ddim(in_dim)); out_value->mutable_data(context.GetPlace()); + // if all the input sparse vars are empty, no need to // merge these vars. if (first_dim == 0UL) { diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index f577068d1f39a3083a54f106d006f9982304411e..1f61a0e289f32196ead04d71d07b513cbe4655b1 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) { .value("Backward", framework::OpRole::kBackward) .value("Optimize", framework::OpRole::kOptimize) .value("Loss", framework::OpRole::kLoss) - .value("RPC", framework::OpRole::kRPC); + .value("RPC", framework::OpRole::kRPC) + .value("Dist", framework::OpRole::kDist) + .value("LRSched", framework::OpRole::kLRSched); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1ca2ac2ddc7daef3f4c0ea2004a62258ae4610ac..9e4a5ae8baaf7f2975c8060856f9eecab55f241c 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -46,7 +46,7 @@ from . import transpiler from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope -from .transpiler import DistributeTranspiler, InferenceTranspiler, \ +from .transpiler import DistributeTranspiler, \ memory_optimize, release_memory, DistributeTranspilerConfig from .lod_tensor import create_lod_tensor, create_random_int_lodtensor from . import clip diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 59ffc5c8a104473e30ea12337a3e7e7b4ceb5387..1d3c94229048ef568dfa651cc20731190beee3b8 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1510,6 +1510,30 @@ class Program(object): self._op_role_var = [] self._current_role = OpRole.Forward + @contextlib.contextmanager + def _lr_schedule_guard(self): + """ + A with guard to set :code:`LRSched` :code:`OpRole` and + :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is + set to the target learning rate. + + Notes: This is a very low level API. Users should not use it directly. + + + Examples: + + >>> p, g = backward(...) + >>> with program.lr_schedule_guard(): + >>> lr = lr * decay + """ + OpRole = core.op_proto_and_checker_maker.OpRole + self._current_role = OpRole.LRSched + # TODO(typhoonzero): how to set target learning rate var + self._op_role_var = [] + yield + self._op_role_var = [] + self._current_role = OpRole.Forward + def __str__(self): """ Get the protobuf debug string of this Program. diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 7a7a0078a557c47492a4396897aafabe6c9c5dcb..a26b8df5a240be8340597b9627866c323fa98a2d 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -74,7 +74,7 @@ class Initializer(object): directly, but need to use one of its implementations. """ - def __init_(self): + def __init__(self): pass def __call__(self, param, block): @@ -293,7 +293,7 @@ class TruncatedNormalInitializer(Initializer): assert loc is not None assert scale is not None assert seed is not None - super(NormalInitializer, self).__init__() + super(TruncatedNormalInitializer, self).__init__() self._mean = loc self._std_dev = scale self._seed = seed diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 78bb8a1a0a64631cbe2adc11b1494ceed6d14908..e703e5ac7943b006741f12886a14bf344a6b9b28 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -27,8 +27,7 @@ from . import core __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', - 'load_persistables', 'save_inference_model', 'load_inference_model', - 'get_inference_program' + 'load_persistables', 'save_inference_model', 'load_inference_model' ] @@ -504,23 +503,6 @@ def load_persistables(executor, dirname, main_program=None, filename=None): filename=filename) -def get_inference_program(target_vars, main_program=None): - if main_program is None: - main_program = default_main_program() - if not isinstance(target_vars, list): - target_vars = [target_vars] - vars = [] - for var in target_vars: - if isinstance(var, Evaluator): - vars.extend(var.states) - vars.extend(var.metrics) - else: - vars.append(var) - pruned_program = main_program._prune(targets=vars) - inference_program = pruned_program._inference_optimize() - return inference_program - - def prepend_feed_ops(inference_program, feed_target_names, feed_holder_name='feed'): diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8e86bec8609f17c973389047df66b4d725113e6e..574d0d727cba9fa9de0cffbe116f71b9e65a7092 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -39,6 +39,7 @@ __all__ = [ 'detection_map', 'rpn_target_assign', 'anchor_generator', + 'roi_perspective_transform', 'generate_proposal_labels', 'generate_proposals', ] @@ -1262,6 +1263,54 @@ def anchor_generator(input, return anchor, var +def roi_perspective_transform(input, + rois, + transformed_height, + transformed_width, + spatial_scale=1.0): + """ + ROI perspective transform op. + + Args: + input (Variable): The input of ROIPerspectiveTransformOp. The format of + input tensor is NCHW. Where N is batch size, C is the + number of input channels, H is the height of the feature, + and W is the width of the feature. + rois (Variable): ROIs (Regions of Interest) to be transformed. It should be + a 2-D LoDTensor of shape (num_rois, 8). Given as + [[x1, y1, x2, y2, x3, y3, x4, y4], ...], (x1, y1) is the + top left coordinates, and (x2, y2) is the top right + coordinates, and (x3, y3) is the bottom right coordinates, + and (x4, y4) is the bottom left coordinates. + transformed_height (integer): The height of transformed output. + transformed_height (integer): The width of transformed output. + spatial_scale (float): Spatial scale factor to scale ROI coords. Default: 1.0 + + Returns: + Variable: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape + (num_rois, channels, transformed_h, transformed_w). + + Examples: + .. code-block:: python + + out = fluid.layers.roi_perspective_transform(input, rois, 7, 7, 1.0) + """ + helper = LayerHelper('roi_perspective_transform', **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="roi_perspective_transform", + inputs={"X": input, + "ROIs": rois}, + outputs={"Out": out}, + attrs={ + "transformed_height": transformed_height, + "transformed_width": transformed_width, + "spatial_scale": spatial_scale + }) + return out + + def generate_proposal_labels(rpn_rois, gt_classes, is_crowd, diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 7c6cb932a3767efd6a3e8d5ad1ebfb2badf6211e..dfd801a098d6451dbdb20d9ba44187d1e3f8a91a 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -27,7 +27,7 @@ from . import nn from . import ops from . import tensor from ..initializer import init_on_cpu -from ..framework import default_main_program, Parameter +from ..framework import default_main_program, Parameter, unique_name __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', @@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps): Returns: The decayed learning rate. """ - global_step = _decay_step_counter(1) + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter(1) - a = global_step**-0.5 - b = (warmup_steps**-1.5) * global_step - lr_value = (d_model**-0.5) * nn.elementwise_min(a, b) + a = global_step**-0.5 + b = (warmup_steps**-1.5) * global_step + lr_value = (d_model**-0.5) * nn.elementwise_min(a, b) return lr_value @@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): sgd_optimizer.minimize(avg_cost) """ - global_step = _decay_step_counter() + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) - decayed_lr = learning_rate * (decay_rate**div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) + decayed_lr = learning_rate * (decay_rate**div_res) - return decayed_lr + return decayed_lr def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): Returns: The decayed learning rate """ - global_step = _decay_step_counter() + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) - decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) + decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) - return decayed_lr + return decayed_lr def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): staircase=True)) sgd_optimizer.minimize(avg_cost) """ - global_step = _decay_step_counter() + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) - decayed_lr = learning_rate / (1 + decay_rate * div_res) + decayed_lr = learning_rate / (1 + decay_rate * div_res) - return decayed_lr + return decayed_lr def polynomial_decay(learning_rate, @@ -220,25 +224,28 @@ def polynomial_decay(learning_rate, Returns: Variable: The decayed learning rate """ - global_step = _decay_step_counter() - - if cycle: - div_res = ops.ceil(global_step / decay_steps) - zero_var = tensor.fill_constant(shape=[1], dtype='float32', value=0.0) - one_var = tensor.fill_constant(shape=[1], dtype='float32', value=1.0) - - with control_flow.Switch() as switch: - with switch.case(global_step == zero_var): - tensor.assign(input=one_var, output=div_res) - decay_steps = decay_steps * div_res - else: - decay_steps_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(decay_steps)) - global_step = nn.elementwise_min(x=global_step, y=decay_steps_var) + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() + + if cycle: + div_res = ops.ceil(global_step / decay_steps) + zero_var = tensor.fill_constant( + shape=[1], dtype='float32', value=0.0) + one_var = tensor.fill_constant( + shape=[1], dtype='float32', value=1.0) + + with control_flow.Switch() as switch: + with switch.case(global_step == zero_var): + tensor.assign(input=one_var, output=div_res) + decay_steps = decay_steps * div_res + else: + decay_steps_var = tensor.fill_constant( + shape=[1], dtype='float32', value=float(decay_steps)) + global_step = nn.elementwise_min(x=global_step, y=decay_steps_var) - decayed_lr = (learning_rate - end_learning_rate) * \ - ((1 - global_step / decay_steps) ** power) + end_learning_rate - return decayed_lr + decayed_lr = (learning_rate - end_learning_rate) * \ + ((1 - global_step / decay_steps) ** power) + end_learning_rate + return decayed_lr def piecewise_decay(boundaries, values): @@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values): """ + with default_main_program()._lr_schedule_guard(): + if len(values) - len(boundaries) != 1: + raise ValueError("len(values) - len(boundaries) should be 1") - if len(values) - len(boundaries) != 1: - raise ValueError("len(values) - len(boundaries) should be 1") - - global_step = _decay_step_counter() + global_step = _decay_step_counter() - lr = tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") + lr = tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") - with control_flow.Switch() as switch: - for i in range(len(boundaries)): - boundary_val = tensor.fill_constant( + with control_flow.Switch() as switch: + for i in range(len(boundaries)): + boundary_val = tensor.fill_constant( + shape=[1], + dtype='float32', + value=float(boundaries[i]), + force_cpu=True) + value_var = tensor.fill_constant( + shape=[1], dtype='float32', value=float(values[i])) + with switch.case(global_step < boundary_val): + tensor.assign(value_var, lr) + last_value_var = tensor.fill_constant( shape=[1], dtype='float32', - value=float(boundaries[i]), - force_cpu=True) - value_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(values[i])) - with switch.case(global_step < boundary_val): - tensor.assign(value_var, lr) - last_value_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(values[len(values) - 1])) - with switch.default(): - tensor.assign(last_value_var, lr) + value=float(values[len(values) - 1])) + with switch.default(): + tensor.assign(last_value_var, lr) return lr diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4b696b06baa5f4df29edb8d7aefd7c75d3ca765f..8aeb4986475baf7e9733aa389c69f3fe6e2ac2eb 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -107,6 +107,12 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'elu', + 'relu6', + 'pow', + 'stanh', + 'hard_sigmoid', + 'swish', 'prelu', 'flatten', 'sequence_mask', @@ -5903,6 +5909,148 @@ def pad2d(input, return out +@templatedoc() +def elu(x, alpha=1.0, name=None): + """ + ${comment} + Args: + x(${x_type}): ${x_comment} + alpha(${alpha_type}|1.0): ${alpha_comment} + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + output(${out_type}): ${out_comment} + """ + helper = LayerHelper('elu', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='elu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'alpha': alpha}) + return out + + +@templatedoc() +def relu6(x, threshold=6.0, name=None): + """ + ${comment} + Args: + x(${x_type}): ${x_comment} + threshold(${threshold_type}|6.0): ${threshold_comment} + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + output(${out_type}): ${out_comment} + """ + helper = LayerHelper('relu6', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='relu6', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'threshold': threshold}) + return out + + +@templatedoc() +def pow(x, factor=1.0, name=None): + """ + ${comment} + Args: + x(${x_type}): ${x_comment} + factor(${factor_type}|1.0): ${factor_comment} + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + output(${out_type}): ${out_comment} + """ + helper = LayerHelper('pow', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='pow', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'factor': factor}) + return out + + +@templatedoc() +def stanh(x, scale_a=2.0 / 3.0, scale_b=1.7159, name=None): + """ + ${comment} + Args: + x(${x_type}): ${x_comment} + scale_a(${scale_a_type}|2.0 / 3.0): ${scale_a_comment} + scale_b(${scale_b_type}|1.7159): ${scale_b_comment} + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + output(${out_type}): ${out_comment} + """ + helper = LayerHelper('stanh', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='stanh', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'scale_a': scale_a, + 'scale_b': scale_b}) + return out + + +@templatedoc() +def hard_sigmoid(x, slope=0.2, offset=0.5, name=None): + """ + ${comment} + Args: + x(${x_type}): ${x_comment} + slope(${slope_type}|0.2): ${slope_comment} + offset(${offset_type}|0.5): ${offset_comment} + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + output(${out_type}): ${out_comment} + """ + helper = LayerHelper('hard_sigmoid', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='hard_sigmoid', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'slope': slope, + 'offset': offset}) + return out + + +@templatedoc() +def swish(x, beta=1.0, name=None): + """ + ${comment} + Args: + x(${x_type}): ${x_comment} + beta(${beta_type}|1.0): ${beta_comment} + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + output(${out_type}): ${out_comment} + """ + helper = LayerHelper('swish', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='swish', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'slope': beta}) + return out + + def prelu(x, mode, param_attr=None, name=None): """ Equation: diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 3e25394a05fd9e6fa1f6ac968b2d12a1c0e43c34..c087122a4c5198071a4c51e6bcf200edb8f2ce1b 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -36,12 +36,6 @@ __activations__ = [ 'brelu', 'leaky_relu', 'soft_relu', - 'elu', - 'relu6', - 'pow', - 'stanh', - 'hard_sigmoid', - 'swish', ] __all__ = [ diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f53fe6d69d0855c8ba88eac8059708b690d2475b..d02c890209e65bdceb5da23ba5b9c7c0356174b8 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -80,7 +80,8 @@ if(WITH_DISTRIBUTE) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL) endif(NOT APPLE) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) - py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) + #FIXME(gongwb): random fails. + #py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index 3ec79f8ef6e6f70f1365eaa32352c284d294a1ea..175bd130e5a8324227953eeeb769474e78f94fd2 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -437,13 +437,8 @@ def split_data(data, num_part): ] -def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, +def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num): - # Context to do validation. - test_program = train_progm.clone() - with fluid.program_guard(test_program): - test_program = fluid.io.get_inference_program([avg_cost]) - val_data = DataReader( src_vocab_fpath=TrainTaskConfig.src_vocab_fpath, trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath, @@ -505,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, - token_num, predict): + token_num, predict, test_program): # Initialize the parameters. if TrainTaskConfig.ckpt_path: lr_scheduler.current_steps = TrainTaskConfig.start_step @@ -554,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, -1] + label_data_input_fields if TrainTaskConfig.val_file_pattern is not None: - test = test_context(train_progm, avg_cost, train_exe, dev_count, + test = test_context(test_program, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num) # the best cross-entropy value with label smoothing @@ -1647,6 +1642,8 @@ def get_model(is_dist, is_async): local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, TrainTaskConfig.warmup_steps, TrainTaskConfig.learning_rate) + # Context to do validation. + test_program = fluid.default_main_program().clone(for_test=True) if not is_dist: optimizer = fluid.optimizer.Adam( @@ -1671,7 +1668,7 @@ def get_model(is_dist, is_async): epsilon=TrainTaskConfig.eps) optimizer.minimize(sum_cost) - return sum_cost, avg_cost, predict, token_num, local_lr_scheduler + return sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program def update_args(): @@ -1705,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase): def run_trainer(self, use_cuda, args): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() TrainTaskConfig.use_gpu = use_cuda - sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( + sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model( args.is_dist, not args.sync_mode) if args.is_dist: @@ -1726,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase): TrainTaskConfig.local = not args.is_dist train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost, - local_lr_scheduler, token_num, predict) + local_lr_scheduler, token_num, predict, test_program) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index e97643cddef22465436051a41ef4b825e9634d23..b5549c507ed753f4504afd655be59b444164e6f3 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -345,7 +345,7 @@ class OpTest(unittest.TestCase): actual_t, expect_t, atol=atol, equal_nan=equal_nan), "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " + str(expect_t) + "\n" + "But Got" + - str(actual_t) + " in class " + self.__class__.__name__) + str(actual_t)) if isinstance(expect, tuple): self.assertListEqual(actual.recursive_sequence_lengths(), expect[1], "Output (" + out_name + diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py index 0c5343a97d5ef0f97fc6b144dfc82174eacb8573..f6eb8f2c6d8b94f92e24ff789c91efb53a645a46 100644 --- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py +++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py @@ -20,7 +20,6 @@ import six import sys import collections import math -import paddle.fluid as fluid from op_test import OpTest @@ -33,7 +32,7 @@ class TestDetectionMAPOp(OpTest): self.detect = np.array(self.detect).astype('float32') self.mAP = np.array(self.mAP).astype('float32') - if len(self.class_pos_count) > 0: + if (len(self.class_pos_count) > 0): self.class_pos_count = np.array(self.class_pos_count).astype( 'int32') self.true_pos = np.array(self.true_pos).astype('float32') @@ -274,7 +273,7 @@ class TestDetectionMAPOp11Point(TestDetectionMAPOp): class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp): def init_test_case(self): super(TestDetectionMAPOpMultiBatch, self).init_test_case() - self.class_pos_count = [0, 2, 1, 0] + self.class_pos_count = [0, 2, 1] self.true_pos_lod = [[0, 3, 2]] self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]] self.false_pos_lod = [[0, 3, 2]] diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 59a137c18c9435ef5c5772d0cc08f197c1d86603..09b1c546e49bd02bf336f31885bf4c7339cc5a2c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -22,7 +22,7 @@ class TestDistMnist2x2(TestDistBase): self._sync_mode = True self._use_reduce = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=1e-7) @@ -31,7 +31,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase): self._sync_mode = True self._mem_opt = True - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=1e-7) @@ -40,7 +40,7 @@ class TestDistMnistAsync(TestDistBase): self._sync_mode = False self._use_reduce = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=200) diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index c0e9fa38e7d1eadd89eff9a8ba4442f888b8120e..7c3ed0916845d0a0cc0c462ff00927b91f546b21 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -21,7 +21,16 @@ class TestDistSeResneXt2x2(TestDistBase): def _setup_config(self): self._sync_mode = True - def test_se_resnext(self): + def test_dist_train(self): + self.check_with_place("dist_se_resnext.py", delta=1e-7) + + +class TestDistseResnXt2x2WithMemopt(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._mem_opt = True + + def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=1e-7) @@ -29,7 +38,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transformer.py b/python/paddle/fluid/tests/unittests/test_dist_transformer.py index 47083ca7e954c85bb42fcc88639f3e757283cbea..47e8dfaf03ceb27a74f5e48d662d2b534d2d152b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transformer.py @@ -59,7 +59,7 @@ class TestDistTransformer2x2Sync(TestDistBase): def _setup_config(self): self._sync_mode = True - def test_transformer(self): + def test_dist_train(self): download_files() self.check_with_place("dist_transformer.py", delta=1e-5) @@ -68,7 +68,7 @@ class TestDistTransformer2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False - def test_transformer(self): + def test_dist_train(self): download_files() self.check_with_place("dist_transformer.py", delta=1.0) diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py index 9a3e92e8d775a37e0c24ee1bcc5435628d61bb91..33b39b262b95b0013e3696c3f15a288a2e801ce1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py @@ -17,19 +17,28 @@ import unittest from test_dist_base import TestDistBase -class TestDistSeResneXt2x2(TestDistBase): +class TestDistW2V2x2(TestDistBase): def _setup_config(self): self._sync_mode = True - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_word2vec.py", delta=1e-4) -class TestDistSeResneXt2x2Async(TestDistBase): +class TestDistW2V2x2WithMemOpt(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._mem_opt = True + + def test_dist_train(self): + self.check_with_place("dist_word2vec.py", delta=1e-4) + + +class TestDistW2V2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_word2vec.py", delta=1) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 9a17d3213c902e0c18123097025349434d271d7f..6855a0e2c0437429fe35b1dd97188716945a743b 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -573,6 +573,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_roi_perspective_transform(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[8], dtype="float32", lod_level=1) + output = layers.roi_perspective_transform(x, rois, 7, 7, 0.6) + self.assertIsNotNone(output) + print(str(program)) + def test_sequence_enumerate(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py new file mode 100644 index 0000000000000000000000000000000000000000..de675131564db43926f97ff4e6dedcaa02ff15b0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py @@ -0,0 +1,306 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUWARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys +import paddle.compat as cpt +from op_test import OpTest +from math import sqrt +from math import floor + + +def gt_e(a, b): + return a > b or abs(a - b) < 1e-4 + + +def gt(a, b): + return (a - b) > 1e-4 + + +def lt_e(a, b): + return a < b or abs(a - b) < 1e-4 + + +def in_quad(x, y, roi_x, roi_y): + # check if (x, y) is in the boundary of roi + for i in range(4): + xs = roi_x[i] + ys = roi_y[i] + xe = roi_x[(i + 1) % 4] + ye = roi_y[(i + 1) % 4] + if abs(ys - ye) < 1e-4: + if abs(y - ys) < 1e-4 and abs(y - ye) < 1e-4 and gt_e( + x, min(xs, xe)) and lt_e(x, max(xs, xe)): + return True + else: + intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs + if abs(intersec_x - x) < 1e-4 and gt_e(y, min(ys, ye)) and lt_e( + y, max(ys, ye)): + return True + n_cross = 0 + for i in range(4): + xs = roi_x[i] + ys = roi_y[i] + xe = roi_x[(i + 1) % 4] + ye = roi_y[(i + 1) % 4] + if abs(ys - ye) < 1e-4: + continue + if lt_e(y, min(ys, ye)) or gt(y, max(ys, ye)): + continue + intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs + if abs(intersec_x - x) < 1e-4: + return True + if gt(intersec_x, x): + n_cross += 1 + return (n_cross % 2 == 1) + + +def get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y): + x0 = roi_x[0] + x1 = roi_x[1] + x2 = roi_x[2] + x3 = roi_x[3] + y0 = roi_y[0] + y1 = roi_y[1] + y2 = roi_y[2] + y3 = roi_y[3] + + len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)) + len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)) + len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)) + len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)) + estimated_height = (len2 + len4) / 2.0 + estimated_width = (len1 + len3) / 2.0 + + normalized_height = transformed_height + normalized_width = round(estimated_width * + (normalized_height - 1) / estimated_height) + 1 + normalized_width = min(normalized_width, transformed_width) + + dx1 = x1 - x2 + dx2 = x3 - x2 + dx3 = x0 - x1 + x2 - x3 + dy1 = y1 - y2 + dy2 = y3 - y2 + dy3 = y0 - y1 + y2 - y3 + matrix = np.zeros([9]) + matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / ( + normalized_width - 1) + matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / ( + normalized_height - 1) + matrix[8] = 1 + + matrix[3] = (y1 - y0 + matrix[6] * + (normalized_width - 1) * y1) / (normalized_width - 1) + matrix[4] = (y3 - y0 + matrix[7] * + (normalized_height - 1) * y3) / (normalized_height - 1) + matrix[5] = y0 + + matrix[0] = (x1 - x0 + matrix[6] * + (normalized_width - 1) * x1) / (normalized_width - 1) + matrix[1] = (x3 - x0 + matrix[7] * + (normalized_height - 1) * x3) / (normalized_height - 1) + matrix[2] = x0 + return matrix + + +def get_source_coords(matrix, out_w, out_h): + u = matrix[0] * out_w + matrix[1] * out_h + matrix[2] + v = matrix[3] * out_w + matrix[4] * out_h + matrix[5] + w = matrix[6] * out_w + matrix[7] * out_h + matrix[8] + in_w = u / w + in_h = v / w + return in_w, in_h + + +def bilinear_interpolate(in_data, in_n, in_c, in_w, in_h): + + batch_size = in_data.shape[0] + channels = in_data.shape[1] + height = in_data.shape[2] + width = in_data.shape[3] + + if gt(-0.5, in_w) or gt(in_w, width - 0.5) or gt(-0.5, in_h) or gt( + in_h, height - 0.5): + return 0.0 + + if gt(0, in_w): + in_w = 0 + if gt(0, in_h): + in_h = 0 + + in_w_floor = floor(in_w) + in_h_floor = floor(in_h) + + if gt_e(in_w_floor, width - 1): + in_w_ceil = width - 1 + in_w_floor = width - 1 + in_w = in_w_floor + else: + in_w_ceil = in_w_floor + 1 + + if gt_e(in_h_floor, height - 1): + in_h_ceil = height - 1 + in_h_floor = height - 1 + in_h = in_h_floor + else: + in_h_ceil = in_h_floor + 1 + + w_floor = in_w - in_w_floor + h_floor = in_h - in_h_floor + w_ceil = 1 - w_floor + h_ceil = 1 - h_floor + v1 = in_data[in_n][in_c][int(in_h_floor)][int(in_w_floor)] + v2 = in_data[in_n][in_c][int(in_h_ceil)][int(in_w_floor)] + v3 = in_data[in_n][in_c][int(in_h_ceil)][int(in_w_ceil)] + v4 = in_data[in_n][in_c][int(in_h_floor)][int(in_w_ceil)] + w1 = w_ceil * h_ceil + w2 = w_ceil * h_floor + w3 = w_floor * h_floor + w4 = w_floor * h_ceil + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + return val + + +def lod_convert(lod): + ret = [0] + for count in lod: + ret.append(ret[-1] + count) + return ret + + +def roi_transform(in_data, rois, rois_lod, transformed_height, + transformed_width, spatial_scale): + channels = in_data.shape[1] + in_height = in_data.shape[2] + in_width = in_data.shape[3] + rois_num = rois.shape[0] + + roi2image = [0] * rois_num + rois_lod = lod_convert(rois_lod[0]) + for i in range(len(rois_lod) - 1): + for j in range(rois_lod[i], rois_lod[i + 1]): + roi2image[j] = i + + out = np.zeros([rois_num, channels, transformed_height, transformed_width]) + + for n in range(rois_num): + roi_x = [] + roi_y = [] + for k in range(4): + roi_x.append(rois[n][2 * k] * spatial_scale) + roi_y.append(rois[n][2 * k + 1] * spatial_scale) + image_id = roi2image[n] + transform_matrix = get_transform_matrix( + transformed_width, transformed_height, roi_x, roi_y) + + for c in range(channels): + for out_h in range(transformed_height): + for out_w in range(transformed_width): + in_w, in_h = get_source_coords(transform_matrix, out_w, + out_h) + if in_quad(in_w, in_h, roi_x, roi_y) and gt_e( + in_w, -0.5) and lt_e(in_w, in_width - 0.5) and gt_e( + in_h, -0.5) and lt_e(in_h, in_height - 0.5): + out[n][c][out_h][out_w] = bilinear_interpolate( + in_data, image_id, c, in_w, in_h) + else: + out[n][c][out_h][out_w] = 0.0 + return out.astype("float32") + + +class TestROIPoolOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + + self.inputs = {'X': self.x, 'ROIs': (self.rois, self.rois_lod)} + + self.attrs = { + 'spatial_scale': self.spatial_scale, + 'transformed_height': self.transformed_height, + 'transformed_width': self.transformed_width + } + out = roi_transform(self.x, self.rois, self.rois_lod, + self.transformed_height, self.transformed_width, + self.spatial_scale) + self.outputs = {'Out': out} + + def init_test_case(self): + self.batch_size = 2 + self.channels = 2 + self.height = 8 + self.width = 8 + + # n, c, h, w + self.x_dim = (self.batch_size, self.channels, self.height, self.width) + + self.spatial_scale = 1.0 / 2.0 + self.transformed_height = 2 + self.transformed_width = 3 + + self.x = np.random.random(self.x_dim).astype('float32') + + def make_rois(self): + rois = [] + self.rois_lod = [[]] + for bno in range(self.batch_size): + self.rois_lod[0].append(bno + 1) + for i in range(bno + 1): + x1 = np.random.randint( + 0, + self.width // self.spatial_scale - self.transformed_width) + y1 = np.random.randint( + 0, + self.height // self.spatial_scale - self.transformed_height) + + x2 = np.random.randint(x1 + self.transformed_width, + self.width // self.spatial_scale) + y2 = np.random.randint( + 0, + self.height // self.spatial_scale - self.transformed_height) + + x3 = np.random.randint(x1 + self.transformed_width, + self.width // self.spatial_scale) + y3 = np.random.randint(y1 + self.transformed_height, + self.height // self.spatial_scale) + + x4 = np.random.randint( + 0, + self.width // self.spatial_scale - self.transformed_width) + y4 = np.random.randint(y1 + self.transformed_height, + self.height // self.spatial_scale) + + roi = [x1, y1, x2, y2, x3, y3, x4, y4] + rois.append(roi) + self.rois_num = len(rois) + self.rois = np.array(rois).astype("float32") + + def setUp(self): + self.op_type = "roi_perspective_transform" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 200175cfe87e24a53e1e229e41d1ff2a25fd66ec..59899e7e9ab98f661699d5ac0645c92bd23a1512 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -21,13 +21,12 @@ import paddle def delete_ops(block, ops): - try: - start = list(block.ops).index(ops[0]) - end = list(block.ops).index(ops[-1]) - [block._remove_op(start) for _ in six.moves.range(end - start + 1)] - except Exception as e: - raise e - block.program._sync_with_cpp() + for op in ops: + try: + idx = list(block.ops).index(op) + block._remove_op(idx) + except Exception as e: + print(e) def find_op_by_input_arg(block, arg_name): @@ -37,10 +36,18 @@ def find_op_by_input_arg(block, arg_name): return -1 -def find_op_by_output_arg(block, arg_name): - for index, op in enumerate(block.ops): - if arg_name in op.output_arg_names: - return index +def find_op_by_output_arg(block, arg_name, reverse=False): + if reverse: + pos = len(block.ops) - 1 + while pos >= 0: + op = block.ops[pos] + if arg_name in op.output_arg_names: + return pos + pos -= 1 + else: + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index return -1 diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f58f1883a407a3123856e19b5ec8fc01862466a7..3f8c7b844a9fdc8404560ba4c78f9d328af2852a 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC +DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist +LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched + +PRINT_LOG = False + + +def log(*args): + if PRINT_LOG: + print(args) class VarBlock: @@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object): slice_var_up = True split_method = None min_block_size = 8192 + print_log = False class DistributeTranspiler(object): @@ -174,6 +184,9 @@ class DistributeTranspiler(object): if self.config.split_method is None: self.config.split_method = RoundRobin + global PRINT_LOG + if self.config.print_log: + PRINT_LOG = True assert (self.config.min_block_size >= 8192) assert (self.config.split_method.__bases__[0] == PSDispatcher) @@ -257,12 +270,12 @@ class DistributeTranspiler(object): splited_grad_varname = grad_varname if len(splited_vars) == 1: splited_grad_varname = splited_vars[0].name - index = find_op_by_output_arg(program.global_block(), - splited_grad_varname) + index = find_op_by_output_arg( + program.global_block(), splited_grad_varname, reverse=True) elif len(splited_vars) > 1: orig_var = program.global_block().vars[splited_grad_varname] - index = find_op_by_output_arg(program.global_block(), - splited_grad_varname) + index = find_op_by_output_arg( + program.global_block(), splited_grad_varname, reverse=True) self._insert_split_op(program, orig_var, index, splited_vars) index += 1 else: @@ -301,7 +314,7 @@ class DistributeTranspiler(object): self.grad_name_to_send_dummy_out[ self.table_name] = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - input_deps = self.grad_name_to_send_dummy_out.values() + input_deps = list(self.grad_name_to_send_dummy_out.values()) program.global_block().append_op( type="send_barrier", @@ -377,7 +390,10 @@ class DistributeTranspiler(object): type="concat", inputs={"X": splited_var}, outputs={"Out": [orig_param]}, - attrs={"axis": 0}) + attrs={ + "axis": 0, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) @@ -496,9 +512,9 @@ class DistributeTranspiler(object): # NOTE: assume blocks of the same variable is not distributed # on the same pserver, only change param/grad varnames for # trainers to fetch. - sys.stderr.write("get_pserver_program() is deprecated, call\ - get_pserver_programs() to get pserver main and startup\ - in a single call.") + sys.stderr.write("get_pserver_program() is deprecated, call \ +get_pserver_programs() to get pserver main and startup \ +in a single call.") # step1 pserver_program = Program() pserver_program.random_seed = self.origin_program.random_seed @@ -615,22 +631,31 @@ class DistributeTranspiler(object): for idx, opt_op in enumerate(opt_op_on_pserver): per_opt_block = pserver_program._create_block(pre_block_idx) optimize_blocks.append(per_opt_block) + optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0] # append grad merging ops before clip and weight decay - # cases may like: - # L2Decay op -> clip op -> optimize + # e.g. merge grad -> L2Decay op -> clip op -> optimize + merged_var = None for _, op in enumerate(self.optimize_ops): - # find the origin @GRAD var before clipping - grad_varname_for_block = __op_have_grad_input__(op) - if ufind.is_connected(op, opt_op) and grad_varname_for_block: + # find the origin grad var before clipping/L2Decay, + # merged_var should be the input var name of L2Decaybuil + grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] + if op.attr(OP_ROLE_VAR_ATTR_NAME)[ + 0] == optimize_target_param_name: merged_var = self._append_pserver_grad_merge_ops( per_opt_block, grad_varname_for_block, endpoint, grad_to_block_id, self.origin_program) - break # append optimize op once then append other ops. - for _, op in enumerate(self.optimize_ops): - # optimizer is connected to itself - if ufind.is_connected(op, opt_op) and op not in global_ops: - __append_optimize_op__(op, per_opt_block, grad_to_block_id, - merged_var, lr_ops) + if merged_var: + break # append optimize op once then append other ops. + if merged_var: + for _, op in enumerate(self.optimize_ops): + # optimizer is connected to itself + if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \ + op not in global_ops: + log("append opt op: ", op.type, op.input_arg_names, + merged_var) + __append_optimize_op__(op, per_opt_block, + grad_to_block_id, merged_var, + lr_ops) # dedup grad to ids list grad_to_block_id = list(set(grad_to_block_id)) @@ -726,17 +751,17 @@ class DistributeTranspiler(object): Returns: Program: parameter server side startup program. """ - sys.stderr.write("get_startup_program() is deprecated, call\ - get_pserver_programs() to get pserver main and startup\ - in a single call.") + sys.stderr.write("get_startup_program() is deprecated, call \ +get_pserver_programs() to get pserver main and startup \ +in a single call.") if pserver_program != None: - sys.stderr.write("passing pserver_program to get_startup_program()\ - is deprecated, you can use new API get_pserver_programs() to\ - get both pserver main program and startup program.") + sys.stderr.write("passing pserver_program to get_startup_program() \ +is deprecated, you can use new API get_pserver_programs() to \ +get both pserver main program and startup program.") if startup_program != None: - sys.stderr.write("passing startup_program to get_startup_program()\ - is deprecated, use fluid.program_guard() or pass this argument\ - to transpile() call.") + sys.stderr.write("passing startup_program to get_startup_program() \ +is deprecated, use fluid.program_guard() or pass this argument \ +to transpile() call.") s_prog = Program() orig_s_prog = self.startup_program @@ -1302,7 +1327,10 @@ class DistributeTranspiler(object): type="split_selected_rows", inputs={"X": orig_var}, outputs={"Out": splited_vars}, - attrs={"height_sections": height_sections}) + attrs={ + "height_sections": height_sections, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: @@ -1312,8 +1340,10 @@ class DistributeTranspiler(object): type="split_byref", inputs={"X": orig_var}, outputs={"Out": splited_vars}, - attrs={"sections": sections} # assume split evenly - ) + attrs={ + "sections": sections, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) else: AssertionError("Variable type should be in set " "[LOD_TENSOR, SELECTED_ROWS]") @@ -1381,15 +1411,15 @@ class DistributeTranspiler(object): if not grad_block: # do not append this op if current endpoint # is not dealing with this grad block - return + return None orig_varname, block_name, trainer_name = self._get_varname_parts( grad_block.name) if block_name: merged_var_name = '.'.join([orig_varname, block_name]) else: merged_var_name = orig_varname - merged_var = \ - pserver_block.vars[merged_var_name] + + merged_var = pserver_block.vars[merged_var_name] grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx)) if self.sync_mode and self.trainer_num > 1: vars2merge = [] @@ -1473,7 +1503,6 @@ class DistributeTranspiler(object): outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) outputs["ParamOut"] = new_inputs["Param"] - optimize_block.append_op( type=opt_op.type, inputs=new_inputs, @@ -1618,6 +1647,16 @@ class DistributeTranspiler(object): return iomap def _get_lr_ops(self): + lr_ops = [] + block = self.origin_program.global_block() + for op in block.ops: + if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int( + LR_SCHED_OP_ROLE_ATTR_VALUE): + lr_ops.append(op) + log("append lr op: ", op.type) + return lr_ops + + def _get_lr_ops_deprecated(self): lr_ops = [] # find learning rate variables by optimize op lr_vars = set() @@ -1670,20 +1709,21 @@ class DistributeTranspiler(object): block = self.origin_program.global_block() opt_ops = [] params_grads = [] + # tmp set to dedup + optimize_params = set() origin_var_dict = self.origin_program.global_block().vars for op in block.ops: if self._is_opt_role_op(op): opt_ops.append(op) - # HACK(wuyi): if we find grad vars from input of optimize - # ops, we may get the output of clip op. Use syntax "@GRAD" - # and op_role_var to get the pair. - for input_name in op.input_arg_names: - if input_name.find("@GRAD") != -1 and \ - op.attr(RPC_OP_ROLE_ATTR_NAME): - param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] + if op.attr(OP_ROLE_VAR_ATTR_NAME): + param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] + grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] + if not param_name in optimize_params: + optimize_params.add(param_name) + log("adding param_grad pair: ", param_name, grad_name) params_grads.append([ origin_var_dict[param_name], - origin_var_dict[input_name] + origin_var_dict[grad_name] ]) else: pass diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index d4517059a4b033eec20ef6903894426ccbd597d7..d5aa54d752305b188d292f95f05cd70d27702c35 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,10 +14,10 @@ from __future__ import print_function -from collections import defaultdict +from collections import defaultdict, OrderedDict, Callable from .. import core from ... import compat as cpt -from ..framework import Program, default_main_program, Parameter +from ..framework import Program, default_main_program, Parameter, Variable from ..backward import _rename_arg_ from functools import reduce from six.moves import range @@ -113,8 +113,10 @@ class ControlFlowGraph(object): def _fill_pool(self, i, is_forward): block_desc = self._ops[i].block() in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) + # NOTE: must sort the in_diff set for cases that get different cache var. + # FIXME(typhoonzero): maybe use a "sorted set" is better than this. can_optimize = [ - x for x in in_diff + x for x in sorted(list(in_diff)) if self._check_var_validity(block_desc, x, is_forward) ] if can_optimize: @@ -220,8 +222,9 @@ class ControlFlowGraph(object): block_desc = op.block() is_forward = i < self._forward_num if self.pool: + # NOTE: must sort the in_diff set for cases that get different cache var. defs_can_optimize = [ - x for x in self._defs[i] + x for x in sorted(list(self._defs[i])) if self._check_var_validity(block_desc, x, is_forward) ] out_pair = [ @@ -271,6 +274,8 @@ class ControlFlowGraph(object): self._program.block(block_desc.id).var(cpt.to_text( x)).desc = self._find_var(block_desc, cache_var, is_forward) + self._program.block(block_desc.id).vars[cpt.to_text(x)] = \ + Variable(self._program.block(block_desc.id), name=cpt.to_text(x)) self._update_graph(x, cache_var, begin_idx=i) break self._fill_pool(i, is_forward)