From cdb9605badfd51bc6a1c1ea59c0eea6dc1f602c0 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 17 Sep 2018 20:59:41 +0800 Subject: [PATCH] refine --- paddle/fluid/framework/mixed_vector.h | 23 ++++++++++++++----- paddle/fluid/operators/extract_rows_op.cc | 2 +- .../operators/math/selected_rows_functor.cu | 10 ++++---- paddle/fluid/operators/sum_op.h | 1 - 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 4a2b378887e..ba2c41eb896 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -181,8 +181,8 @@ class Vector { template void Extend(It begin, It end) { MutableCPU(); - cpu_.reserve((end - begin) + cpu_.size()); - std::copy(begin, end, std::back_inserter>(cpu_)); + auto out_it = std::back_inserter>(this->cpu_); + std::copy(begin, end, out_it); } // resize the vector @@ -291,8 +291,11 @@ class Vector { void *src = cpu_.data(); gpu_.Resize(place, cpu_.size() * sizeof(T)); void *dst = gpu_.data_; - memory::Copy(boost::get(place), dst, - platform::CPUPlace(), src, gpu_.size_, nullptr); + auto stream = static_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + memory::Copy(gpu_.place_, dst, platform::CPUPlace(), src, gpu_.size_, + stream); } void ImmutableCPU() const { @@ -399,10 +402,16 @@ class Vector { } // resize the vector - void resize(size_t size) { m_->resize(size); } + void resize(size_t size) { + if (m_.Data().size() != size) { + m_->resize(size); + } + } // get cuda ptr. immutable - const T *CUDAData(platform::Place place) const { return m_->CUDAData(place); } + const T *CUDAData(platform::Place place) const { + return m_.Data().CUDAData(place); + } // get cuda ptr. mutable T *CUDAMutableData(platform::Place place) { @@ -450,6 +459,8 @@ class Vector { return true; } + const void *Handle() const { return &m_.Data(); } + private: // Vector is an COW object. details::COWPtr m_; diff --git a/paddle/fluid/operators/extract_rows_op.cc b/paddle/fluid/operators/extract_rows_op.cc index 9a297d03cfb..3acae3bcdf4 100644 --- a/paddle/fluid/operators/extract_rows_op.cc +++ b/paddle/fluid/operators/extract_rows_op.cc @@ -50,7 +50,7 @@ class ExtractRowsOp : public framework::OperatorBase { auto &in = scope.FindVar(Input("X"))->Get(); auto out = scope.FindVar(Output("Out"))->GetMutable(); - auto in_rows = in.rows(); + auto &in_rows = in.rows(); auto out_dim = framework::make_ddim( std::vector{static_cast(in_rows.size()), 1}); auto dst_ptr = out->mutable_data(out_dim, in.place()); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a92762c7fea..d559aaa7210 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -60,11 +60,9 @@ struct SelectedRowsAdd { auto out_place = context.GetPlace(); PADDLE_ENFORCE(platform::is_gpu_place(out_place)); - memory::Copy( - boost::get(out_place), out_data, - boost::get(in1_place), in1_data, - in1_value.numel() * sizeof(T), - reinterpret_cast(context).stream()); + memory::Copy(boost::get(out_place), out_data, + boost::get(in1_place), in1_data, + in1_value.numel() * sizeof(T), context.stream()); auto* in2_data = in2_value.data(); memory::Copy(boost::get(out_place), @@ -148,7 +146,7 @@ struct SelectedRowsAddTo { auto in1_height = input1.height(); PADDLE_ENFORCE_EQ(in1_height, input2->height()); - framework::Vector in1_rows(input1.rows()); + auto& in1_rows = input1.rows(); auto& in2_rows = *(input2->mutable_rows()); auto& in1_value = input1.value(); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 6dffe527c10..2c4c2411259 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -123,7 +123,6 @@ class SumKernel : public framework::OpKernel { out_value->Resize(framework::make_ddim(in_dim)); out_value->mutable_data(context.GetPlace()); - // if all the input sparse vars are empty, no need to // merge these vars. if (first_dim == 0UL) { -- GitLab