提交 788636f0 编写于 作者: T typhoonzero

update by comments

上级 e2d56832
......@@ -98,9 +98,6 @@ class Tensor {
/*! The internal of two tensors share the same memory block. */
inline Tensor& ShareDataWith(const Tensor& src);
/*! Share part of the memory of the two tensors */
inline Tensor& ShareDataWith(const Tensor* src, size_t offset);
/**
* @brief Return a sub-tensor of the given tensor.
*
......
......@@ -162,37 +162,6 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
return *this;
}
inline Tensor& Tensor::ShareDataWith(const Tensor* src, size_t offset) {
// NOTE: data size is determined by current tensor shape and data type
src->check_memory_size();
PADDLE_ENFORCE_EQ(src->type(), this->type(),
"tensor data type must be the same when sharing data");
auto place = src->place();
auto type = src->type();
size_t size = src->numel() * SizeOfType(src->type());
auto* ref = src->data<uint8_t>() + offset;
if (platform::is_cpu_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), ref, size, type));
} else if (platform::is_gpu_place(place) ||
platform::is_cuda_pinned_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(
"CUDAPlace or CUDAPinnedPlace is not supported in CPU-only mode.");
}
#else
if (platform::is_gpu_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), ref, size, type));
} else if (platform::is_cuda_pinned_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CUDAPinnedPlace>(
boost::get<platform::CUDAPinnedPlace>(place), ref, size, type));
}
}
#endif
return *this;
}
inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
......
......@@ -26,15 +26,14 @@ class SplitByrefOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride_numel(in->dims());
auto place = ctx.GetPlace();
size_t input_offset = 0;
size_t row_offset = 0;
for (size_t i = 0; i < outs.size(); ++i) {
// NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i];
out->ShareDataWith(in, input_offset);
input_offset += out->numel() * framework::SizeOfType(out->type());
*out = std::move(in->Slice(row_offset, out->dims()[0]));
row_offset += out->dims()[0];
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册