diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index a230346a8e0056b2fdb162846b4da90d98db04c7..c72e53308e65857c6488a484d0da568dc740a5d4 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -190,7 +190,7 @@ cc_test( cc_library( var_type_traits SRCS var_type_traits.cc - DEPS framework_proto scope) + DEPS framework_proto scope tensor_array) if(WITH_GPU) target_link_libraries(var_type_traits dynload_cuda) endif() diff --git a/paddle/fluid/framework/lod_tensor_array.h b/paddle/fluid/framework/lod_tensor_array.h index 7aa180ed75ce217dcadb747f8014d92f6e4931fe..4849cfbc6e8e98d71c3824adce30f6d5cc1c7615 100644 --- a/paddle/fluid/framework/lod_tensor_array.h +++ b/paddle/fluid/framework/lod_tensor_array.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/phi/core/tensor_array.h" namespace paddle { namespace framework { -using LoDTensorArray = std::vector; +using LoDTensorArray = phi::TensorArray; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 32933ef66170d13d8da791d924d0c0438c93997d..fe64f81ddf00148918b11df4d7d9dc94d1f91d55 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2665,13 +2665,8 @@ void OperatorWithKernel::BuildPhiKernelContext( phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } else if (var->IsType()) { need_prepare_phi_data_ = true; - paddle::small_vector tensor_vector; - auto& tensor_array = var->Get(); - for (auto& t : tensor_array) { - tensor_vector.emplace_back(&t); - } - phi_kernel_context->EmplaceBackInputsWithoutSetRange(tensor_vector); - end_idx += tensor_array.size() - 1; + tensor_in = &(var->Get()); + phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported input `%s` type when call pt kernel.", @@ -2714,16 +2709,10 @@ void OperatorWithKernel::BuildPhiKernelContext( tensor_out = var->template GetMutable(); phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } else if (var->template IsType()) { - paddle::small_vector tensor_vector; - auto* tensor_array = - var->template GetMutable(); + tensor_out = var->template GetMutable(); // Note: If the input LoDTensorArray size is 0, the output // LoDTensorArray is also 0 - for (auto& t : *tensor_array) { - tensor_vector.emplace_back(&t); - } - phi_kernel_context->EmplaceBackOutputsWithoutSetRange(tensor_vector); - end_idx += tensor_array->size() - 1; + phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported output `%s` type when call pt kernel.", diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 5ab09d546df10fb5931867775709b84de0c93baa..499884208bebd1abe3d37afb5f3a36a8ec381084 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -19,7 +19,7 @@ namespace paddle { namespace framework { -void ReaderBase::ReadNext(std::vector *out) { +void ReaderBase::ReadNext(paddle::framework::LoDTensorArray *out) { std::lock_guard lock(mu_); PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning, diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index d708e01803c3ce994fe2fab4ad6cc63deb0c8e59..b2c48c5877dc71192ce2494d94e963295a94ca29 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -48,7 +48,7 @@ class ReaderBase { "and need_check_feed")); } - virtual void ReadNext(std::vector* out); + virtual void ReadNext(paddle::framework::LoDTensorArray* out); virtual void Shutdown(); @@ -73,7 +73,7 @@ class ReaderBase { virtual ~ReaderBase(); protected: - virtual void ReadNextImpl(std::vector* out) {} + virtual void ReadNextImpl(paddle::framework::LoDTensorArray* out) {} virtual void ShutdownImpl() {} @@ -167,7 +167,7 @@ class ReaderHolder { const std::shared_ptr& Get() const { return reader_; } - void ReadNext(std::vector* out) { + void ReadNext(paddle::framework::LoDTensorArray* out) { PADDLE_ENFORCE_NOT_NULL( reader_, platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc index f47a36c3b41345f9d573d72f86ad39ce09a618ed..bca4f7de8ad0a0371bfd9cfec75ffbda35c39572 100644 --- a/paddle/fluid/framework/reader_test.cc +++ b/paddle/fluid/framework/reader_test.cc @@ -24,7 +24,7 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader { explicit StubDecoratedReader(const std::shared_ptr &reader) : DecoratedReader(reader) {} - void ReadNextImpl(std::vector *out) override {} + void ReadNextImpl(paddle::framework::LoDTensorArray *out) override {} }; class StubRootReader : public paddle::framework::ReaderBase { @@ -34,7 +34,7 @@ class StubRootReader : public paddle::framework::ReaderBase { const std::vector &var_types, const std::vector &need_check_feed) : paddle::framework::ReaderBase(dims, var_types, need_check_feed) {} - void ReadNextImpl(std::vector *out) override {} + void ReadNextImpl(paddle::framework::LoDTensorArray *out) override {} }; TEST(READER, decorate_chain) { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index b6c78c47a287c1b6d85dda4510c6b9d207dd30f3..58cae0faead9f0b21f5c7779c9e7b3acf60a583f 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -330,13 +330,8 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, tensor_in = &(var.template Get()); kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); } else if (var.template IsType()) { - paddle::small_vector tensor_vector; - auto& tensor_array = var.template Get(); - for (auto& t : tensor_array) { - tensor_vector.emplace_back(&t); - } - kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector); - end_idx += tensor_array.size() - 1; + tensor_in = &(var.template Get()); + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported input `%s` type when call pt kernel.", @@ -377,14 +372,8 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, tensor_out = var->template GetMutable(); kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); } else if (var->template IsType()) { - paddle::small_vector tensor_vector; - auto* tensor_array = - var->template GetMutable(); - for (auto& t : *tensor_array) { - tensor_vector.emplace_back(&t); - } - kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector); - end_idx += tensor_array->size() - 1; + tensor_out = var->template GetMutable(); + kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported output `%s` type when call pt kernel.", diff --git a/paddle/fluid/operators/dequeue_op.cc b/paddle/fluid/operators/dequeue_op.cc index a0b8fed5b8a1bea4d6ae5f6889cbbf3238f8eb06..a23b408d0b798d3ddcf032cd777745adc7dc7085 100644 --- a/paddle/fluid/operators/dequeue_op.cc +++ b/paddle/fluid/operators/dequeue_op.cc @@ -65,7 +65,7 @@ class DequeueOp : public framework::OperatorBase { platform::errors::InvalidArgument( "Variable with name %s has not been initialized.", out_names[i])); - std::vector lod_tensor_vec; + paddle::framework::LoDTensorArray lod_tensor_vec; bool success = false; lod_tensor_vec = queue_holder->GetQueue()->Pop(&success); PADDLE_ENFORCE_EQ(lod_tensor_vec.size(), diff --git a/paddle/fluid/operators/enqueue_op.cc b/paddle/fluid/operators/enqueue_op.cc index be7afee223e5867e3eea0bac8f8864945c4d5c6b..b118852870ed8e083e23cd89f7ed2a9886a7d1ae 100644 --- a/paddle/fluid/operators/enqueue_op.cc +++ b/paddle/fluid/operators/enqueue_op.cc @@ -65,7 +65,7 @@ class EnqueueOp : public framework::OperatorBase { auto* queue_holder = queue_holder_var->template GetMutable(); - std::vector lod_tensor_vec; + paddle::framework::LoDTensorArray lod_tensor_vec; lod_tensor_vec.emplace_back(*in_tensor); queue_holder->GetQueue()->Push(lod_tensor_vec); } diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index b9c608b62e7db55bb3b41bb6e3f12a78d58fc445..e9205e3ccb8c22131a962abc5e62d321c08e9029 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -502,7 +502,7 @@ void BufferedReader::StartImpl() { ReadTillBufferFullAsync(); } -void BufferedReader::ReadNextImpl(std::vector *out) { +void BufferedReader::ReadNextImpl(paddle::framework::LoDTensorArray *out) { if (position_.empty()) { out->clear(); return; diff --git a/paddle/fluid/operators/reader/buffered_reader.h b/paddle/fluid/operators/reader/buffered_reader.h index 06aaf4c12057da32febfaa4b1fc8d93bca5f5c0f..e506601358e5558ad8db337cc8af338958dd152c 100644 --- a/paddle/fluid/operators/reader/buffered_reader.h +++ b/paddle/fluid/operators/reader/buffered_reader.h @@ -46,7 +46,7 @@ namespace operators { namespace reader { class BufferedReader : public framework::DecoratedReader { - using TensorVec = std::vector; + using TensorVec = paddle::framework::LoDTensorArray; using VecFuture = std::future; public: @@ -65,7 +65,7 @@ class BufferedReader : public framework::DecoratedReader { protected: void ShutdownImpl() override; void StartImpl() override; - void ReadNextImpl(std::vector* out) override; + void ReadNextImpl(paddle::framework::LoDTensorArray* out) override; private: ThreadPool thread_pool_; diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index 5285d14ec7d5396d05464f7ae8e62e822d4587d3..76c57956e9b5ed8616647244882d689c32fd8f0f 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -156,9 +156,9 @@ class CustomReaderInferVarType : public framework::VarTypeInference { } }; -void CustomReader::ReadNextImpl(std::vector* out) { +void CustomReader::ReadNextImpl(paddle::framework::LoDTensorArray* out) { out->clear(); - std::vector underlying_outs; + paddle::framework::LoDTensorArray underlying_outs; reader_->ReadNext(&underlying_outs); if (underlying_outs.empty()) { // There is not next data. diff --git a/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h b/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h index ec50a21eb44e1afde1a86ed6d8c9795865db49fa..e19d8d3219db206fdc23be0a2e01e1699802014f 100644 --- a/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h +++ b/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h @@ -34,16 +34,16 @@ class LoDTensorBlockingQueue { ~LoDTensorBlockingQueue() { VLOG(10) << "Destruct LoDTensorBlockingQueue"; } - bool Push(const std::vector& lod_tensor_vec) { + bool Push(const paddle::framework::LoDTensorArray& lod_tensor_vec) { return queue_.Send(lod_tensor_vec); } - bool Push(std::vector&& lod_tensor_vec) { + bool Push(paddle::framework::LoDTensorArray&& lod_tensor_vec) { return queue_.Send(std::move(lod_tensor_vec)); } - std::vector Pop(bool* ok = nullptr) { - std::vector lod_tensor_vec; + paddle::framework::LoDTensorArray Pop(bool* ok = nullptr) { + paddle::framework::LoDTensorArray lod_tensor_vec; bool success = queue_.Receive(&lod_tensor_vec); if (ok != nullptr) *ok = success; return lod_tensor_vec; @@ -67,7 +67,7 @@ class LoDTensorBlockingQueue { inline bool WaitForInited(size_t) { return true; } private: - BlockingQueue> queue_; + BlockingQueue queue_; }; class OrderedMultiDeviceLoDTensorBlockingQueue { @@ -123,7 +123,7 @@ class OrderedMultiDeviceLoDTensorBlockingQueue { return queues_[idx]; } - bool Push(const std::vector& lod_tensor_vec) { + bool Push(const paddle::framework::LoDTensorArray& lod_tensor_vec) { return CurQueue()->Push(lod_tensor_vec); } diff --git a/paddle/fluid/operators/reader/py_reader.cc b/paddle/fluid/operators/reader/py_reader.cc index ad79f6bbc4c4a82eff107e42fa60f139ccd5a818..89a5c256add4fa25b0fe64682aa880ea619198f3 100644 --- a/paddle/fluid/operators/reader/py_reader.cc +++ b/paddle/fluid/operators/reader/py_reader.cc @@ -30,7 +30,7 @@ PyReader::PyReader( queue_ = queue; } -void PyReader::ReadNext(std::vector* out) { +void PyReader::ReadNext(paddle::framework::LoDTensorArray* out) { bool success; *out = queue_->Pop(&success); if (!success) out->clear(); diff --git a/paddle/fluid/operators/reader/py_reader.h b/paddle/fluid/operators/reader/py_reader.h index 3492d57804886309ebf1b63e8c161bc76d0c5abd..21a20c6ce95f5252b1ee980336d9bc6473bac1bb 100644 --- a/paddle/fluid/operators/reader/py_reader.h +++ b/paddle/fluid/operators/reader/py_reader.h @@ -35,7 +35,7 @@ class PyReader : public framework::FileReader { const std::vector& var_types, const std::vector& need_check_feed); - void ReadNext(std::vector* out) override; + void ReadNext(paddle::framework::LoDTensorArray* out) override; ~PyReader(); diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index deb0e4a49337f9f383f23f1d62690e2aa167c898..3d551412a9c130a8d9f4de014c33a48d3627bbdb 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -106,7 +106,7 @@ class ReadOp : public framework::OperatorBase { scope.FindVar(Input("Reader")), "Input", "Reader", "Read") .GetMutable(); std::vector out_arg_names = Outputs("Out"); - std::vector ins; + paddle::framework::LoDTensorArray ins; // For profiling platform::RecordEvent record_event( diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index c50a80f64dd591e02a346dbae5ed5d49e7e01491..c4dbd905de8c9471c6073755e42ea1374fad171d 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -483,7 +483,10 @@ std::vector CastPyArg2VectorOfTensorBase(PyObject* obj, } else if (PyObject_IsInstance(obj, reinterpret_cast( g_framework_lodtensorarray_pytype))) { - return ::pybind11::handle(obj).cast(); + for (auto& tensor : + (::pybind11::handle(obj).cast())) { + result.emplace_back(tensor); + } } else if (obj == Py_None) { return {}; } else { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 1878752f4284e8f3cb74b91af0b1804e33936fca..1f4a93dab91eb01858853fafd12ecbcf49a347f4 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -19,6 +19,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/jit/function.h" #include "paddle/fluid/platform/place.h" diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc index 36c09f543a6c20cea94865bc01fc992d43e4184d..7a70c2356c8b5b2e13efa938a07d95b364979533 100644 --- a/paddle/fluid/pybind/reader_py.cc +++ b/paddle/fluid/pybind/reader_py.cc @@ -118,7 +118,7 @@ class MultiDeviceFeedReader { public: using ResultDictList = std::vector>; - using ResultList = std::vector>; + using ResultList = std::vector; static constexpr bool kKeepOrder = std::is_same> futures_; std::vector exceptions_; - std::vector> ret_; + std::vector ret_; bool drop_last_; bool pin_memory_; }; @@ -427,7 +427,7 @@ void BindReader(py::module *module) { .def( "push", [](reader::LoDTensorBlockingQueue &self, - const std::vector &lod_tensor_vec) { + const paddle::framework::LoDTensorArray &lod_tensor_vec) { return self.Push(lod_tensor_vec); }, py::call_guard()) @@ -445,7 +445,7 @@ void BindReader(py::module *module) { .def( "push", [](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self, - const std::vector &lod_tensor_vec) { + const paddle::framework::LoDTensorArray &lod_tensor_vec) { return self.Push(lod_tensor_vec); }, py::call_guard()) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index f49045818177a06a0ba158ef316d588b0939dc39..1599dba981efb4546203f9ac2e22f92b09edee66 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2193,15 +2193,6 @@ func : reverse backward : reverse_grad -- api : reverse_array - args : (Tensor[] x, IntArray axis) - output : Tensor[]{x.size()} - infer_meta : - func : ReverseArrayInferMeta - kernel : - func : reverse_array - backward : reverse_array_grad - - api : rmsprop_ args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, float epsilon, float decay, float momentum, bool centered) output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out) diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index c353e21fbd8218cd97820a13c523a04ae48e4893..669ca6c63c41e4d5498a885374b9c01d473285b2 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -57,6 +57,11 @@ cc_library( SRCS string_tensor.cc DEPS convert_utils tensor_meta tensor_base) +cc_library( + tensor_array + SRCS tensor_array.cc + DEPS dense_tensor tensor_base) + cc_library( meta_tensor SRCS meta_tensor.cc diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 1cba62a86ef01d63b2fc7bee005057cdda1fae2c..28c750dd9d923879268679a81790bd2dd4877a5f 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -132,6 +132,11 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(const TensorArray&))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(DenseTensor*))) { args_def->AppendOutput(default_key.backend(), default_tensor_layout, @@ -148,6 +153,11 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(TensorArray*))) { + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(SparseCooTensor*))) { args_def->AppendOutput(default_key.backend(), default_tensor_layout, diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index df850389ff453ba324d5e3fb751533109c11df20..c87e5e2595e29627cd2f78caf9b890fcc2d8d9eb 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -30,6 +30,7 @@ #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/string_tensor.h" +#include "paddle/phi/core/tensor_array.h" #include "paddle/phi/core/type_defs.h" namespace phi { @@ -284,6 +285,9 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(StringTensor); + PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(TensorArray); + PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorArray); + /* Attribute Helpers */ PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool); @@ -322,6 +326,8 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor); + PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray); + template struct KernelCallHelper { template & vec) { + tensors_ = vec; +} + +/// \brief Test whether the tensor's storage in TensorArray is allocated. +/// return Whether all tensors in TensorArray is allocated. +bool TensorArray::initialized() const { + bool init = true; + for (auto tensor : tensors_) { + if (!tensor.IsInitialized()) { + init = false; + } + } + return init; +} + +int64_t TensorArray::numel() const { + PADDLE_THROW(errors::Unavailable("numel() can't be used in TensorArray")); + return -1; +} + +const DDim& TensorArray::dims() const { + PADDLE_THROW(errors::Unavailable("dims() can't be used in TensorArray")); + return tensors_[0].dims(); +} + +const Place& TensorArray::place() const { + PADDLE_THROW(errors::Unavailable("place() can't be used in TensorArray")); + return tensors_[0].place(); +} + +DataType TensorArray::dtype() const { + PADDLE_THROW(errors::Unavailable("dtype() can't be used in TensorArray")); + return DataType::UNDEFINED; +} + +DataLayout TensorArray::layout() const { + PADDLE_THROW(errors::Unavailable("layout() can't be used in TensorArray")); + return DataLayout::UNDEFINED; +} + +bool TensorArray::valid() const { + PADDLE_THROW(errors::Unavailable("valid() can't be used in TensorArray")); + return false; +} + +/// \brief Allocate memory with requested size for all tensors from allocator. +/// \return Void pointer +void* TensorArray::AllocateFrom(Allocator* allocator, + DataType dtype, + size_t requested_size) { + for (size_t i = 0; i < tensors_.size(); i++) { + tensors_[i].AllocateFrom(allocator, tensors_[i].dtype(), requested_size); + } + return nullptr; +} + +void TensorArray::push_back(const DenseTensor& tensor) { + tensors_.push_back(tensor); +} + +void TensorArray::emplace_back(const DenseTensor& tensor) { + tensors_.emplace_back(tensor); +} + +void TensorArray::emplace_back() { + DenseTensor t; + tensors_.emplace_back(t); +} + +} // namespace phi diff --git a/paddle/phi/core/tensor_array.h b/paddle/phi/core/tensor_array.h new file mode 100644 index 0000000000000000000000000000000000000000..ade33099eee312c80f9adceb1c1c0ed2bede0f7f --- /dev/null +++ b/paddle/phi/core/tensor_array.h @@ -0,0 +1,134 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/// \brief The TensorArray store a list of tensor and it is designed for +/// compatible with LodTensorArray in Fluid. It shouldn't be used widely +/// in PHI. If you want to store a list of tensor in PHI, please use std::vector +/// when ever possible. +class TensorArray : public TensorBase, + public TypeInfoTraits { + public: + /// \brief Construct a TensorArray. + /// \param vec The vector DenseTensor used to init TensorArray. + explicit TensorArray(const std::vector& vec); + + explicit TensorArray(size_t n) { + for (size_t i = 0; i < n; i++) { + tensors_.emplace_back(); + } + } + + TensorArray() = default; + + TensorArray(TensorArray&& other) = default; + + TensorArray(const TensorArray& other) = default; + + /// \brief TensorArray shallow copy assignment. + TensorArray& operator=(const TensorArray& other) = default; + + TensorArray& operator=(TensorArray&& other) = default; + + /// \brief Destroy the tensor object and release exclusive resources. + virtual ~TensorArray() = default; + + public: + /// \brief Returns the name of the class for type traits. + /// \return The name of the class. + static const char* name() { return "TensorArray"; } + + /// \brief This overrided function is not used in TensorArray. + int64_t numel() const override; + + /// \brief This overrided function is not used in TensorArray. + const DDim& dims() const override; + + /// \brief This overrided function is not used in TensorArray. + const Place& place() const override; + + /// \brief This overrided function is not used in TensorArray. + DataType dtype() const override; + + /// \brief This overrided function is not used in TensorArray. + DataLayout layout() const override; + + /// \brief This overrided function is not used in TensorArray. + bool valid() const override; + + /// \brief Test whether the tensor's storage in TensorArray is allocated. + /// return Whether all tensors in TensorArray is allocated. + bool initialized() const override; + + /// \brief Clear all tensors in TensorArray. + void clear() { tensors_.clear(); } + + /// \brief Allocate memory with requested size for all tensors from allocator. + /// \return Void pointer + void* AllocateFrom(Allocator* allocator, + DataType dtype, + size_t requested_size = 0); + + bool empty() const { return tensors_.empty(); } + + /// \brief Returns the number of tensors in TensorArray. + size_t size() const { return tensors_.size(); } + + /// \brief Resizes the TensorArray so that it contains n tensors. + void resize(size_t n) { tensors_.resize(n); } + + /// \brief Requests that the TensorArray capacity be at least enough to + /// contain n tensors. + void reserve(size_t n) { tensors_.reserve(n); } + + /// \brief Add the tensor to the end of TensorArray + void push_back(const DenseTensor& tensor); + + void emplace_back(); + + void emplace_back(const DenseTensor& tensor); + + /// \brief Return the last tensor in TensorArray + DenseTensor& back() { return tensors_.back(); } + + DenseTensor& at(size_t index) { return tensors_.at(index); } + + const DenseTensor& at(size_t index) const { return tensors_.at(index); } + + const DenseTensor& operator[](size_t index) const { return tensors_[index]; } + + DenseTensor& operator[](size_t index) { return tensors_[index]; } + + std::vector::iterator begin() { return tensors_.begin(); } + + std::vector::const_iterator begin() const { + return tensors_.begin(); + } + + std::vector::iterator end() { return tensors_.end(); } + + std::vector::const_iterator end() const { + return tensors_.end(); + } + + private: + std::vector tensors_; +}; + +} // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 66867c938dd5ade7fbccafb1fa526c497d172c2a..275b9ef031bb4fed3dd7d1d50f06e53b11ff9ffa 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -25,6 +25,7 @@ set(COMMON_KERNEL_DEPS string_tensor sparse_coo_tensor sparse_csr_tensor + tensor_array kernel_context kernel_factory arg_map_context diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index bf030e6fb4b5fd45f6d2a6a800a3ac0c3ae1bdd1..77b9fbc0e1628ac031441dda67246e7e92948ef0 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -45,10 +45,10 @@ void AssignRawKernel(const Context& dev_ctx, // as input if needed template void AssignArrayKernel(const Context& dev_ctx, - const std::vector& x, - std::vector out) { + const TensorArray& x, + TensorArray* out) { for (size_t i = 0; i < x.size(); ++i) { - AssignKernel(dev_ctx, *x[i], out.at(i)); + AssignKernel(dev_ctx, x[i], &out->at(i)); } } diff --git a/paddle/phi/kernels/assign_kernel.h b/paddle/phi/kernels/assign_kernel.h index 41be3e43a303d1233f2f18bbf2cd4af6f34d31be..7fa0350ad0ed6458adafb69e32ad2085b7330289 100644 --- a/paddle/phi/kernels/assign_kernel.h +++ b/paddle/phi/kernels/assign_kernel.h @@ -18,6 +18,7 @@ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" #include "paddle/phi/infermeta/unary.h" namespace phi { @@ -47,8 +48,8 @@ void AssignRawKernel(const Context& dev_ctx, template void AssignArrayKernel(const Context& dev_ctx, - const std::vector& x, - std::vector out); + const TensorArray& x, + TensorArray* out); template void AssignValueKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/funcs/strided_slice.h b/paddle/phi/kernels/funcs/strided_slice.h index 4d045bdeb596c8f55e9ca5c90000f66d7116b565..4a88c1e0660b79872668bc96e98b0f4808d2775c 100644 --- a/paddle/phi/kernels/funcs/strided_slice.h +++ b/paddle/phi/kernels/funcs/strided_slice.h @@ -20,6 +20,7 @@ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/tensor_array.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -297,14 +298,14 @@ void StridedSliceCompute(const Context& dev_ctx, template void StridedSliceCompute(const Context& dev_ctx, - const std::vector& x, + const TensorArray& x, const std::vector& axes, const IntArray& starts, const IntArray& ends, const IntArray& strides, const std::vector& infer_flags, const std::vector& decrease_axis, - std::vector out) { + TensorArray* out) { const int64_t size = x.size(); auto in_dims = phi::make_ddim({size}); @@ -419,29 +420,29 @@ void StridedSliceCompute(const Context& dev_ctx, "dimension of Output should be 1, but received %d", out_dims_origin.size())); - out.resize(out_dims_origin[0]); + out->resize(out_dims_origin[0]); size_t const in_array_size = x.size(); - for (size_t i = 0; i < out.size(); i++) { + for (size_t i = 0; i < out->size(); i++) { size_t in_offset = (starts_indices[0] % in_array_size) + i * strides_indices[0]; int64_t out_offset = i; if (need_reverse) { - out_offset = out.size() - i - 1; + out_offset = out->size() - i - 1; } - auto* in_tensor = x.at(in_offset); + auto& in_tensor = x.at(in_offset); PADDLE_ENFORCE_GT( - in_tensor->memory_size(), + in_tensor.memory_size(), 0, errors::PreconditionNotMet( "The input LoDTensorArray Input[%d] holds no memory.", in_offset)); - auto* out_tensor = out.at(out_offset); - out_tensor->Resize(in_tensor->dims()); + auto& out_tensor = out->at(out_offset); + out_tensor.Resize(in_tensor.dims()); phi::Copy( - dev_ctx, *in_tensor, dev_ctx.GetPlace(), false, out_tensor); - out_tensor->set_lod(in_tensor->lod()); + dev_ctx, in_tensor, dev_ctx.GetPlace(), false, &out_tensor); + out_tensor.set_lod(in_tensor.lod()); } } @@ -531,15 +532,15 @@ void StridedSliceGradCompute(const Context& dev_ctx, template void StridedSliceGradCompute(const Context& dev_ctx, - const std::vector& x, - const std::vector& out_grad, + const TensorArray& x, + const TensorArray& out_grad, const std::vector& axes, const IntArray& starts, const IntArray& ends, const IntArray& strides, const std::vector& infer_flags, const std::vector& decrease_axis, - std::vector x_grad) { + TensorArray* x_grad) { // Note(weixin):Since the shape of `framework::GradVarName("Input")` of // StridedSliceGrad cannot be calculated by // `framework::GradVarName("Output")`, the dim of "Input" is used to @@ -619,11 +620,11 @@ void StridedSliceGradCompute(const Context& dev_ctx, "the dimension of output should be 1, but received %d.", out_dims.size())); - auto const d_out_array_size = x_grad.size(); + auto const d_out_array_size = x_grad->size(); for (size_t j = 0; j < d_out_array_size; j++) { - auto& dim = x.at(j)->dims(); - auto* d_out_tensor = x_grad.at(j); + auto& dim = x.at(j).dims(); + auto& d_out_tensor = x_grad->at(j); int64_t sub = j - starts_indices[0]; @@ -635,26 +636,26 @@ void StridedSliceGradCompute(const Context& dev_ctx, if ((sub % strides_indices[0] == 0) && (0 <= in_offset) && (static_cast(in_offset) < out_grad.size())) { - auto* in_tensor = out_grad.at(in_offset); + auto& in_tensor = out_grad.at(in_offset); PADDLE_ENFORCE_GT( - in_tensor->memory_size(), + in_tensor.memory_size(), 0, errors::PreconditionNotMet( "The input LoDTensorArray Input[%d] holds no memory.", in_offset)); phi::Copy( - dev_ctx, *in_tensor, dev_ctx.GetPlace(), false, d_out_tensor); - d_out_tensor->set_lod(in_tensor->lod()); + dev_ctx, in_tensor, dev_ctx.GetPlace(), false, &d_out_tensor); + d_out_tensor.set_lod(in_tensor.lod()); } else { - d_out_tensor->Resize(dim); + d_out_tensor.Resize(dim); - if (!d_out_tensor->IsInitialized()) { - dev_ctx.template Alloc(d_out_tensor); + if (!d_out_tensor.IsInitialized()) { + dev_ctx.template Alloc(&d_out_tensor); } phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, d_out_tensor, static_cast(0)); + set_zero(dev_ctx, &d_out_tensor, static_cast(0)); } } } diff --git a/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h b/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h index f8b604ef1179bc39a70e914003a0693a46c0cc4b..3824e301e2ec2b84b7a3f85f84f052aea8f0dbbc 100644 --- a/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h @@ -56,17 +56,16 @@ void StridedSliceRawGradKernel(const Context& dev_ctx, } template -void StridedSliceArrayGradKernel( - const Context& dev_ctx, - const std::vector& x, - const std::vector& out_grad, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, - std::vector x_grad) { +void StridedSliceArrayGradKernel(const Context& dev_ctx, + const TensorArray& x, + const TensorArray& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + TensorArray* x_grad) { funcs::StridedSliceGradCompute(dev_ctx, x, out_grad, diff --git a/paddle/phi/kernels/impl/strided_slice_kernel_impl.h b/paddle/phi/kernels/impl/strided_slice_kernel_impl.h index 5d6c3d8992cb40a28cc74e8a5c15dc9468598b50..f8dc298f47e60e27aeea61d98189b1b0fb337377 100644 --- a/paddle/phi/kernels/impl/strided_slice_kernel_impl.h +++ b/paddle/phi/kernels/impl/strided_slice_kernel_impl.h @@ -55,14 +55,14 @@ void StridedSliceRawKernel(const Context& dev_ctx, template void StridedSliceArrayKernel(const Context& dev_ctx, - const std::vector& x, + const TensorArray& x, const std::vector& axes, const IntArray& starts, const IntArray& ends, const IntArray& strides, const std::vector& infer_flags, const std::vector& decrease_axis, - std::vector out) { + TensorArray* out) { funcs::StridedSliceCompute( dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out); } diff --git a/paddle/phi/kernels/memcpy_kernel.cc b/paddle/phi/kernels/memcpy_kernel.cc index e6307b66d4b0f7d35ae8401836fa7ebcfbef94d1..acc87dc9960d1c78962b0704ba1f34be4e56f49c 100644 --- a/paddle/phi/kernels/memcpy_kernel.cc +++ b/paddle/phi/kernels/memcpy_kernel.cc @@ -110,25 +110,21 @@ void MemcpyD2HKernel(const Context& dev_ctx, template void MemcpyD2HMultiIOKernel(const Context& dev_ctx, - const std::vector& array, + const TensorArray& array, int dst_place_type, - std::vector out_array) { + TensorArray* out_array) { + PADDLE_ENFORCE_NOT_NULL( + out_array, + errors::PreconditionNotMet("output tesnor_array should not be nullptr")); PADDLE_ENFORCE_EQ( array.size(), - out_array.size(), + out_array->size(), errors::PreconditionNotMet( - "input size %d != output size %d", array.size(), out_array.size())); + "input size %d != output size %d", array.size(), out_array->size())); for (size_t i = 0; i < array.size(); i++) { - PADDLE_ENFORCE_NOT_NULL( - array[i], - errors::PreconditionNotMet("input tesnor %d should not be nullptr", i)); - PADDLE_ENFORCE_NOT_NULL(out_array[i], - errors::PreconditionNotMet( - "output tesnor %d should not be nullptr", i)); - - const auto& x = *(array[i]); - MemcpyD2HKernel(dev_ctx, x, dst_place_type, out_array[i]); + const auto& x = array[i]; + MemcpyD2HKernel(dev_ctx, x, dst_place_type, &(out_array->at(i))); } } diff --git a/paddle/phi/kernels/memcpy_kernel.h b/paddle/phi/kernels/memcpy_kernel.h index d63881a723ebb851233e2ca8e03931d09879b0f6..72a58982b05c373a487257c09063d64840a598ac 100644 --- a/paddle/phi/kernels/memcpy_kernel.h +++ b/paddle/phi/kernels/memcpy_kernel.h @@ -17,6 +17,7 @@ #include #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" namespace phi { @@ -36,9 +37,9 @@ void MemcpyD2HKernel(const Context& dev_ctx, template void MemcpyD2HMultiIOKernel(const Context& dev_ctx, - const std::vector& array, + const TensorArray& array, int dst_place_type, - std::vector out_array); + TensorArray* out_array); template void MemcpyKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/reverse_kernel.cc b/paddle/phi/kernels/reverse_kernel.cc index b42923ac5dde47e036bab227f033988df1229521..b2fe61ad41fc6db84fdf755d7d57a1f4f76cf5c7 100644 --- a/paddle/phi/kernels/reverse_kernel.cc +++ b/paddle/phi/kernels/reverse_kernel.cc @@ -22,29 +22,29 @@ namespace phi { template void ReverseArrayKernel(const Context& dev_ctx, - const std::vector& x, + const TensorArray& x, const IntArray& axis, - std::vector out) { + TensorArray* out) { PADDLE_ENFORCE_EQ( x.size(), - out.size(), + out->size(), phi::errors::InvalidArgument("The input size(%d) and output size(%d) of " "ReverseArrayKernel is different.", x.size(), - out.size())); + out->size())); for (size_t offset = 0; offset < x.size(); ++offset) { - auto* x_tensor = x.at(offset); + auto& x_tensor = x.at(offset); PADDLE_ENFORCE_GT( - x_tensor->memory_size(), + x_tensor.memory_size(), 0, phi::errors::PreconditionNotMet( "The input LoDTensorArray X[%d] holds no memory.", offset)); auto out_offset = x.size() - offset - 1; - auto* out_tensor = out.at(out_offset); + auto& out_tensor = out->at(out_offset); - out_tensor->set_lod(x_tensor->lod()); + out_tensor.set_lod(x_tensor.lod()); phi::Copy( - dev_ctx, *x_tensor, dev_ctx.GetPlace(), false, out_tensor); + dev_ctx, x_tensor, dev_ctx.GetPlace(), false, &out_tensor); } } @@ -60,7 +60,9 @@ PD_REGISTER_KERNEL(reverse_array, bool, float, double) {} + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + PD_REGISTER_KERNEL(reverse_array, GPU, ALL_LAYOUT, @@ -71,4 +73,5 @@ PD_REGISTER_KERNEL(reverse_array, bool, float, double) {} + #endif diff --git a/paddle/phi/kernels/reverse_kernel.h b/paddle/phi/kernels/reverse_kernel.h index 1ccfa344d5c92733c883883dbcc701547ff9bdf4..9e4d5fc512d07ca15bc458b0461b5a67ad9a4006 100644 --- a/paddle/phi/kernels/reverse_kernel.h +++ b/paddle/phi/kernels/reverse_kernel.h @@ -18,6 +18,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" namespace phi { @@ -29,8 +30,8 @@ void ReverseKernel(const Context& dev_ctx, template void ReverseArrayKernel(const Context& dev_ctx, - const std::vector& x, + const TensorArray& x, const IntArray& axis, - std::vector out); + TensorArray* out); } // namespace phi diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.h b/paddle/phi/kernels/strided_slice_grad_kernel.h index 21d01310b662f4c919d7008f2ddc7f6d5ea836ff..8dfd3fd5bcc07dd927b43562f9441bd8c96f7bf4 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.h +++ b/paddle/phi/kernels/strided_slice_grad_kernel.h @@ -16,6 +16,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" namespace phi { @@ -42,15 +43,14 @@ void StridedSliceGradKernel(const Context& dev_ctx, DenseTensor* x_grad); template -void StridedSliceArrayGradKernel( - const Context& dev_ctx, - const std::vector& x, - const std::vector& out_grad, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, - std::vector x_grad); +void StridedSliceArrayGradKernel(const Context& dev_ctx, + const TensorArray& x, + const TensorArray& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + TensorArray* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/strided_slice_kernel.h b/paddle/phi/kernels/strided_slice_kernel.h index 2c8b373bf03a85a73cb4341756ea1f4e51033e65..35ffbeebd4a9a887688737efc6c3671b0a2fec25 100644 --- a/paddle/phi/kernels/strided_slice_kernel.h +++ b/paddle/phi/kernels/strided_slice_kernel.h @@ -16,6 +16,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" namespace phi { @@ -41,12 +42,12 @@ void StridedSliceKernel(const Context& dev_ctx, template void StridedSliceArrayKernel(const Context& dev_ctx, - const std::vector& x, + const TensorArray& x, const std::vector& axes, const IntArray& starts, const IntArray& ends, const IntArray& strides, const std::vector& infer_flags, const std::vector& decrease_axis, - std::vector out); + TensorArray* out); } // namespace phi diff --git a/paddle/phi/tests/core/CMakeLists.txt b/paddle/phi/tests/core/CMakeLists.txt index 3d549aa5f160cd54af92b5fa4b5bf801a9760abb..4a0c99f9878127b861a8d7bd4d70980b45320b9a 100644 --- a/paddle/phi/tests/core/CMakeLists.txt +++ b/paddle/phi/tests/core/CMakeLists.txt @@ -60,3 +60,8 @@ cc_test( SRCS test_string_tensor.cc DEPS string_tensor) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) + +cc_test( + test_tensor_array + SRCS test_tensor_array.cc + DEPS tensor_array) diff --git a/paddle/phi/tests/core/test_tensor_array.cc b/paddle/phi/tests/core/test_tensor_array.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a29629cc2dc35f1c2c5299ca23e3f2423e2646b --- /dev/null +++ b/paddle/phi/tests/core/test_tensor_array.cc @@ -0,0 +1,122 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/tensor_array.h" +#include "paddle/phi/tests/core/allocator.h" + +namespace phi { +namespace tests { + +using pstring = ::phi::dtype::pstring; + +TEST(tensor_array, tensor_array_not_init) { + const DDim dims({1, 2}); + const DataType dtype{DataType::INT8}; + const DataLayout layout{DataLayout::NHWC}; + const LoD lod{}; + DenseTensorMeta meta(dtype, dims, layout, lod); + DenseTensor tensor_0; + tensor_0.set_meta(meta); + + std::vector tensors; + tensors.push_back(tensor_0); + tensors.push_back(tensor_0); + tensors.push_back(tensor_0); + + TensorArray tensor_array(tensors); + + try { + tensor_array.dims(); + } catch (const phi::enforce::EnforceNotMet& error) { + std::string ex_msg = error.what(); + EXPECT_TRUE(ex_msg.find("dims") != std::string::npos); + } + + try { + tensor_array.place(); + } catch (const phi::enforce::EnforceNotMet& error) { + std::string ex_msg = error.what(); + EXPECT_TRUE(ex_msg.find("place") != std::string::npos); + } + + try { + tensor_array.dtype(); + } catch (const phi::enforce::EnforceNotMet& error) { + std::string ex_msg = error.what(); + EXPECT_TRUE(ex_msg.find("dtype") != std::string::npos); + } + + try { + tensor_array.layout(); + } catch (const phi::enforce::EnforceNotMet& error) { + std::string ex_msg = error.what(); + EXPECT_TRUE(ex_msg.find("layout") != std::string::npos); + } + + try { + tensor_array.numel(); + } catch (const phi::enforce::EnforceNotMet& error) { + std::string ex_msg = error.what(); + EXPECT_TRUE(ex_msg.find("numel") != std::string::npos); + } + + try { + tensor_array.valid(); + } catch (const phi::enforce::EnforceNotMet& error) { + std::string ex_msg = error.what(); + EXPECT_TRUE(ex_msg.find("valid") != std::string::npos); + } + + CHECK_EQ(tensor_array.initialized(), false); +} + +TEST(tensor_array, tensor_array_init) { + const DDim dims1({1, 2}); + const DDim dims2({1, 2, 3}); + const DataType dtype{DataType::INT8}; + const DataLayout layout{DataLayout::NHWC}; + const LoD lod{}; + + DenseTensorMeta meta1(dtype, dims1, layout, lod); + DenseTensorMeta meta2(dtype, dims2, layout, lod); + + auto fancy_allocator = std::unique_ptr(new FancyAllocator); + auto* alloc = fancy_allocator.get(); + DenseTensor tensor_0; + tensor_0.set_meta(meta1); + + DenseTensor tensor_1; + tensor_1.set_meta(meta2); + + std::vector tensors; + tensors.push_back(tensor_0); + tensors.push_back(tensor_1); + tensors.push_back(tensor_0); + + TensorArray tensor_array(tensors); + tensor_array.AllocateFrom(alloc, DataType::INT8); + + CHECK_EQ(tensor_array.initialized(), true); +} + +} // namespace tests +} // namespace phi diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index a33cfa055c5b935687ae3ea3dadddb33025cd92f..3df027931ccc58792eebbc37677fe98c3b016760 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1284,10 +1284,7 @@ def reverse(x, axis): if isinstance(axis, int): axis = [axis] if in_dygraph_mode(): - if x.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: - return _C_ops.reverse_array(x, axis) - else: - return _C_ops.reverse(x, axis) + return _C_ops.reverse(x, axis) helper = LayerHelper("reverse", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type='reverse', diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index f090cf1c8de11e48b57a5c1e3efa8dd51326c77b..7f09d9b70631dcbbbb541f7c6eb379f868678d4e 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -267,68 +267,6 @@ class TestReverseAxisListTensor(TestReverseAxisTensor): return out -class TestAReverseEagerAPI(UnittestBase): - - def test_api(self): - paddle.disable_static() - x = paddle.randn([4, 10]) - y = paddle.randn([4, 10]) - - out = paddle._C_ops.reverse_array([x, y], [0]) - np.testing.assert_allclose(x.numpy(), out[1].numpy()) - np.testing.assert_allclose(y.numpy(), out[0].numpy()) - - paddle.enable_static() - - -class TestReverseTensorArrayAxisTensor(UnittestBase): - - def init_info(self): - self.shapes = [[2, 3, 4]] - self.save_path = os.path.join(self.temp_dir.name, - 'reverse_tensor_array') - - def test_static(self): - main_prog = Program() - starup_prog = Program() - with program_guard(main_prog, starup_prog): - fc = paddle.nn.Linear(4, 2) - x = paddle.randn([2, 3, 4]) - x.stop_gradient = False - feat = fc(x) # [2,3,10] - # tensor_array.shape: [[2,3,10], [2,3,10]] - tensor_array = paddle.fluid.layers.create_array(dtype='float32') - idx0 = paddle.full(shape=[1], fill_value=0, dtype="int64") - val0 = paddle.randn([2, 3, 2]) - paddle.fluid.layers.array_write(val0, idx0, tensor_array) - idx1 = paddle.full(shape=[1], fill_value=1, dtype="int64") - paddle.fluid.layers.array_write(feat, idx1, tensor_array) - # axes is a Variable - axes = paddle.assign([0]) - # tensor_array.shape: [[2,3,10], [2,3,10]] - reverse_array = paddle.fluid.layers.reverse(tensor_array, axes) - - out, _ = paddle.fluid.layers.tensor_array_to_tensor(reverse_array, - axis=0) - - sgd = paddle.optimizer.SGD() - sgd.minimize(paddle.mean(out)) - self.assertTrue("Var[" in str(main_prog)) - - exe = paddle.static.Executor() - exe.run(starup_prog) - res = exe.run(fetch_list=[val0, feat, out]) - np.testing.assert_allclose(res[1], res[-1][0:2]) - np.testing.assert_allclose(res[0], res[-1][2:4]) - - paddle.static.save_inference_model(self.save_path, [x], - [val0, feat, out], exe) - # Test for Inference Predictor - infer_outs = self.infer_prog() - np.testing.assert_allclose(infer_outs[1], infer_outs[-1][0:2]) - np.testing.assert_allclose(infer_outs[0], infer_outs[-1][2:4]) - - if __name__ == '__main__': paddle.enable_static() unittest.main()