提交 cdb9605b 编写于 作者: C chengduoZH

refine

上级 c809fee3
...@@ -181,8 +181,8 @@ class Vector { ...@@ -181,8 +181,8 @@ class Vector {
template <typename It> template <typename It>
void Extend(It begin, It end) { void Extend(It begin, It end) {
MutableCPU(); MutableCPU();
cpu_.reserve((end - begin) + cpu_.size()); auto out_it = std::back_inserter<std::vector<T>>(this->cpu_);
std::copy(begin, end, std::back_inserter<std::vector<T>>(cpu_)); std::copy(begin, end, out_it);
} }
// resize the vector // resize the vector
...@@ -291,8 +291,11 @@ class Vector { ...@@ -291,8 +291,11 @@ class Vector {
void *src = cpu_.data(); void *src = cpu_.data();
gpu_.Resize(place, cpu_.size() * sizeof(T)); gpu_.Resize(place, cpu_.size() * sizeof(T));
void *dst = gpu_.data_; void *dst = gpu_.data_;
memory::Copy(boost::get<platform::CUDAPlace>(place), dst, auto stream = static_cast<platform::CUDADeviceContext *>(
platform::CPUPlace(), src, gpu_.size_, nullptr); platform::DeviceContextPool::Instance().Get(place))
->stream();
memory::Copy(gpu_.place_, dst, platform::CPUPlace(), src, gpu_.size_,
stream);
} }
void ImmutableCPU() const { void ImmutableCPU() const {
...@@ -399,10 +402,16 @@ class Vector { ...@@ -399,10 +402,16 @@ class Vector {
} }
// resize the vector // resize the vector
void resize(size_t size) { m_->resize(size); } void resize(size_t size) {
if (m_.Data().size() != size) {
m_->resize(size);
}
}
// get cuda ptr. immutable // get cuda ptr. immutable
const T *CUDAData(platform::Place place) const { return m_->CUDAData(place); } const T *CUDAData(platform::Place place) const {
return m_.Data().CUDAData(place);
}
// get cuda ptr. mutable // get cuda ptr. mutable
T *CUDAMutableData(platform::Place place) { T *CUDAMutableData(platform::Place place) {
...@@ -450,6 +459,8 @@ class Vector { ...@@ -450,6 +459,8 @@ class Vector {
return true; return true;
} }
const void *Handle() const { return &m_.Data(); }
private: private:
// Vector is an COW object. // Vector is an COW object.
details::COWPtr<VectorData> m_; details::COWPtr<VectorData> m_;
......
...@@ -50,7 +50,7 @@ class ExtractRowsOp : public framework::OperatorBase { ...@@ -50,7 +50,7 @@ class ExtractRowsOp : public framework::OperatorBase {
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>(); auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto in_rows = in.rows(); auto &in_rows = in.rows();
auto out_dim = framework::make_ddim( auto out_dim = framework::make_ddim(
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1}); std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place()); auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
......
...@@ -60,11 +60,9 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> { ...@@ -60,11 +60,9 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
auto out_place = context.GetPlace(); auto out_place = context.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(out_place)); PADDLE_ENFORCE(platform::is_gpu_place(out_place));
memory::Copy( memory::Copy(boost::get<platform::CUDAPlace>(out_place), out_data,
boost::get<platform::CUDAPlace>(out_place), out_data,
boost::get<platform::CUDAPlace>(in1_place), in1_data, boost::get<platform::CUDAPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T), in1_value.numel() * sizeof(T), context.stream());
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
auto* in2_data = in2_value.data<T>(); auto* in2_data = in2_value.data<T>();
memory::Copy(boost::get<platform::CUDAPlace>(out_place), memory::Copy(boost::get<platform::CUDAPlace>(out_place),
...@@ -148,7 +146,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { ...@@ -148,7 +146,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2->height()); PADDLE_ENFORCE_EQ(in1_height, input2->height());
framework::Vector<int64_t> in1_rows(input1.rows()); auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows()); auto& in2_rows = *(input2->mutable_rows());
auto& in1_value = input1.value(); auto& in1_value = input1.value();
......
...@@ -123,7 +123,6 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -123,7 +123,6 @@ class SumKernel : public framework::OpKernel<T> {
out_value->Resize(framework::make_ddim(in_dim)); out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(context.GetPlace()); out_value->mutable_data<T>(context.GetPlace());
// if all the input sparse vars are empty, no need to // if all the input sparse vars are empty, no need to
// merge these vars. // merge these vars.
if (first_dim == 0UL) { if (first_dim == 0UL) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册