From cfffb1a36251e7d06535dac6db220131e36fe9f8 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Wed, 14 Feb 2018 10:32:46 -0800 Subject: [PATCH] Update tensor_util.h (#8422) * Update tensor_util.h * Update with moved TensorDesc * Fix tensur_utils.cu * Update * Update * Update * Update * Make tensor_util.cu a symbolic link --- .../fluid/framework/data_device_transform.cc | 2 +- .../framework/data_device_transform_test.cu | 4 +- paddle/fluid/framework/executor.cc | 6 +- paddle/fluid/framework/lod_tensor.cc | 12 +- paddle/fluid/framework/lod_tensor.h | 4 +- paddle/fluid/framework/mixed_vector.h | 12 +- paddle/fluid/framework/reader.cc | 2 +- paddle/fluid/framework/selected_rows.cc | 4 +- paddle/fluid/framework/tensor_util.cc | 198 ++++++++++++- paddle/fluid/framework/tensor_util.cu | 120 +------- paddle/fluid/framework/tensor_util.h | 261 ++---------------- paddle/fluid/framework/tensor_util_test.cc | 65 ++--- paddle/fluid/framework/tensor_util_test.cu | 8 +- paddle/fluid/framework/threadpool.h | 2 +- paddle/fluid/operators/array_operator.h | 2 +- .../fluid/operators/array_to_lod_tensor_op.cc | 4 +- paddle/fluid/operators/assign_op.cc | 4 +- paddle/fluid/operators/assign_value_op.h | 2 +- .../fluid/operators/beam_search_decode_op.h | 4 +- paddle/fluid/operators/detection_output_op.h | 34 +-- paddle/fluid/operators/expand_op.h | 3 +- paddle/fluid/operators/feed_op.cc | 2 +- paddle/fluid/operators/fetch_op.cc | 2 +- paddle/fluid/operators/fill_op.cc | 2 +- paddle/fluid/operators/layer_norm_op.h | 4 +- paddle/fluid/operators/load_combine_op.cc | 2 +- paddle/fluid/operators/load_op.cc | 2 +- paddle/fluid/operators/lod_reset_op.h | 4 +- .../fluid/operators/lod_tensor_to_array_op.cc | 6 +- paddle/fluid/operators/math/context_project.h | 6 +- paddle/fluid/operators/math/im2col_test.cc | 14 +- .../operators/math/math_function_test.cu | 36 +-- .../math/selected_rows_functor_test.cu | 8 +- .../fluid/operators/math/sequence_padding.cu | 4 +- .../operators/math/sequence_padding_test.cc | 4 +- paddle/fluid/operators/math/vol2col_test.cc | 8 +- paddle/fluid/operators/merge_lod_tensor_op.cc | 7 +- .../fluid/operators/mine_hard_examples_op.cc | 3 +- paddle/fluid/operators/multiplex_op.cu | 4 +- paddle/fluid/operators/nccl_op_test.cu.cc | 2 +- paddle/fluid/operators/parallel_do_op.cc | 6 +- paddle/fluid/operators/print_op.cc | 2 +- paddle/fluid/operators/recurrent_op.cc | 8 +- .../reorder_lod_tensor_by_rank_op.cc | 2 +- paddle/fluid/operators/reshape_op.h | 4 +- paddle/fluid/operators/sequence_reshape_op.h | 4 +- paddle/fluid/operators/sequence_slice_op.h | 16 +- .../fluid/operators/shrink_rnn_memory_op.cc | 2 +- paddle/fluid/operators/split_lod_tensor_op.cc | 9 +- paddle/fluid/operators/sum_op.h | 4 +- .../operators/tensor_array_read_write_op.cc | 4 +- paddle/fluid/operators/warpctc_op.h | 5 +- paddle/fluid/pybind/tensor_py.h | 6 +- 53 files changed, 411 insertions(+), 534 deletions(-) mode change 100644 => 120000 paddle/fluid/framework/tensor_util.cu diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index 728a2fb6f33..85dbb39e6fb 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -37,7 +37,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, << " dst_place: " << dst_place; auto* dev_ctx = GetDeviceContext(in.place(), dst_place); dev_ctx->Wait(); - Copy(in, dst_place, *dev_ctx, out); + TensorCopy(in, dst_place, *dev_ctx, out); dev_ctx->Wait(); } diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index c9ba0711755..db6687985df 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -157,8 +157,8 @@ TEST(Operator, CPUtoGPU) { auto dev_ctx = pool.Get(cuda_place); paddle::framework::Tensor output_tensor; - Copy(output2->Get(), paddle::platform::CPUPlace(), *dev_ctx, - &output_tensor); + TensorCopy(output2->Get(), paddle::platform::CPUPlace(), *dev_ctx, + &output_tensor); dev_ctx->Wait(); float* output2_ptr = output_tensor.data(); diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index ebfd54fdc55..23eeb276c07 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -73,8 +73,10 @@ static void CheckTensorNANOrInf(const std::string& name, tensor.type().hash_code() != typeid(double).hash_code()) { return; } - PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); - PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); + PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), + "Tensor %s contains Inf", name); + PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor), + "Tensor %s contains NAN", name); } void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 89768bcfd51..4cf14c8da54 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -46,7 +46,7 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { if (!platform::is_cpu_place(t.place())) { LoDTensor tt; - framework::Copy(t, platform::CPUPlace(), &tt); + framework::TensorCopy(t, platform::CPUPlace(), &tt); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(t.place()); dev_ctx.Wait(); @@ -255,7 +255,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor, } } // the 3st field, Tensor - SerializeToStream(os, static_cast(tensor), dev_ctx); + TensorToStream(os, static_cast(tensor), dev_ctx); } void DeserializeFromStream(std::istream &is, LoDTensor *tensor, @@ -282,7 +282,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, } } // the 3st filed, Tensor - DeserializeFromStream(is, static_cast(tensor), dev_ctx); + TensorFromStream(is, static_cast(tensor), dev_ctx); } std::vector LoDTensor::SplitLoDTensor( @@ -308,14 +308,14 @@ std::vector LoDTensor::SplitLoDTensor( if (lod().empty()) { auto src = Slice(begin, end); auto &dst_place = places[i]; - framework::Copy(src, dst_place, &dst); + framework::TensorCopy(src, dst_place, &dst); } else { auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0); auto &offset = lod_and_offset.second; auto src = Slice(offset.first, offset.second); auto &dst_place = places[i]; - framework::Copy(src, dst_place, &dst); + framework::TensorCopy(src, dst_place, &dst); LoD my_lod; for (auto &l : lod_and_offset.first) { @@ -369,7 +369,7 @@ void LoDTensor::MergeLoDTensor( for (auto *src : lod_tensors) { int end = begin + src->dims()[0]; auto dst = Slice(begin, end); - framework::Copy(*src, dst_place, &dst); + framework::TensorCopy(*src, dst_place, &dst); begin = end; } } diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 948389afb66..94d5a6e9fd9 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -175,8 +175,8 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level, for (size_t ins = 0; ins < num_instances; ins++) { for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) { auto slice = tensor.Slice(elem, elem + 1); - Copy(source.Slice(ins, ins + 1), platform::CPUPlace(), - platform::CPUDeviceContext(), &slice); + TensorCopy(source.Slice(ins, ins + 1), platform::CPUPlace(), + platform::CPUDeviceContext(), &slice); } } return tensor; diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index c1a89a1261c..6a6fa538718 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -291,7 +291,7 @@ class Vector { void CopyToCPU() const { // COPY GPU Data To CPU - Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); + TensorCopy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); WaitPlace(cuda_vec_.place()); } @@ -305,13 +305,14 @@ class Vector { void ImmutableCUDA(platform::Place place) const { if (IsDirty()) { if (IsInCPU()) { - Copy(cpu_vec_, boost::get(place), &cuda_vec_); + TensorCopy(cpu_vec_, boost::get(place), + &cuda_vec_); WaitPlace(place); UnsetFlag(kDirty); SetFlag(kDataInCUDA); } else if (IsInCUDA() && !(place == cuda_vec_.place())) { framework::Tensor tmp; - Copy(cuda_vec_, boost::get(place), &tmp); + TensorCopy(cuda_vec_, boost::get(place), &tmp); WaitPlace(cuda_vec_.place()); cuda_vec_.ShareDataWith(tmp); // Still dirty @@ -322,13 +323,14 @@ class Vector { } else { if (!IsInCUDA()) { // Even data is not dirty. However, data is not in CUDA. Copy data. - Copy(cpu_vec_, boost::get(place), &cuda_vec_); + TensorCopy(cpu_vec_, boost::get(place), + &cuda_vec_); WaitPlace(place); SetFlag(kDataInCUDA); } else if (!(place == cuda_vec_.place())) { framework::Tensor tmp; WaitPlace(cuda_vec_.place()); - Copy(cuda_vec_, boost::get(place), &tmp); + TensorCopy(cuda_vec_, boost::get(place), &tmp); WaitPlace(cuda_vec_.place()); WaitPlace(place); cuda_vec_.ShareDataWith(tmp); diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 1ef0c482111..dc1caa72a4c 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -105,7 +105,7 @@ void BatchReader::ReadNext(std::vector* out) { } } Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); - Copy(buffer_[i][j], platform::CPUPlace(), &dst); + TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst); dst_offset += ins_shape[0]; } out_tensor.set_lod(batch_lod); diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 08c319002d1..504344e937d 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -34,7 +34,7 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, os.write(reinterpret_cast(&height), sizeof(height)); } // the 4st field, Tensor data - SerializeToStream(os, selected_rows.value(), dev_ctx); + TensorToStream(os, selected_rows.value(), dev_ctx); } void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, @@ -62,7 +62,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, selected_rows->set_height(height); } // the 4st field, tensor which contains the data - DeserializeFromStream(is, selected_rows->mutable_value(), dev_ctx); + TensorFromStream(is, selected_rows->mutable_value(), dev_ctx); } } // namespace framework diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 537fb4614ca..9b465b85b0a 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -16,6 +16,76 @@ namespace paddle { namespace framework { + +void TensorCopy(const Tensor& src, const platform::Place& dst_place, + const platform::DeviceContext& ctx, Tensor* dst) { + VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " + << dst_place; + src.check_memory_size(); + + dst->Resize(src.dims()); + dst->set_layout(src.layout()); + auto src_place = src.place(); + auto src_ptr = src.data(); + + auto dst_ptr = dst->mutable_data(dst_place, src.type()); + + auto size = src.numel() * SizeOfType(src.type()); + + if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), dst_ptr, + boost::get(src_place), src_ptr, size); + } +#ifdef PADDLE_WITH_CUDA + else if (platform::is_gpu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + auto src_gpu_place = boost::get(src_place); + auto dst_cpu_place = boost::get(dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); + auto ctx_gpu_place = boost::get(ctx_place); + PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); + memory::Copy( + dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(ctx).stream()); + } else if (platform::is_cpu_place(src_place) && + platform::is_gpu_place(dst_place)) { + auto src_cpu_place = boost::get(src_place); + auto dst_gpu_place = boost::get(dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); + auto ctx_gpu_place = boost::get(ctx_place); + PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); + memory::Copy( + dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, + reinterpret_cast(ctx).stream()); + } else if (platform::is_gpu_place(src_place) && + platform::is_gpu_place(dst_place)) { + auto src_gpu_place = boost::get(src_place); + auto dst_gpu_place = boost::get(dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); + auto ctx_gpu_place = boost::get(ctx_place); + PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); + memory::Copy( + dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif +} + +void TensorCopy(const Tensor& src, const platform::Place& dst_place, + Tensor* dst) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + const platform::DeviceContext* dev_ctx; + if (platform::is_gpu_place(src.place())) { + dev_ctx = pool.Get(src.place()); + } else { + dev_ctx = pool.Get(dst_place); + } + TensorCopy(src, dst_place, *dev_ctx, dst); +} + template struct AnyDTypeVisitor { Predicate predicate_; @@ -69,7 +139,7 @@ struct AnyVisitor : public boost::static_visitor { tmp.mutable_data(cpu); auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); gpuctx->Wait(); - Copy(out, cpu, *gpuctx, &tmp); + TensorCopy(out, cpu, *gpuctx, &tmp); gpuctx->Wait(); return GetResult(tmp, cpu); } @@ -87,7 +157,7 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { return platform::VisitPlace(place, visitor); } -struct HasNANPredicate { +struct ContainsNANPredicate { template auto operator()(const T& eigen_vec) const -> decltype(std::declval().isnan()) { @@ -96,12 +166,12 @@ struct HasNANPredicate { } }; -bool HasNAN(const framework::Tensor& tensor) { - HasNANPredicate predicate; +bool TensorContainsNAN(const framework::Tensor& tensor) { + ContainsNANPredicate predicate; return Any(tensor, predicate); } -struct HasInfPredicate { +struct ContainsInfPredicate { template auto operator()(const T& eigen_vec) const -> decltype(std::declval().isinf()) { @@ -110,10 +180,124 @@ struct HasInfPredicate { } }; -bool HasInf(const framework::Tensor& tensor) { - HasInfPredicate predicate; +bool TensorContainsInf(const framework::Tensor& tensor) { + ContainsInfPredicate predicate; return Any(tensor, predicate); } +void TensorToStream(std::ostream& os, const Tensor& tensor, + const platform::DeviceContext& dev_ctx) { + // TODO(typhoonzero): serialize to ostream + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + os.write(reinterpret_cast(&version), sizeof(version)); + } + { // the 2nd field, tensor description + // int32_t size + // void* protobuf message + proto::VarType::TensorDesc desc; + desc.set_data_type(framework::ToDataType(tensor.type())); + auto dims = framework::vectorize(tensor.dims()); + auto* pb_dims = desc.mutable_dims(); + pb_dims->Resize(static_cast(dims.size()), 0); + std::copy(dims.begin(), dims.end(), pb_dims->begin()); + int32_t size = desc.ByteSize(); + os.write(reinterpret_cast(&size), sizeof(size)); + auto out = desc.SerializeAsString(); + os.write(out.data(), size); + } + { // the 3rd field, tensor data + uint64_t size = tensor.memory_size(); + auto* data_ptr = tensor.data(); + PADDLE_ENFORCE(size < std::numeric_limits::max(), + "Index overflow when writing tensor"); + if (platform::is_gpu_place(tensor.place())) { +#ifdef PADDLE_WITH_CUDA + constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB + std::unique_ptr buf(new char[kBufSize]); + auto& gpu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + uintptr_t data = reinterpret_cast(data_ptr); + while (size != 0) { + size_t size_to_write = std::min(kBufSize, static_cast(size)); + memory::Copy(cpu, buf.get(), + boost::get(tensor.place()), + reinterpret_cast(data), size_to_write, + gpu_dev_ctx.stream()); + gpu_dev_ctx.Wait(); + os.write(buf.get(), size_to_write); + data += size_to_write; + size -= size_to_write; + } +#else + PADDLE_THROW("Unexpected branch"); +#endif + } else { + os.write(static_cast(data_ptr), + static_cast(size)); + } + } +} + +struct DeserializedDataFunctor { + DeserializedDataFunctor(void** buf, Tensor* tensor, + const platform::Place& place) + : buf_(buf), tensor_(tensor), place_(place) {} + + template + void operator()() { + *buf_ = tensor_->mutable_data(place_); + } + + void** buf_; + Tensor* tensor_; + platform::Place place_; +}; + +void TensorFromStream(std::istream& is, Tensor* tensor, + const platform::DeviceContext& dev_ctx) { + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + proto::VarType::TensorDesc desc; + { // int32_t size + // proto buffer + int32_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + is.read(reinterpret_cast(buf.get()), size); + PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), + "Cannot parse tensor desc"); + } + { // read tensor + std::vector dims; + dims.reserve(static_cast(desc.dims().size())); + std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); + tensor->Resize(framework::make_ddim(dims)); + void* buf; + auto ctx = platform::CPUDeviceContext(); + if (platform::is_gpu_place(dev_ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + Tensor cpu_tensor; + cpu_tensor.Resize(framework::make_ddim(dims)); + framework::VisitDataType( + desc.data_type(), + DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); + is.read(static_cast(buf), cpu_tensor.memory_size()); + auto dst_place = dev_ctx.GetPlace(); + framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); +#else + PADDLE_THROW("Unexpected branch"); +#endif + } else { + framework::VisitDataType( + desc.data_type(), + DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); + is.read(static_cast(buf), tensor->memory_size()); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor_util.cu b/paddle/fluid/framework/tensor_util.cu deleted file mode 100644 index 537fb4614ca..00000000000 --- a/paddle/fluid/framework/tensor_util.cu +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/framework/tensor_util.h" - -namespace paddle { -namespace framework { -template -struct AnyDTypeVisitor { - Predicate predicate_; - const Tensor& tensor_; - const DevCtx& ctx_; - Tensor* out_; - - AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, - Tensor* out) - : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} - - template - void operator()() const { - auto t = EigenVector::Flatten(tensor_); - auto o = EigenScalar::From(*out_); - // return any of predicate_(t) is true. - o.device(*ctx_.eigen_device()) = predicate_(t).any(); - } -}; - -template -inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, - const DevCtx& ctx, framework::Tensor* out) { - VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( - predicate, tensor, ctx, out)); -} - -template -struct AnyVisitor : public boost::static_visitor { - const framework::Tensor& tensor_; - Predicate predicate_; - - AnyVisitor(const framework::Tensor& tensor, Predicate predicate) - : tensor_(tensor), predicate_(std::move(predicate)) {} - - template - bool operator()(const Place& place) const { - framework::Tensor out; - out.Resize({1}); - out.mutable_data(place); - auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); - AnyImpl(predicate_, tensor_, *ctx, &out); - return this->GetResult(out, place); - } - - bool GetResult(const framework::Tensor& out, - const platform::CUDAPlace& gpu) const { - platform::CPUPlace cpu; - framework::Tensor tmp; - tmp.Resize({1}); - tmp.mutable_data(cpu); - auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); - gpuctx->Wait(); - Copy(out, cpu, *gpuctx, &tmp); - gpuctx->Wait(); - return GetResult(tmp, cpu); - } - - bool GetResult(const framework::Tensor& out, - const platform::CPUPlace& cpu) const { - return *out.data(); - } -}; - -template -inline bool Any(const framework::Tensor& tensor, Predicate predicate) { - AnyVisitor visitor(tensor, predicate); - auto place = tensor.place(); - return platform::VisitPlace(place, visitor); -} - -struct HasNANPredicate { - template - auto operator()(const T& eigen_vec) const - -> decltype(std::declval().isnan()) { - // Cast eigen_vector to vector of bool. true if is inf. - return eigen_vec.isnan(); - } -}; - -bool HasNAN(const framework::Tensor& tensor) { - HasNANPredicate predicate; - return Any(tensor, predicate); -} - -struct HasInfPredicate { - template - auto operator()(const T& eigen_vec) const - -> decltype(std::declval().isinf()) { - // Cast eigen_vector to vector of bool. true if is inf. - return eigen_vec.isinf(); - } -}; - -bool HasInf(const framework::Tensor& tensor) { - HasInfPredicate predicate; - return Any(tensor, predicate); -} - -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/tensor_util.cu b/paddle/fluid/framework/tensor_util.cu new file mode 120000 index 00000000000..edd88c4e547 --- /dev/null +++ b/paddle/fluid/framework/tensor_util.cu @@ -0,0 +1 @@ +tensor_util.cc \ No newline at end of file diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index f0464d48078..38b6d1c5c46 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -22,106 +22,38 @@ limitations under the License. */ namespace paddle { namespace framework { -/** - * @brief Copy the content of external tensor to a new place. - * - * @param[in] src The external tensor. - * @param[in] dst_place The dst place. - * @param[in] ctx The device context contains device resources. - * - * @note Copy supports CPU <-> GPU, GPU <-> GPU. - */ -inline void Copy(const Tensor& src, const platform::Place& dst_place, - const platform::DeviceContext& ctx, Tensor* dst) { - VLOG(3) << "Copy " << src.dims() << " from " << src.place() << " to " - << dst_place; - src.check_memory_size(); +void TensorCopy(const Tensor& src, const platform::Place& dst_place, + const platform::DeviceContext& ctx, Tensor* dst); +void TensorCopy(const Tensor& src, const platform::Place& dst_place, + Tensor* dst); - dst->Resize(src.dims()); - dst->set_layout(src.layout()); - auto src_place = src.place(); - auto src_ptr = src.data(); +template +void TensorFromVector(const std::vector& src, + const platform::DeviceContext& ctx, Tensor* dst); +template +void TensorFromVector(const std::vector& src, Tensor* dst); - auto dst_ptr = dst->mutable_data(dst_place, src.type()); +template +void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, + std::vector* dst); +template +void TesnorToVector(const Tensor& src, std::vector* dst); - auto size = src.numel() * SizeOfType(src.type()); +bool TensorContainsNAN(const framework::Tensor& tensor); +bool TensorContainsInf(const framework::Tensor& tensor); - if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size); - } -#ifdef PADDLE_WITH_CUDA - else if (platform::is_gpu_place(src_place) && // NOLINT - platform::is_cpu_place(dst_place)) { - auto src_gpu_place = boost::get(src_place); - auto dst_cpu_place = boost::get(dst_place); - auto ctx_place = ctx.GetPlace(); - PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); - auto ctx_gpu_place = boost::get(ctx_place); - PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); - memory::Copy( - dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); - } else if (platform::is_cpu_place(src_place) && - platform::is_gpu_place(dst_place)) { - auto src_cpu_place = boost::get(src_place); - auto dst_gpu_place = boost::get(dst_place); - auto ctx_place = ctx.GetPlace(); - PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); - auto ctx_gpu_place = boost::get(ctx_place); - PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); - memory::Copy( - dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); - } else if (platform::is_gpu_place(src_place) && - platform::is_gpu_place(dst_place)) { - auto src_gpu_place = boost::get(src_place); - auto dst_gpu_place = boost::get(dst_place); - auto ctx_place = ctx.GetPlace(); - PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); - auto ctx_gpu_place = boost::get(ctx_place); - PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); - memory::Copy( - dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); - } -#endif -} +void TensorToStream(std::ostream& os, const Tensor& tensor, + const platform::DeviceContext& dev_ctx); +void TensorFromStream(std::istream& is, Tensor* tensor, + const platform::DeviceContext& dev_ctx); -/** - * @brief Wrapper on - * Copy(const Tensor& src, const platform::Place& dst_place, - * const platform::DeviceContext& ctx, Tensor* dst); - * - * @param[in] src The external tensor. - * @param[in] dst_place The dst place. - * - * @note Copy supports CPU <-> GPU, GPU <-> GPU. - */ -inline void Copy(const Tensor& src, const platform::Place& dst_place, - Tensor* dst) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - const platform::DeviceContext* dev_ctx; - if (platform::is_gpu_place(src.place())) { - dev_ctx = pool.Get(src.place()); - } else { - dev_ctx = pool.Get(dst_place); - } - Copy(src, dst_place, *dev_ctx, dst); -} +// +// The implementation of template functions. +// -/** - * @brief Copy the content of an external vector to a tensor. - * - * @param[in] src The external tensor. - * @param[in] ctx The device context contains device resources. - * - * * @note CopyFromVector will resize dst to an 1D tensor with the same - * size as src. - */ template -inline void CopyFromVector(const std::vector& src, - const platform::DeviceContext& ctx, Tensor* dst) { +void TensorFromVector(const std::vector& src, + const platform::DeviceContext& ctx, Tensor* dst) { auto dst_place = ctx.GetPlace(); auto src_ptr = static_cast(src.data()); platform::CPUPlace src_place; @@ -143,11 +75,8 @@ inline void CopyFromVector(const std::vector& src, #endif } -/** - * @brief CopyFromVector CPU vector -> CPU Tensor - */ template -inline void CopyFromVector(const std::vector& src, Tensor* dst) { +void TensorFromVector(const std::vector& src, Tensor* dst) { platform::CPUPlace dst_place = platform::CPUPlace(); auto src_ptr = static_cast(src.data()); platform::CPUPlace src_place; @@ -158,18 +87,9 @@ inline void CopyFromVector(const std::vector& src, Tensor* dst) { memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); } -/** - * @brief Copy the content of a tensor to a vector - * - * @param[in] src The external tensor. - * @param[in] ctx The device context contains device resources. - * - * * @note CopyFromVector assumes that the tensor has been resized - * before invoking. - */ template -inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx, - std::vector* dst) { +void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, + std::vector* dst) { auto src_ptr = static_cast(src.data()); auto size = src.numel() * sizeof(T); @@ -191,11 +111,8 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx, #endif } -/** - * @brief CopyToVector CPUTensor <-> CPU Vector - */ template -inline void CopyToVector(const Tensor& src, std::vector* dst) { +void TensorToVector(const Tensor& src, std::vector* dst) { auto src_ptr = static_cast(src.data()); auto size = src.numel() * sizeof(T); @@ -209,125 +126,5 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } -// Returns true if a tensor contains NAN, i.e., Not A Number. -bool HasNAN(const framework::Tensor& tensor); - -// Returns true if a tensor contains Inf, i.e., Infinity. -bool HasInf(const framework::Tensor& tensor); - -inline void SerializeToStream(std::ostream& os, const Tensor& tensor, - const platform::DeviceContext& dev_ctx) { - // TODO(typhoonzero): serialize to ostream - { // the 1st field, uint32_t version - constexpr uint32_t version = 0; - os.write(reinterpret_cast(&version), sizeof(version)); - } - { // the 2nd field, tensor description - // int32_t size - // void* protobuf message - proto::VarType::TensorDesc desc; - desc.set_data_type(framework::ToDataType(tensor.type())); - auto dims = framework::vectorize(tensor.dims()); - auto* pb_dims = desc.mutable_dims(); - pb_dims->Resize(static_cast(dims.size()), 0); - std::copy(dims.begin(), dims.end(), pb_dims->begin()); - int32_t size = desc.ByteSize(); - os.write(reinterpret_cast(&size), sizeof(size)); - auto out = desc.SerializeAsString(); - os.write(out.data(), size); - } - { // the 3rd field, tensor data - uint64_t size = tensor.memory_size(); - auto* data_ptr = tensor.data(); - PADDLE_ENFORCE(size < std::numeric_limits::max(), - "Index overflow when writing tensor"); - if (platform::is_gpu_place(tensor.place())) { -#ifdef PADDLE_WITH_CUDA - constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB - std::unique_ptr buf(new char[kBufSize]); - auto& gpu_dev_ctx = - static_cast(dev_ctx); - platform::CPUPlace cpu; - uintptr_t data = reinterpret_cast(data_ptr); - while (size != 0) { - size_t size_to_write = std::min(kBufSize, static_cast(size)); - memory::Copy(cpu, buf.get(), - boost::get(tensor.place()), - reinterpret_cast(data), size_to_write, - gpu_dev_ctx.stream()); - gpu_dev_ctx.Wait(); - os.write(buf.get(), size_to_write); - data += size_to_write; - size -= size_to_write; - } -#else - PADDLE_THROW("Unexpected branch"); -#endif - } else { - os.write(static_cast(data_ptr), - static_cast(size)); - } - } -} - -struct DeserializedDataFunctor { - DeserializedDataFunctor(void** buf, Tensor* tensor, - const platform::Place& place) - : buf_(buf), tensor_(tensor), place_(place) {} - - template - void operator()() { - *buf_ = tensor_->mutable_data(place_); - } - - void** buf_; - Tensor* tensor_; - platform::Place place_; -}; - -inline void DeserializeFromStream(std::istream& is, Tensor* tensor, - const platform::DeviceContext& dev_ctx) { - uint32_t version; - is.read(reinterpret_cast(&version), sizeof(version)); - PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); - proto::VarType::TensorDesc desc; - { // int32_t size - // proto buffer - int32_t size; - is.read(reinterpret_cast(&size), sizeof(size)); - std::unique_ptr buf(new char[size]); - is.read(reinterpret_cast(buf.get()), size); - PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), - "Cannot parse tensor desc"); - } - { // read tensor - std::vector dims; - dims.reserve(static_cast(desc.dims().size())); - std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); - tensor->Resize(framework::make_ddim(dims)); - void* buf; - auto ctx = platform::CPUDeviceContext(); - if (platform::is_gpu_place(dev_ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA - Tensor cpu_tensor; - cpu_tensor.Resize(framework::make_ddim(dims)); - framework::VisitDataType( - desc.data_type(), - DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); - is.read(static_cast(buf), cpu_tensor.memory_size()); - auto dst_place = dev_ctx.GetPlace(); - framework::Copy(cpu_tensor, dst_place, dev_ctx, tensor); -#else - PADDLE_THROW("Unexpected branch"); -#endif - } else { - framework::VisitDataType( - desc.data_type(), - DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); - is.read(static_cast(buf), tensor->memory_size()); - } - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor_util_test.cc b/paddle/fluid/framework/tensor_util_test.cc index dcdbf9d3958..8aebfcb3b62 100644 --- a/paddle/fluid/framework/tensor_util_test.cc +++ b/paddle/fluid/framework/tensor_util_test.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { -TEST(Copy, Tensor) { +TEST(TensorCopy, Tensor) { Tensor src_tensor; Tensor dst_tensor; platform::CPUDeviceContext cpu_ctx((platform::CPUPlace())); @@ -33,7 +33,7 @@ TEST(Copy, Tensor) { src_tensor.set_layout(DataLayout::kAnyLayout); auto cpu_place = new platform::CPUPlace(); - Copy(src_tensor, *cpu_place, &dst_tensor); + TensorCopy(src_tensor, *cpu_place, &dst_tensor); const int* dst_ptr = dst_tensor.data(); ASSERT_NE(src_ptr, dst_ptr); @@ -44,7 +44,7 @@ TEST(Copy, Tensor) { EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout()); Tensor slice_tensor = src_tensor.Slice(1, 2); - Copy(slice_tensor, *cpu_place, &dst_tensor); + TensorCopy(slice_tensor, *cpu_place, &dst_tensor); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); ASSERT_NE(dst_ptr, slice_ptr); @@ -68,11 +68,11 @@ TEST(Copy, Tensor) { // CPU Tensor to GPU Tensor auto gpu_place = new platform::CUDAPlace(0); platform::CUDADeviceContext gpu_ctx(*gpu_place); - Copy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); + TensorCopy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); // GPU Tensor to CPU Tensor auto cpu_place = new platform::CPUPlace(); - Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); + TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); // Sync before Compare Tensors gpu_ctx.Wait(); @@ -85,10 +85,10 @@ TEST(Copy, Tensor) { Tensor slice_tensor = src_tensor.Slice(1, 2); // CPU Slice Tensor to GPU Tensor - Copy(slice_tensor, *gpu_place, gpu_ctx, &gpu_tensor); + TensorCopy(slice_tensor, *gpu_place, gpu_ctx, &gpu_tensor); // GPU Tensor to CPU Tensor - Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); + TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); // Sync before Compare Slice Tensors gpu_ctx.Wait(); @@ -104,7 +104,7 @@ TEST(Copy, Tensor) { #endif } -TEST(CopyFromVector, Tensor) { +TEST(TensorFromVector, Tensor) { using namespace paddle::framework; using namespace paddle::platform; { @@ -114,7 +114,7 @@ TEST(CopyFromVector, Tensor) { // Copy to CPU Tensor cpu_tensor.Resize(make_ddim({3, 3})); auto cpu_place = new paddle::platform::CPUPlace(); - CopyFromVector(src_vec, &cpu_tensor); + TensorFromVector(src_vec, &cpu_tensor); // Compare Tensors const int* cpu_ptr = cpu_tensor.data(); @@ -126,7 +126,7 @@ TEST(CopyFromVector, Tensor) { src_vec.erase(src_vec.begin(), src_vec.begin() + 5); cpu_tensor.Resize(make_ddim({2, 2})); - CopyFromVector(src_vec, &cpu_tensor); + TensorFromVector(src_vec, &cpu_tensor); cpu_ptr = cpu_tensor.data(); src_ptr = src_vec.data(); ASSERT_NE(src_ptr, cpu_ptr); @@ -148,15 +148,15 @@ TEST(CopyFromVector, Tensor) { cpu_tensor.Resize(make_ddim({3, 3})); auto cpu_place = new paddle::platform::CPUPlace(); CPUDeviceContext cpu_ctx(*cpu_place); - CopyFromVector(src_vec, cpu_ctx, &cpu_tensor); + TensorFromVector(src_vec, cpu_ctx, &cpu_tensor); // Copy to GPUTensor gpu_tensor.Resize(make_ddim({3, 3})); auto gpu_place = new paddle::platform::CUDAPlace(); CUDADeviceContext gpu_ctx(*gpu_place); - CopyFromVector(src_vec, gpu_ctx, &gpu_tensor); + TensorFromVector(src_vec, gpu_ctx, &gpu_tensor); // Copy from GPU to CPU tensor for comparison - Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); + TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); // Sync before Compare Tensors gpu_ctx.Wait(); @@ -173,10 +173,10 @@ TEST(CopyFromVector, Tensor) { src_vec.erase(src_vec.begin(), src_vec.begin() + 5); cpu_tensor.Resize(make_ddim({2, 2})); - CopyFromVector(src_vec, cpu_ctx, &cpu_tensor); + TensorFromVector(src_vec, cpu_ctx, &cpu_tensor); gpu_tensor.Resize(make_ddim({2, 2})); - CopyFromVector(src_vec, gpu_ctx, &gpu_tensor); - Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); + TensorFromVector(src_vec, gpu_ctx, &gpu_tensor); + TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); // Sync before Compare Tensors gpu_ctx.Wait(); @@ -196,7 +196,7 @@ TEST(CopyFromVector, Tensor) { #endif } -TEST(CopyToVector, Tensor) { +TEST(TensorToVector, Tensor) { using namespace paddle::framework; using namespace paddle::platform; { @@ -208,7 +208,7 @@ TEST(CopyToVector, Tensor) { CPUPlace place; std::vector dst; - CopyToVector(src, &dst); + TensorToVector(src, &dst); for (int i = 0; i < 3 * 3; ++i) { EXPECT_EQ(src_ptr[i], dst[i]); @@ -220,10 +220,10 @@ TEST(CopyToVector, Tensor) { Tensor gpu_tensor; CUDAPlace place; CUDADeviceContext gpu_ctx(place); - CopyFromVector(src_vec, gpu_ctx, &gpu_tensor); + TensorFromVector(src_vec, gpu_ctx, &gpu_tensor); std::vector dst; - CopyToVector(gpu_tensor, gpu_ctx, &dst); + TensorToVector(gpu_tensor, gpu_ctx, &dst); for (int i = 0; i < 3 * 3; ++i) { EXPECT_EQ(src_vec[i], dst[i]); @@ -232,7 +232,7 @@ TEST(CopyToVector, Tensor) { #endif } -TEST(HasNAN, CPU) { +TEST(TensorContainsNAN, CPU) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; @@ -240,11 +240,12 @@ TEST(HasNAN, CPU) { buf[0] = 0.0; buf[1] = NAN; buf[2] = 0.0; - - ASSERT_TRUE(HasNAN(src)); + ASSERT_TRUE(TensorContainsNAN(src)); + buf[1] = 0.0; + ASSERT_FALSE(TensorContainsNAN(src)); } -TEST(HasInf, CPU) { +TEST(TensorContainsInf, CPU) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; @@ -252,10 +253,12 @@ TEST(HasInf, CPU) { buf[0] = 1.0; buf[1] = INFINITY; buf[2] = 0.0; - ASSERT_TRUE(HasInf(src)); + ASSERT_TRUE(TensorContainsInf(src)); + buf[1] = 1.0; + ASSERT_FALSE(TensorContainsInf(src)); } -TEST(Tensor, SerializeAndDeserialize) { +TEST(Tensor, FromAndToStream) { framework::Tensor src_tensor; int array[6] = {1, 2, 3, 4, 5, 6}; src_tensor.Resize({2, 3}); @@ -268,10 +271,10 @@ TEST(Tensor, SerializeAndDeserialize) { auto place = new platform::CPUPlace(); platform::CPUDeviceContext cpu_ctx(*place); std::ostringstream oss; - SerializeToStream(oss, src_tensor, cpu_ctx); + TensorToStream(oss, src_tensor, cpu_ctx); std::istringstream iss(oss.str()); - DeserializeFromStream(iss, &dst_tensor, cpu_ctx); + TensorFromStream(iss, &dst_tensor, cpu_ctx); int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < 5; ++i) { ASSERT_EQ(dst_ptr[i], array[i]); @@ -288,13 +291,13 @@ TEST(Tensor, SerializeAndDeserialize) { auto gpu_place = new platform::CUDAPlace(); platform::CUDADeviceContext gpu_ctx(*gpu_place); - Copy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); + TensorCopy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); std::ostringstream oss; - SerializeToStream(oss, gpu_tensor, gpu_ctx); + TensorToStream(oss, gpu_tensor, gpu_ctx); std::istringstream iss(oss.str()); - DeserializeFromStream(iss, &dst_tensor, gpu_ctx); + TensorFromStream(iss, &dst_tensor, gpu_ctx); int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < 6; ++i) { diff --git a/paddle/fluid/framework/tensor_util_test.cu b/paddle/fluid/framework/tensor_util_test.cu index 1982b642bcd..d630ec44a2a 100644 --- a/paddle/fluid/framework/tensor_util_test.cu +++ b/paddle/fluid/framework/tensor_util_test.cu @@ -31,7 +31,7 @@ static __global__ void FillInf(float* buf) { buf[2] = 0.5; } -TEST(HasNAN, GPU) { +TEST(TensorContainsNAN, GPU) { Tensor tensor; platform::CUDAPlace gpu(0); auto& pool = platform::DeviceContextPool::Instance(); @@ -39,10 +39,10 @@ TEST(HasNAN, GPU) { float* buf = tensor.mutable_data({3}, gpu); FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); cuda_ctx->Wait(); - ASSERT_TRUE(HasNAN(tensor)); + ASSERT_TRUE(TensorContainsNAN(tensor)); } -TEST(HasInf, GPU) { +TEST(TensorContainsInf, GPU) { Tensor tensor; platform::CUDAPlace gpu(0); auto& pool = platform::DeviceContextPool::Instance(); @@ -50,7 +50,7 @@ TEST(HasInf, GPU) { float* buf = tensor.mutable_data({3}, gpu); FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); cuda_ctx->Wait(); - ASSERT_TRUE(HasInf(tensor)); + ASSERT_TRUE(TensorContainsInf(tensor)); } } // namespace framework diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 606a93e13be..3adc260caf5 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -64,7 +64,6 @@ class ThreadPool { Task task([fn]() -> std::unique_ptr { try { fn(); - return nullptr; } catch (platform::EnforceNotMet ex) { return std::unique_ptr( new platform::EnforceNotMet(ex)); @@ -73,6 +72,7 @@ class ThreadPool { << "Unexpected exception is catched in thread pool. All " "throwable exception in Fluid should be an EnforceNotMet."; } + return nullptr; }); std::future> f = task.get_future(); tasks_.push(std::move(task)); diff --git a/paddle/fluid/operators/array_operator.h b/paddle/fluid/operators/array_operator.h index d0fc1533470..dbcc7abb099 100644 --- a/paddle/fluid/operators/array_operator.h +++ b/paddle/fluid/operators/array_operator.h @@ -42,7 +42,7 @@ class ArrayOp : public framework::OperatorBase { if (platform::is_gpu_place(i_tensor.place())) { // FIXME: Avoid copy from GPU to CPU framework::Tensor t; - framework::Copy(i_tensor, platform::CPUPlace(), dev_ctx, &t); + framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t); dev_ctx.Wait(); offset = static_cast(*t.data()); } else { diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index f59bfad6cca..5db2e4540ef 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -112,8 +112,8 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::Copy(x[x_idx].Slice(start_offset, end_offset), place, - dev_ctx, &slice); + framework::TensorCopy(x[x_idx].Slice(start_offset, end_offset), place, + dev_ctx, &slice); out_offset += len; } } diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index e21dc6d77f3..39ae3c0040d 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -45,7 +45,7 @@ class AssignFunctor { out_rows.set_height(rows.height()); auto &t = rows.value(); auto *m = out_rows.mutable_value(); - framework::Copy(t, t.place(), dev_ctx_, m); + framework::TensorCopy(t, t.place(), dev_ctx_, m); } template @@ -57,7 +57,7 @@ class AssignFunctor { void copy_tensor(const framework::LoDTensor &lod_tensor, framework::LoDTensor *out) const { auto &out_tensor = *out; - Copy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); + TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); out_tensor.set_lod(lod_tensor.lod()); } diff --git a/paddle/fluid/operators/assign_value_op.h b/paddle/fluid/operators/assign_value_op.h index 90c9496a3c1..d51b215a083 100644 --- a/paddle/fluid/operators/assign_value_op.h +++ b/paddle/fluid/operators/assign_value_op.h @@ -41,7 +41,7 @@ class AssignValueKernel : public framework::OpKernel { break; } auto values = ctx.Attr>(value_name); - framework::CopyFromVector(values, ctx.device_context(), out); + framework::TensorFromVector(values, ctx.device_context(), out); out->Resize(framework::make_ddim(shape)); } }; diff --git a/paddle/fluid/operators/beam_search_decode_op.h b/paddle/fluid/operators/beam_search_decode_op.h index 40147ce1eb2..3cc6ed31057 100644 --- a/paddle/fluid/operators/beam_search_decode_op.h +++ b/paddle/fluid/operators/beam_search_decode_op.h @@ -232,12 +232,12 @@ void BeamSearchDecoder::ConvertSentenceVectorToLodTensor( id_tensor->set_lod(lod); id_tensor->Resize({static_cast(id_data.size())}); id_tensor->mutable_data(paddle::platform::CPUPlace()); - framework::CopyFromVector(id_data, cpu_ctx, id_tensor); + framework::TensorFromVector(id_data, cpu_ctx, id_tensor); score_tensor->set_lod(lod); score_tensor->Resize({static_cast(score_data.size())}); score_tensor->mutable_data(paddle::platform::CPUPlace()); - framework::CopyFromVector(score_data, cpu_ctx, score_tensor); + framework::TensorFromVector(score_data, cpu_ctx, score_tensor); } template diff --git a/paddle/fluid/operators/detection_output_op.h b/paddle/fluid/operators/detection_output_op.h index 0aa5fc010de..af9081c9343 100644 --- a/paddle/fluid/operators/detection_output_op.h +++ b/paddle/fluid/operators/detection_output_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -Indicesou may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + Indicesou may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" @@ -98,16 +98,16 @@ class DetectionOutputKernel : public framework::OpKernel { T* conf_data = conf_tensor.data(); if (platform::is_gpu_place(context.GetPlace())) { loc_cpu.mutable_data(loc_tensor.dims(), platform::CPUPlace()); - framework::Copy(loc_tensor, platform::CPUPlace(), - context.device_context(), &loc_cpu); + framework::TensorCopy(loc_tensor, platform::CPUPlace(), + context.device_context(), &loc_cpu); loc_data = loc_cpu.data(); conf_cpu.mutable_data(conf_tensor.dims(), platform::CPUPlace()); - framework::Copy(conf_tensor, platform::CPUPlace(), - context.device_context(), &conf_cpu); + framework::TensorCopy(conf_tensor, platform::CPUPlace(), + context.device_context(), &conf_cpu); conf_data = conf_cpu.data(); priorbox_cpu.mutable_data(in_priorbox->dims(), platform::CPUPlace()); - framework::Copy(*in_priorbox, platform::CPUPlace(), - context.device_context(), &priorbox_cpu); + framework::TensorCopy(*in_priorbox, platform::CPUPlace(), + context.device_context(), &priorbox_cpu); priorbox_data = priorbox_cpu.data(); } // get decode bboxes @@ -158,8 +158,8 @@ class DetectionOutputKernel : public framework::OpKernel { batch_size, all_indices, all_decoded_bboxes, out_data); if (platform::is_gpu_place(context.GetPlace())) { - framework::Copy(out_cpu, platform::CUDAPlace(), context.device_context(), - out); + framework::TensorCopy(out_cpu, platform::CUDAPlace(), + context.device_context(), out); } } }; diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 953d75adae5..2c2d5c7c42c 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -126,7 +126,8 @@ class ExpandGradKernel : public framework::OpKernel { auto* in0 = context.Input(framework::GradVarName("Out")); auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); - framework::Copy(*in0, context.GetPlace(), context.device_context(), out0); + framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), + out0); } else { switch (dims) { REP_EXPAND_GRAD_TEMPLATE(72) diff --git a/paddle/fluid/operators/feed_op.cc b/paddle/fluid/operators/feed_op.cc index 438d9754298..90c31877f6a 100644 --- a/paddle/fluid/operators/feed_op.cc +++ b/paddle/fluid/operators/feed_op.cc @@ -57,7 +57,7 @@ class FeedOp : public framework::OperatorBase { if (platform::is_same_place(feed_item.place(), place)) { out_item->ShareDataWith(feed_item); } else { - framework::Copy(feed_item, place, dev_ctx, out_item); + framework::TensorCopy(feed_item, place, dev_ctx, out_item); } out_item->set_lod(feed_item.lod()); } diff --git a/paddle/fluid/operators/fetch_op.cc b/paddle/fluid/operators/fetch_op.cc index 2684e646340..d66f01d1b7c 100644 --- a/paddle/fluid/operators/fetch_op.cc +++ b/paddle/fluid/operators/fetch_op.cc @@ -56,7 +56,7 @@ class FetchOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(src_item.place()); - Copy(src_item, platform::CPUPlace(), dev_ctx, &dst_item); + TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait(); dst_item.set_lod(src_item.lod()); diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index c505c739d46..3b4b4092311 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -74,7 +74,7 @@ class FillOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::Copy(tensor, place, dev_ctx, &out); + framework::TensorCopy(tensor, place, dev_ctx, &out); } } }; diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 84f5a40aacb..605b5c258ca 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -196,7 +196,7 @@ class LayerNormGradKernel : public framework::OpKernel { // dy_dx ElementwiseComputeEx, DeviceContext, T>( ctx, &d_y, scale, /*axis*/ 1, MulFunctor(), &temp); - framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x); + framework::TensorCopy(temp, ctx.GetPlace(), ctx.device_context(), d_x); // dy_dmean_dx row_mean(dev_ctx, temp, &temp_vec); @@ -208,7 +208,7 @@ class LayerNormGradKernel : public framework::OpKernel { ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); } else { // dy_dx - framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); + framework::TensorCopy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); // dy_dmean_dx row_mean(dev_ctx, d_y, &temp_vec); diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index ba8fc4a6836..e5353144e91 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase { out_var->Clear(); tensor = out_var->GetMutable(); tensor->set_lod(cpu_tensor.lod()); - Copy(cpu_tensor, place, dev_ctx, tensor); + TensorCopy(cpu_tensor, place, dev_ctx, tensor); } } } diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index d72b7a7eb96..05f809ac562 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -55,7 +55,7 @@ class LoadOp : public framework::OperatorBase { out_var->Clear(); tensor = out_var->GetMutable(); tensor->set_lod(cpu_tensor.lod()); - Copy(cpu_tensor, place, dev_ctx, tensor); + TensorCopy(cpu_tensor, place, dev_ctx, tensor); } } }; diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index e612bc2d367..8186d4f8262 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -33,8 +33,8 @@ class LoDResetKernel : public framework::OpKernel { auto* lod = lod_t->data(); if (platform::is_gpu_place(ctx.GetPlace())) { framework::Tensor lod_cpu; - framework::Copy(*lod_t, platform::CPUPlace(), ctx.device_context(), - &lod_cpu); + framework::TensorCopy(*lod_t, platform::CPUPlace(), + ctx.device_context(), &lod_cpu); lod = lod_cpu.data(); } level0 = std::vector(lod, lod + lod_t->numel()); diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index b5e778a5811..543495ce4e6 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -94,9 +94,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::Copy(x.Slice(static_cast(each_range.begin), - static_cast(each_range.end)), - x.place(), dev_ctx, &slice); + framework::TensorCopy(x.Slice(static_cast(each_range.begin), + static_cast(each_range.end)), + x.place(), dev_ctx, &slice); offset += len; } } diff --git a/paddle/fluid/operators/math/context_project.h b/paddle/fluid/operators/math/context_project.h index 83f6ae45fca..4da94383af6 100644 --- a/paddle/fluid/operators/math/context_project.h +++ b/paddle/fluid/operators/math/context_project.h @@ -149,7 +149,8 @@ class ContextProjectFunctor { Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); Tensor w_sub = padding_data.Slice(k, k + padding_size); - framework::Copy(w_sub, context.GetPlace(), context, &out_t_sub); + framework::TensorCopy(w_sub, context.GetPlace(), context, + &out_t_sub); } } if (down_pad > 0) { // add down pad @@ -179,7 +180,8 @@ class ContextProjectFunctor { (down_pad_begin_row + t) * context_length); Tensor w_sub = padding_data.Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); - framework::Copy(w_sub, context.GetPlace(), context, &out_t_sub); + framework::TensorCopy(w_sub, context.GetPlace(), context, + &out_t_sub); } } out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/fluid/operators/math/im2col_test.cc b/paddle/fluid/operators/math/im2col_test.cc index 30519253154..b3978536bca 100644 --- a/paddle/fluid/operators/math/im2col_test.cc +++ b/paddle/fluid/operators/math/im2col_test.cc @@ -62,7 +62,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - Copy(input_tmp, *place, *context, &input); + TensorCopy(input_tmp, *place, *context, &input); } output_cfo.mutable_data( {1, filter_size, filter_size, output_height, output_width}, *place); @@ -87,7 +87,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output_cfo.data(); } else { - Copy(output_cfo, paddle::platform::CPUPlace(), *context, &output_tmp); + TensorCopy(output_cfo, paddle::platform::CPUPlace(), *context, &output_tmp); out_cfo_ptr = output_tmp.data(); } for (int i = 0; i < 6; ++i) { @@ -98,7 +98,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { out_ocf_ptr = output_ocf.data(); } else { - Copy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp); + TensorCopy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp); out_ocf_ptr = output_tmp.data(); } @@ -119,7 +119,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - Copy(input_tmp, *place, *context, &input); + TensorCopy(input_tmp, *place, *context, &input); } col2im(*context, output_cfo, dilation, stride, padding, &input); @@ -128,7 +128,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); } else { - Copy(input, paddle::platform::CPUPlace(), *context, &input_tmp); + TensorCopy(input, paddle::platform::CPUPlace(), *context, &input_tmp); in_ptr = input_tmp.data(); } for (int i = 0; i < 6; ++i) { @@ -140,7 +140,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - Copy(input_tmp, *place, *context, &input); + TensorCopy(input_tmp, *place, *context, &input); } col2im_ocf(*context, output_ocf, dilation, stride, padding, &input); @@ -148,7 +148,7 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); } else { - Copy(input, paddle::platform::CPUPlace(), *context, &input_tmp); + TensorCopy(input, paddle::platform::CPUPlace(), *context, &input_tmp); in_ptr = input_tmp.data(); } for (int i = 0; i < 6; ++i) { diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index f333c6c98ed..207d6a87bce 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -29,15 +29,15 @@ TEST(math_function, notrans_mul_trans) { auto* gpu_place = new paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::Copy(input1, *gpu_place, context, &input2_gpu); + paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); + paddle::framework::TensorCopy(input1, *gpu_place, context, &input2_gpu); out_gpu.mutable_data({2, 2}, *gpu_place); paddle::operators::math::matmul( context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); - paddle::framework::Copy(out_gpu, *cpu_place, context, &out); + paddle::framework::TensorCopy(out_gpu, *cpu_place, context, &out); float* out_ptr = out.data(); context.Wait(); @@ -63,15 +63,15 @@ TEST(math_function, trans_mul_notrans) { auto* gpu_place = new paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::Copy(input1, *gpu_place, context, &input2_gpu); + paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); + paddle::framework::TensorCopy(input1, *gpu_place, context, &input2_gpu); out_gpu.mutable_data({3, 3}, *gpu_place); paddle::operators::math::matmul( context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); - paddle::framework::Copy(out_gpu, *cpu_place, context, &out); + paddle::framework::TensorCopy(out_gpu, *cpu_place, context, &out); float* out_ptr = out.data(); context.Wait(); @@ -112,9 +112,9 @@ TEST(math_function, gemm_notrans_cublas) { auto* gpu_place = new paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::Copy(input2, *gpu_place, context, &input2_gpu); - paddle::framework::Copy(input3, *gpu_place, context, &input3_gpu); + paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); + paddle::framework::TensorCopy(input2, *gpu_place, context, &input2_gpu); + paddle::framework::TensorCopy(input3, *gpu_place, context, &input3_gpu); float* a = input1_gpu.data(); float* b = input2_gpu.data(); float* c = input3_gpu.mutable_data(*gpu_place); @@ -122,7 +122,7 @@ TEST(math_function, gemm_notrans_cublas) { paddle::operators::math::gemm( context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); - paddle::framework::Copy(input3_gpu, *cpu_place, context, &input3); + paddle::framework::TensorCopy(input3_gpu, *cpu_place, context, &input3); // numpy code: // a = np.arange(6).reshape(2, 3) @@ -167,9 +167,9 @@ TEST(math_function, gemm_trans_cublas) { auto* gpu_place = new paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::Copy(input2, *gpu_place, context, &input2_gpu); - paddle::framework::Copy(input3, *gpu_place, context, &input3_gpu); + paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); + paddle::framework::TensorCopy(input2, *gpu_place, context, &input2_gpu); + paddle::framework::TensorCopy(input3, *gpu_place, context, &input3_gpu); float* a = input1_gpu.data(); float* b = input2_gpu.data(); float* c = input3_gpu.mutable_data(*gpu_place); @@ -177,7 +177,7 @@ TEST(math_function, gemm_trans_cublas) { paddle::operators::math::gemm( context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); - paddle::framework::Copy(input3_gpu, *cpu_place, context, &input3); + paddle::framework::TensorCopy(input3_gpu, *cpu_place, context, &input3); context.Wait(); EXPECT_EQ(input3_ptr[0], 0); @@ -218,15 +218,15 @@ void GemvTest(int m, int n, bool trans) { } paddle::platform::CUDADeviceContext context(*gpu_place); - paddle::framework::Copy(mat_a, *gpu_place, context, &g_mat_a); - paddle::framework::Copy(vec_b, *gpu_place, context, &g_vec_b); + paddle::framework::TensorCopy(mat_a, *gpu_place, context, &g_mat_a); + paddle::framework::TensorCopy(vec_b, *gpu_place, context, &g_vec_b); paddle::operators::math::gemv( context, trans, static_cast(m), static_cast(n), 1., g_data_a, g_data_b, 0., g_data_c); - paddle::framework::Copy(g_vec_c, paddle::platform::CPUPlace(), context, - &vec_c); + paddle::framework::TensorCopy(g_vec_c, paddle::platform::CPUPlace(), context, + &vec_c); if (!trans) { for (int i = 0; i < m; ++i) { diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu index cefe239bd28..942d9b13fc1 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu @@ -67,7 +67,7 @@ TEST(selected_rows_functor, gpu_add) { EXPECT_EQ(out_rows[6], 9); Tensor out_cpu; - Copy(*out_value, cpu_place, ctx, &out_cpu); + TensorCopy(*out_value, cpu_place, ctx, &out_cpu); ctx.Wait(); auto* out_cpu_data = out_cpu.data(); @@ -94,7 +94,7 @@ TEST(selected_rows_functor, gpu_add) { add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); Tensor tensor2_cpu; - Copy(*tensor2, cpu_place, ctx, &tensor2_cpu); + TensorCopy(*tensor2, cpu_place, ctx, &tensor2_cpu); ctx.Wait(); auto* tensor2_cpu_data = tensor2_cpu.data(); @@ -167,7 +167,7 @@ TEST(selected_rows_functor, gpu_add_to) { EXPECT_EQ(out_rows[6], 9); Tensor out_cpu; - Copy(*out_value, cpu_place, ctx, &out_cpu); + TensorCopy(*out_value, cpu_place, ctx, &out_cpu); ctx.Wait(); auto* out_cpu_data = out_cpu.data(); @@ -191,7 +191,7 @@ TEST(selected_rows_functor, gpu_add_to) { add_to_tensor_functor(ctx, *output, tensor1.get()); Tensor tensor1_cpu; - Copy(*tensor1, cpu_place, ctx, &tensor1_cpu); + TensorCopy(*tensor1, cpu_place, ctx, &tensor1_cpu); ctx.Wait(); auto* tensor1_cpu_data = tensor1_cpu.data(); diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 9eb52f6fd92..c044e6fc32b 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -97,7 +97,7 @@ class PaddingLoDTensorFunctor { "width of sequence in LoDTensor seq."); if (!norm_by_times && num_sequences == 1UL) { - Copy(seq, context.GetPlace(), context, &padding); + TensorCopy(seq, context.GetPlace(), context, &padding); padding.Resize(padding_dims); return; } @@ -172,7 +172,7 @@ class UnpaddingLoDTensorFunctor { "width of sequence in LoDTensor seq."); if (!norm_by_times && num_sequences == 1UL) { - Copy(padding, context.GetPlace(), context, &seq); + TensorCopy(padding, context.GetPlace(), context, &seq); seq.Resize(seq_dims); return; } diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index e1177fb0d77..bece46e7537 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -40,7 +40,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, if (paddle::platform::is_cpu_place(*place)) { seq = cpu_seq; } else { - Copy(cpu_seq, *place, *context, &seq); + TensorCopy(cpu_seq, *place, *context, &seq); seq.set_lod(lod); } @@ -63,7 +63,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, if (paddle::platform::is_cpu_place(*place)) { cpu_seq_back = seq_back; } else { - Copy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back); + TensorCopy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back); cpu_seq_back.set_lod(lod); } diff --git a/paddle/fluid/operators/math/vol2col_test.cc b/paddle/fluid/operators/math/vol2col_test.cc index 751d3ef19a2..eb91f862e39 100644 --- a/paddle/fluid/operators/math/vol2col_test.cc +++ b/paddle/fluid/operators/math/vol2col_test.cc @@ -71,7 +71,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - Copy(input_tmp, *place, *context, &input); + paddle::framework::TensorCopy(input_tmp, *place, *context, &input); } output.mutable_data({1, filter_size, filter_size, filter_size, output_depth, output_height, output_width}, @@ -85,7 +85,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output.data(); } else { - Copy(output, paddle::platform::CPUPlace(), *context, &output_tmp); + TensorCopy(output, paddle::platform::CPUPlace(), *context, &output_tmp); out_cfo_ptr = output_tmp.data(); } @@ -99,7 +99,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - Copy(input_tmp, *place, *context, &input); + TensorCopy(input_tmp, *place, *context, &input); } paddle::operators::math::Col2VolFunctor col2vol; @@ -109,7 +109,7 @@ void testVol2col() { if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); } else { - Copy(input, paddle::platform::CPUPlace(), *context, &input_tmp); + TensorCopy(input, paddle::platform::CPUPlace(), *context, &input_tmp); in_ptr = input_tmp.data(); } diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index 42ebc8e471b..4ebf20cbba6 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -51,7 +51,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { cpu_mask->ShareDataWith(mask); } else if (platform::is_gpu_place(mask.place())) { #ifdef PADDLE_WITH_CUDA - framework::Copy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); + framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx, + cpu_mask.get()); #else PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); #endif @@ -106,8 +107,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { continue; } auto slice = out->Slice(out_offset, out_offset + len); - framework::Copy(input->Slice(start_offset, end_offset), place, dev_ctx, - &slice); + framework::TensorCopy(input->Slice(start_offset, end_offset), place, + dev_ctx, &slice); out_offset += len; (*in_idx) += 1; } diff --git a/paddle/fluid/operators/mine_hard_examples_op.cc b/paddle/fluid/operators/mine_hard_examples_op.cc index 2128979faee..b7e9f4e2248 100644 --- a/paddle/fluid/operators/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/mine_hard_examples_op.cc @@ -67,7 +67,8 @@ class MineHardExamplesKernel : public framework::OpKernel { auto out_match_indices = ctx.Output("UpdatedMatchIndices"); - framework::Copy(*in_matched_indices, ctx.GetPlace(), out_match_indices); + framework::TensorCopy(*in_matched_indices, ctx.GetPlace(), + out_match_indices); int batch_size = in_matched_indices->dims()[0]; int prior_num = in_matched_indices->dims()[1]; diff --git a/paddle/fluid/operators/multiplex_op.cu b/paddle/fluid/operators/multiplex_op.cu index cb89eeecfb2..45a25507935 100644 --- a/paddle/fluid/operators/multiplex_op.cu +++ b/paddle/fluid/operators/multiplex_op.cu @@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel { auto cols = ins[0]->numel() / rows; // copy index to cpu Tensor index_t_cpu; - Copy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); + TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); auto* index = index_t_cpu.data(); auto stream = ctx.cuda_device_context().stream(); platform::CUDAPlace place = boost::get(ctx.GetPlace()); @@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto cols = ins[0]->numel() / rows; // copy index to cpu Tensor index_t_cpu; - Copy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); + TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); auto* index = index_t_cpu.data(); auto stream = ctx.cuda_device_context().stream(); diff --git a/paddle/fluid/operators/nccl_op_test.cu.cc b/paddle/fluid/operators/nccl_op_test.cu.cc index 24e30f54a13..b4021a5dacd 100644 --- a/paddle/fluid/operators/nccl_op_test.cu.cc +++ b/paddle/fluid/operators/nccl_op_test.cu.cc @@ -98,7 +98,7 @@ class NCCLTester : public ::testing::Test { send_tensor->mutable_data(kDims, place); std::vector send_vector(f::product(kDims), gpu_id); - paddle::framework::CopyFromVector(send_vector, *ctx, send_tensor); + paddle::framework::TensorFromVector(send_vector, *ctx, send_tensor); ctx->Wait(); VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel(); } diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index 88c83ee213f..b21f9937ef5 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -78,7 +78,7 @@ inline void CopyOrShare(const framework::Variable &src, dst->GetMutable()->ShareDataWith(src.Get()); dst->GetMutable()->set_lod(src.Get().lod()); } else { - Copy(src.Get(), dst_place, dst->GetMutable()); + TensorCopy(src.Get(), dst_place, dst->GetMutable()); } } else if (src.IsType()) { auto &src_sr = src.Get(); @@ -88,7 +88,7 @@ inline void CopyOrShare(const framework::Variable &src, dst_sr->mutable_value()->ShareDataWith(src_sr.value()); dst_sr->set_rows(src_sr.rows()); } else { - Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); + TensorCopy(src_sr.value(), dst_place, dst_sr->mutable_value()); } } else { PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); @@ -146,7 +146,7 @@ class ParallelDoOp : public framework::OperatorBase { auto &place = places[i]; auto *sub_scope = sub_scopes[i]; auto *dst = sub_scope->Var(param)->GetMutable(); - framework::Copy(src, place, dst); + framework::TensorCopy(src, place, dst); } } WaitOnPlaces(places); diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index 7fa2b060afd..fc09b4aa1da 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -179,7 +179,7 @@ class TensorPrintOp : public framework::OperatorBase { } else { // copy data to cpu to print platform::CPUPlace place; - framework::Copy(in_tensor, place, &printed_tensor); + framework::TensorCopy(in_tensor, place, &printed_tensor); } Formater formater; diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 8435d6bcf07..00241e76821 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -291,7 +291,7 @@ class RecurrentOp : public RecurrentBase { auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1); // Explicit copy output since the local RNN scope can be destroyed // early. - framework::Copy(src_tensor, place, dev_ctx, &dst_out); + framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out); }); scopes.Next(); @@ -378,7 +378,7 @@ class RecurrentGradOp : public RecurrentBase { auto *cur_grad_var = cur_scope.Var(cur_grad); auto cur_grad_tensor = cur_grad_var->GetMutable(); - framework::Copy(ex_tensor, place, dev_ctx, cur_grad_tensor); + framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor); } } @@ -452,7 +452,7 @@ class RecurrentGradOp : public RecurrentBase { } auto dst = outside->Slice(seq_offset, seq_offset + 1); - framework::Copy(inside, place, dev_ctx, &dst); + framework::TensorCopy(inside, place, dev_ctx, &dst); }); VLOG(5) << "Link outside gradient finished "; @@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase { framework::LoDTensor *outside) { outside->Resize(inside.dims()); outside->mutable_data(place, inside.type()); - framework::Copy(inside, place, dev_ctx, outside); + framework::TensorCopy(inside, place, dev_ctx, outside); }); VLOG(5) << "Link initialize state gradient finished "; } diff --git a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc index b0df932f436..5c3e1f5678d 100644 --- a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc @@ -170,7 +170,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::Copy(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); + framework::TensorCopy(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); out_offset += len; return out_offset; } diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index c01100ef4df..1357bce4b7e 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -28,7 +28,7 @@ class ReshapeKernel : public framework::OpKernel { auto* in = ctx.Input("X"); auto out_dims = out->dims(); out->mutable_data(ctx.GetPlace()); - framework::Copy(*in, ctx.GetPlace(), ctx.device_context(), out); + framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); out->Resize(out_dims); } }; @@ -42,7 +42,7 @@ class ReshapeGradKernel : public framework::OpKernel { d_x->mutable_data(ctx.GetPlace()); auto in_dims = d_x->dims(); - framework::Copy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); d_x->Resize(in_dims); } }; diff --git a/paddle/fluid/operators/sequence_reshape_op.h b/paddle/fluid/operators/sequence_reshape_op.h index f0b5be0218c..2893808ee9c 100644 --- a/paddle/fluid/operators/sequence_reshape_op.h +++ b/paddle/fluid/operators/sequence_reshape_op.h @@ -61,7 +61,7 @@ class SequenceReshapeKernel : public framework::OpKernel { } } - framework::Copy(*in, context.GetPlace(), out); + framework::TensorCopy(*in, context.GetPlace(), out); out->Resize({static_cast(out->lod()[0].back()), out_width}); } }; @@ -77,7 +77,7 @@ class SequenceReshapeGradKernel : public framework::OpKernel { context.Output(framework::GradVarName("X")); xg_tensor_ptr->mutable_data(context.GetPlace()); - framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr); + framework::TensorCopy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr); xg_tensor_ptr->Resize(x_tensor_ptr->dims()); } }; diff --git a/paddle/fluid/operators/sequence_slice_op.h b/paddle/fluid/operators/sequence_slice_op.h index 4f6d70483ec..b9c565cac95 100644 --- a/paddle/fluid/operators/sequence_slice_op.h +++ b/paddle/fluid/operators/sequence_slice_op.h @@ -66,13 +66,13 @@ class SequenceSliceOpKernel : public framework::OpKernel { if (platform::is_gpu_place(ctx.GetPlace())) { offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); - framework::Copy(*offset, platform::CPUPlace(), ctx.device_context(), - &offset_cpu); + framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(), + &offset_cpu); offset_data = offset_cpu.data(); length_cpu.mutable_data(length->dims(), platform::CPUPlace()); - framework::Copy(*length, platform::CPUPlace(), ctx.device_context(), - &length_cpu); + framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(), + &length_cpu); length_data = length_cpu.data(); } @@ -127,13 +127,13 @@ class SequenceSliceGradOpKernel : public framework::OpKernel { if (platform::is_gpu_place(ctx.GetPlace())) { offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); - framework::Copy(*offset, platform::CPUPlace(), ctx.device_context(), - &offset_cpu); + framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(), + &offset_cpu); offset_data = offset_cpu.data(); length_cpu.mutable_data(length->dims(), platform::CPUPlace()); - framework::Copy(*length, platform::CPUPlace(), ctx.device_context(), - &length_cpu); + framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(), + &length_cpu); length_data = length_cpu.data(); } diff --git a/paddle/fluid/operators/shrink_rnn_memory_op.cc b/paddle/fluid/operators/shrink_rnn_memory_op.cc index 183982f90fb..a1871a8e7fb 100644 --- a/paddle/fluid/operators/shrink_rnn_memory_op.cc +++ b/paddle/fluid/operators/shrink_rnn_memory_op.cc @@ -133,7 +133,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { auto &dout_tensor = dout_var->Get(); auto height = dout_tensor.dims()[0]; auto slice = dx_tensor.Slice(0, static_cast(height)); - framework::Copy(dout_tensor, dout_tensor.place(), dev_ctx, &slice); + framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice); if (dx_tensor.dims()[0] > height) { auto rest_tensor = dx_tensor.Slice( static_cast(height), static_cast(dx_tensor.dims()[0])); diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index 1c5d647600d..3222cce2399 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -55,7 +55,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { cpu_mask->ShareDataWith(mask); } else if (platform::is_gpu_place(mask.place())) { #ifdef PADDLE_WITH_CUDA - framework::Copy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); + framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx, + cpu_mask.get()); #else PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); #endif @@ -113,9 +114,9 @@ class SplitLoDTensorOp : public framework::OperatorBase { // out[offset: offset+len] = x[each_range.begin: each_range.end] auto slice = out->Slice(static_cast(offset), static_cast(offset + len)); - framework::Copy(x.Slice(static_cast(each_range.begin), - static_cast(each_range.end)), - x.place(), dev_ctx, &slice); + framework::TensorCopy(x.Slice(static_cast(each_range.begin), + static_cast(each_range.end)), + x.place(), dev_ctx, &slice); offset += len; } } diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index c9f22237d93..48b2d2779ae 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -137,8 +137,8 @@ class SumKernel : public framework::OpKernel { out_array.resize(i + 1); } if (out_array[i].numel() == 0) { - framework::Copy(in_array[i], in_array[i].place(), - context.device_context(), &out_array[i]); + framework::TensorCopy(in_array[i], in_array[i].place(), + context.device_context(), &out_array[i]); out_array[i].set_lod(in_array[i].lod()); } else { PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod()); diff --git a/paddle/fluid/operators/tensor_array_read_write_op.cc b/paddle/fluid/operators/tensor_array_read_write_op.cc index 9b484cda121..2636812c429 100644 --- a/paddle/fluid/operators/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/tensor_array_read_write_op.cc @@ -45,7 +45,7 @@ class WriteToArrayOp : public ArrayOp { platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - Copy(x_tensor, place, dev_ctx, out_tensor); + TensorCopy(x_tensor, place, dev_ctx, out_tensor); out_tensor->set_lod(x_tensor.lod()); } else { VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so " @@ -138,7 +138,7 @@ class ReadFromArrayOp : public ArrayOp { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::Copy(x_array[offset], place, dev_ctx, out_tensor); + framework::TensorCopy(x_array[offset], place, dev_ctx, out_tensor); out_tensor->set_lod(x_array[offset].lod()); } else { VLOG(10) << "offset " << offset << " >= " << x_array.size(); diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index aefb58bdcd1..3e3e3089315 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -185,7 +185,8 @@ class WarpCTCKernel : public framework::OpKernel { // warpctc accesses labels in CPU memory Tensor warpctc_label; - Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); + TensorCopy(*label, platform::CPUPlace(), ctx.device_context(), + &warpctc_label); const int* warpctc_label_data = warpctc_label.data(); // warpctc stores loss in CPU memory Tensor warpctc_loss; @@ -200,7 +201,7 @@ class WarpCTCKernel : public framework::OpKernel { sequence_width, num_sequences, blank, warpctc_loss_data); // Copy the loss back - Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss); + TensorCopy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss); } }; diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 7e7fb554ac1..1b0916ea037 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -101,7 +101,7 @@ T TensorGetElement(framework::Tensor &self, size_t offset) { return self.data()[offset]; } else { std::shared_ptr dst(new framework::Tensor); - framework::Copy(self, platform::CPUPlace(), dst.get()); + framework::TensorCopy(self, platform::CPUPlace(), dst.get()); return dst->data()[offset]; } } @@ -111,9 +111,9 @@ template void TensorSetElement(framework::Tensor &self, size_t offset, T elem) { if (platform::is_gpu_place(self.place())) { std::shared_ptr dst(new framework::Tensor); - framework::Copy(self, platform::CPUPlace(), dst.get()); + framework::TensorCopy(self, platform::CPUPlace(), dst.get()); dst->data()[offset] = elem; - framework::Copy(*dst.get(), self.place(), &self); + framework::TensorCopy(*dst.get(), self.place(), &self); } else if (platform::is_cpu_place(self.place())) { self.data()[offset] = elem; -- GitLab