提交 cdb9605b 编写于 作者: C chengduoZH

refine

上级 c809fee3
......@@ -181,8 +181,8 @@ class Vector {
template <typename It>
void Extend(It begin, It end) {
MutableCPU();
cpu_.reserve((end - begin) + cpu_.size());
std::copy(begin, end, std::back_inserter<std::vector<T>>(cpu_));
auto out_it = std::back_inserter<std::vector<T>>(this->cpu_);
std::copy(begin, end, out_it);
}
// resize the vector
......@@ -291,8 +291,11 @@ class Vector {
void *src = cpu_.data();
gpu_.Resize(place, cpu_.size() * sizeof(T));
void *dst = gpu_.data_;
memory::Copy(boost::get<platform::CUDAPlace>(place), dst,
platform::CPUPlace(), src, gpu_.size_, nullptr);
auto stream = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
memory::Copy(gpu_.place_, dst, platform::CPUPlace(), src, gpu_.size_,
stream);
}
void ImmutableCPU() const {
......@@ -399,10 +402,16 @@ class Vector {
}
// resize the vector
void resize(size_t size) { m_->resize(size); }
void resize(size_t size) {
if (m_.Data().size() != size) {
m_->resize(size);
}
}
// get cuda ptr. immutable
const T *CUDAData(platform::Place place) const { return m_->CUDAData(place); }
const T *CUDAData(platform::Place place) const {
return m_.Data().CUDAData(place);
}
// get cuda ptr. mutable
T *CUDAMutableData(platform::Place place) {
......@@ -450,6 +459,8 @@ class Vector {
return true;
}
const void *Handle() const { return &m_.Data(); }
private:
// Vector is an COW object.
details::COWPtr<VectorData> m_;
......
......@@ -50,7 +50,7 @@ class ExtractRowsOp : public framework::OperatorBase {
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto in_rows = in.rows();
auto &in_rows = in.rows();
auto out_dim = framework::make_ddim(
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
......
......@@ -60,11 +60,9 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
auto out_place = context.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(out_place));
memory::Copy(
boost::get<platform::CUDAPlace>(out_place), out_data,
boost::get<platform::CUDAPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
memory::Copy(boost::get<platform::CUDAPlace>(out_place), out_data,
boost::get<platform::CUDAPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T), context.stream());
auto* in2_data = in2_value.data<T>();
memory::Copy(boost::get<platform::CUDAPlace>(out_place),
......@@ -148,7 +146,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2->height());
framework::Vector<int64_t> in1_rows(input1.rows());
auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows());
auto& in1_value = input1.value();
......
......@@ -123,7 +123,6 @@ class SumKernel : public framework::OpKernel<T> {
out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(context.GetPlace());
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册