From eacfc1ebbdb2f7065b96280a97c097395c332bff Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 8 Feb 2022 10:21:20 +0800 Subject: [PATCH] INFRT/Add infershape schedule (3rd PR) (#39290) --- paddle/infrt/CMakeLists.txt | 3 +- paddle/infrt/host_context/kernel_frame.h | 7 +- paddle/infrt/host_context/value.cc | 2 + paddle/infrt/host_context/value.h | 9 +++ paddle/infrt/naive/CMakeLists.txt | 7 ++ .../infrt/naive/infershaped/elementwise_add.h | 70 +++++++++++++++++++ .../infershaped/infershape_launchers_test.cc | 46 ++++++++++++ .../infershaped_kernel_launcher.cc | 61 ++++++++++++++++ .../infershaped/infershaped_kernel_launcher.h | 56 +++++++++++++++ .../infershaped_kernel_launchers.cc | 28 ++++++++ .../infershaped_kernel_launchers.h | 25 +++++++ .../naive/infershaped/infershaped_registry.cc | 55 +++++++++++++++ .../naive/infershaped/infershaped_registry.h | 56 +++++++++++++++ paddle/infrt/naive/meta_tensor.cc | 31 ++++++++ paddle/infrt/naive/meta_tensor.h | 47 +++++++++++++ paddle/infrt/tensor/dense_host_tensor.cc | 7 ++ paddle/infrt/tensor/dense_host_tensor.h | 9 ++- paddle/infrt/tensor/tensor_shape.cc | 2 + paddle/infrt/tensor/tensor_shape.h | 5 ++ 19 files changed, 522 insertions(+), 4 deletions(-) create mode 100644 paddle/infrt/naive/CMakeLists.txt create mode 100644 paddle/infrt/naive/infershaped/elementwise_add.h create mode 100644 paddle/infrt/naive/infershaped/infershape_launchers_test.cc create mode 100644 paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc create mode 100644 paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h create mode 100644 paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc create mode 100644 paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h create mode 100644 paddle/infrt/naive/infershaped/infershaped_registry.cc create mode 100644 paddle/infrt/naive/infershaped/infershaped_registry.h create mode 100644 paddle/infrt/naive/meta_tensor.cc create mode 100644 paddle/infrt/naive/meta_tensor.h diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index f17ec328f0..5337c423b1 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -76,6 +76,7 @@ add_subdirectory(tensor) add_subdirectory(support) add_subdirectory(external_kernels) add_subdirectory(paddle) +add_subdirectory(naive) add_subdirectory(tests) @@ -93,7 +94,7 @@ set(infrt_mlir_incs ) message(STATUS "infrt srcs:\n${infrt_src}") -cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto) +cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto infrt_naive) cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto) add_dependencies(infrt ${infrt_mlir_incs}) diff --git a/paddle/infrt/host_context/kernel_frame.h b/paddle/infrt/host_context/kernel_frame.h index 298c40322b..35527872e6 100644 --- a/paddle/infrt/host_context/kernel_frame.h +++ b/paddle/infrt/host_context/kernel_frame.h @@ -31,7 +31,7 @@ namespace host_context { class KernelFrame { public: int GetNumArgs() const { return num_arguments_; } - int GetNumResults() const { return num_results_; } + int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; } int GetNumAttributes() const { return value_or_attrs_.size() - num_arguments_ - (num_results_ == -1 ? 0 : num_results_); @@ -45,6 +45,9 @@ class KernelFrame { return value_or_attrs_[index]->template get_or_default(); } + // Get number of elements, either input, attributes or results. + size_t GetNumElements() const { return value_or_attrs_.size(); } + template T& GetArgAt(int index) { CHECK_LT(index, GetNumArgs()); @@ -118,6 +121,8 @@ class KernelFrame { return llvm::makeMutableArrayRef(&value_or_attrs_[from], length); } + bool IsEmpty() const { return value_or_attrs_.empty(); } + protected: int num_arguments_{}; int num_results_{-1}; diff --git a/paddle/infrt/host_context/value.cc b/paddle/infrt/host_context/value.cc index 8c3ccba3d0..1c5a577092 100644 --- a/paddle/infrt/host_context/value.cc +++ b/paddle/infrt/host_context/value.cc @@ -24,6 +24,8 @@ ValueRef::ValueRef(int64_t val) : Shared(new Value(val)) {} ValueRef::ValueRef(float val) : Shared(new Value(val)) {} ValueRef::ValueRef(double val) : Shared(new Value(val)) {} ValueRef::ValueRef(bool val) : Shared(new Value(val)) {} +ValueRef::ValueRef(naive::MetaTensor&& val) + : Shared(new Value(std::move(val))) {} const char* Value::type_info() const { return __type_info__; } diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 000ce95b82..904e51f928 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -23,6 +23,7 @@ #include "paddle/infrt/common/object.h" #include "paddle/infrt/common/shared.h" #include "paddle/infrt/host_context/function.h" +#include "paddle/infrt/naive/meta_tensor.h" #include "paddle/infrt/support/variant.h" #include "paddle/infrt/tensor/dense_host_tensor.h" #include "paddle/infrt/tensor/dense_tensor_view.h" @@ -50,6 +51,7 @@ using ValueVariantType = Variant, std::vector, std::vector, @@ -82,6 +84,7 @@ class Value : public common::Object { explicit Value(tensor::TensorShape&& x) : data(std::move(x)) {} explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {} explicit Value(MlirFunctionExecutable* x) : data(x) {} + explicit Value(naive::MetaTensor&& x) : data(std::move(x)) {} template const T& get() const { @@ -113,6 +116,11 @@ class Value : public common::Object { bool valid() const { return true; } + template + bool is_type() const { + return data.template is(); + } + const char* type_info() const override; friend void CopyTo(const Value& from, Value* to); @@ -134,6 +142,7 @@ class ValueRef : common::Shared { explicit ValueRef(float val); explicit ValueRef(double val); explicit ValueRef(bool val); + explicit ValueRef(naive::MetaTensor&& val); using common::Shared::get; using common::Shared::Reset; diff --git a/paddle/infrt/naive/CMakeLists.txt b/paddle/infrt/naive/CMakeLists.txt new file mode 100644 index 0000000000..edb7b8a912 --- /dev/null +++ b/paddle/infrt/naive/CMakeLists.txt @@ -0,0 +1,7 @@ +cc_library(infrt_naive SRCS meta_tensor.cc + infershaped/infershaped_kernel_launcher.cc + infershaped/infershaped_registry.cc + infershaped/infershaped_kernel_launchers.cc + ) + +cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt) diff --git a/paddle/infrt/naive/infershaped/elementwise_add.h b/paddle/infrt/naive/infershaped/elementwise_add.h new file mode 100644 index 0000000000..c79929822b --- /dev/null +++ b/paddle/infrt/naive/infershaped/elementwise_add.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" + +// This file contains a example of the infershape ElementwiseAdd kernel. +// Some of the following code should be generated from PTEN by script. + +namespace infrt { +namespace naive { + +static void ElementwiseAddInferShape(const MetaTensor& a, + const MetaTensor& b, + MetaTensor* c) { + CHECK(a.shape() == b.shape()) + << "ElementwiseAdd, but shapes of a b are not match"; + *c->mutable_shape() = a.shape(); +} + +static void ElementwiseAdd(const tensor::DenseHostTensor& a, + const tensor::DenseHostTensor& b, + tensor::DenseHostTensor* c) {} + +// TODO(zhiqiang) This class should be generated by a script offline. +class ElementwiseAddLauncher : public InferShapedKernelLauncher { + public: + static const uint16_t input_tensor_indices[2]; + static const uint16_t num_input_tensors{2}; + static const bool turn_on_infer_shape_cache{true}; + + void Invoke(host_context::KernelFrame* frame) override { + // Build the infershape KernelFrame if needed. + // TODO(Superjomn) add unlikely here. + if (infershape_kernel_frame_builder.IsEmpty()) { + CreateKernelFrameForInferShape(frame); + } + if (turn_on_infer_shape_cache) { + if (IsShapeChanged(input_tensor_indices, num_input_tensors)) { + INFRT_KERNEL(ElementwiseAddInferShape) + (&infershape_kernel_frame_builder); + BuildInferShapeCache(input_tensor_indices, num_input_tensors); + } + } else { + INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder); + BuildInferShapeCache(input_tensor_indices, num_input_tensors); + } + + INFRT_KERNEL(ElementwiseAdd)(frame); + } +}; + +const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1}; + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershape_launchers_test.cc b/paddle/infrt/naive/infershaped/infershape_launchers_test.cc new file mode 100644 index 0000000000..317323d7c5 --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershape_launchers_test.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" +#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h" +#include "paddle/infrt/naive/infershaped/infershaped_registry.h" +#include "paddle/infrt/tensor/dense_host_tensor.h" + +namespace infrt { +namespace naive { + +TEST(ElementwiseAdd, registry) { + InferShapedKernelRegistry registry; + RegisterInferShapeLaunchers(®istry); + ASSERT_EQ(registry.size(), 1UL); + auto creator = registry.GetKernel("elementwise_add"); + auto infershape_launcher_handle = creator(); + // fake some tensors + + tensor::DenseHostTensor a({2, 8}, GetDType()); + tensor::DenseHostTensor b({2, 8}, GetDType()); + tensor::DenseHostTensor c({2, 8}, GetDType()); + + host_context::KernelFrameBuilder kernel_frame_builder; + kernel_frame_builder.AddArgument(new host_context::Value(std::move(a))); + kernel_frame_builder.AddArgument(new host_context::Value(std::move(b))); + kernel_frame_builder.SetResults({new host_context::Value(std::move(c))}); + + infershape_launcher_handle->Invoke(&kernel_frame_builder); +} + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc new file mode 100644 index 0000000000..9ef9d9f2b7 --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" + +namespace infrt { +namespace naive { + +void InferShapedKernelLauncher::CreateKernelFrameForInferShape( + host_context::KernelFrame* frame) { + for (host_context::Value* value : + frame->GetValues(0, frame->GetNumElements())) { + // TODO(Superjomn) To extend this. + if (value->is_type()) { + values.emplace_back(MetaTensor{&value->get()}); + infershape_kernel_frame_builder.AddArgument(values.back().get()); + } else { + infershape_kernel_frame_builder.AddArgument(value); + } + } +} + +void InferShapedKernelLauncher::BuildInferShapeCache( + const uint16_t* input_indices, const uint16_t num_inputs) { + tensor_shape_cache.resize(num_inputs); + for (uint16_t i = 0; i < num_inputs; i++) { + tensor_shape_cache[i] = + infershape_kernel_frame_builder.GetArgAt(input_indices[i]) + ->get() + .shape(); + } +} + +bool InferShapedKernelLauncher::IsShapeChanged( + const uint16_t* input_indices, const uint16_t num_inputs) const { + if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty()) + return true; + + bool changed = false; + for (uint16_t i = 0; i < num_inputs && !changed; i++) { + changed = changed || (tensor_shape_cache[i] != + infershape_kernel_frame_builder + .GetArgAt(input_indices[i]) + .shape()); + } + return changed; +} + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h new file mode 100644 index 0000000000..14c4beaf93 --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/infrt/host_context/kernel_frame.h" +#include "paddle/infrt/host_context/value.h" +#include "paddle/infrt/naive/meta_tensor.h" +#include "paddle/infrt/tensor/dense_host_tensor.h" + +namespace infrt { +namespace naive { + +struct InferShapedKernelLauncher { + virtual void Invoke(host_context::KernelFrame* frame) = 0; + + virtual ~InferShapedKernelLauncher() = default; + + protected: + //! Initialize the kernel frame for InferShape kernel. + // This method will create a new KernelFrame with all the Tensors(currently + // only DenseHostTensor) converted into MetaTensors so that the infer-shape + // function can work with. + // @frame: the frame containing argument list that is same with the ones of + // the corresponding kernel. + void CreateKernelFrameForInferShape(host_context::KernelFrame* frame); + + //! Build or update the infer-shape cache using the latest shape from + //! InferShapeFrame. + void BuildInferShapeCache(const uint16_t* input_indices, + const uint16_t num_inputs); + + //! Compare the latest shape with the shape cache. + bool IsShapeChanged(const uint16_t* input_indices, + const uint16_t num_inputs) const; + + // values to hold the TensorMeta. + llvm::SmallVector values; + llvm::SmallVector tensor_shape_cache; + host_context::KernelFrameBuilder infershape_kernel_frame_builder; +}; + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc b/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc new file mode 100644 index 0000000000..928a43da3e --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h" + +#include "paddle/infrt/naive/infershaped/elementwise_add.h" +#include "paddle/infrt/naive/infershaped/infershaped_registry.h" +namespace infrt { +namespace naive { + +void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) { + registry->AddKernel("elementwise_add", + INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher)); +} + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h b/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h new file mode 100644 index 0000000000..3e83b690bb --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace infrt { +namespace naive { + +struct InferShapedKernelRegistry; + +void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry); + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershaped_registry.cc b/paddle/infrt/naive/infershaped/infershaped_registry.cc new file mode 100644 index 0000000000..94218a9a6f --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_registry.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/naive/infershaped/infershaped_registry.h" + +#include + +#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" + +namespace infrt { +namespace naive { + +struct InferShapedKernelRegistry::Impl { + std::unordered_map data; +}; + +InferShapedKernelRegistry::InferShapedKernelRegistry() + : impl_(std::make_unique()) {} + +void InferShapedKernelRegistry::AddKernel( + const std::string& key, + InferShapedKernelRegistry::InferShapeLauncherCreator&& creator) { + CHECK(!impl_->data.count(key)) << "Item called " << key << " duplicates"; + impl_->data.emplace(key, std::move(creator)); +} + +const InferShapedKernelRegistry::InferShapeLauncherCreator& +InferShapedKernelRegistry::GetKernel(const std::string& key) const { + auto it = impl_->data.find(key); + CHECK(it != impl_->data.end()) << "No item called " << key << " exists"; + return it->second; +} + +size_t InferShapedKernelRegistry::size() const { return impl_->data.size(); } + +InferShapedKernelRegistry* GetInferShapeRegistry() { + static auto registry = std::make_unique(); + return registry.get(); +} + +InferShapedKernelRegistry::~InferShapedKernelRegistry() {} + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershaped_registry.h b/paddle/infrt/naive/infershaped/infershaped_registry.h new file mode 100644 index 0000000000..e0e56a148f --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_registry.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace infrt { +namespace naive { + +struct InferShapedKernelLauncher; + +class InferShapedKernelRegistry { + public: + using InferShapeLauncherHandle = std::unique_ptr; + using InferShapeLauncherCreator = std::function; + + InferShapedKernelRegistry(); + + void AddKernel(const std::string& key, InferShapeLauncherCreator&& creator); + + const InferShapeLauncherCreator& GetKernel(const std::string& key) const; + + size_t size() const; + + ~InferShapedKernelRegistry(); + + private: + struct Impl; + + std::unique_ptr impl_; +}; + +//! The global infershape registry. +InferShapedKernelRegistry* GetInferShapeRegistry(); + +} // namespace naive +} // namespace infrt + +#define INFERSHAPED_KERNEL_CREATOR(infershape_launcher_class_) \ + []() \ + -> ::infrt::naive::InferShapedKernelRegistry::InferShapeLauncherHandle { \ + return std::make_unique(); \ + } diff --git a/paddle/infrt/naive/meta_tensor.cc b/paddle/infrt/naive/meta_tensor.cc new file mode 100644 index 0000000000..2f7ee3a69e --- /dev/null +++ b/paddle/infrt/naive/meta_tensor.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/naive/meta_tensor.h" + +#include "paddle/infrt/tensor/dense_host_tensor.h" +#include "paddle/infrt/tensor/tensor_shape.h" + +namespace infrt { +namespace naive { + +const tensor::TensorShape& MetaTensor::shape() const { + return mutable_tensor_->shape(); +} +tensor::TensorShape* MetaTensor::mutable_shape() { + return mutable_tensor_->mutable_shape(); +} + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/naive/meta_tensor.h b/paddle/infrt/naive/meta_tensor.h new file mode 100644 index 0000000000..4b62f3021a --- /dev/null +++ b/paddle/infrt/naive/meta_tensor.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A naive implementation of MetaTensor +#pragma once +#include "paddle/infrt/common/common.h" + +namespace infrt { +namespace tensor { +struct DenseHostTensor; +struct TensorShape; +} // namespace tensor + +namespace naive { + +class MetaTensor { + public: + MetaTensor() = default; + explicit MetaTensor(tensor::DenseHostTensor* tensor) + : mutable_tensor_(tensor) {} + explicit MetaTensor(const tensor::DenseHostTensor* tensor) + : mutable_tensor_(&Reference(tensor)) {} + explicit MetaTensor(MetaTensor&& other) + : mutable_tensor_(other.mutable_tensor_) {} + explicit MetaTensor(const MetaTensor& other) + : mutable_tensor_(other.mutable_tensor_) {} + + const tensor::TensorShape& shape() const; + tensor::TensorShape* mutable_shape(); + + private: + tensor::DenseHostTensor* mutable_tensor_{}; +}; + +} // namespace naive +} // namespace infrt diff --git a/paddle/infrt/tensor/dense_host_tensor.cc b/paddle/infrt/tensor/dense_host_tensor.cc index e54ab0e5c4..639b0f9f51 100644 --- a/paddle/infrt/tensor/dense_host_tensor.cc +++ b/paddle/infrt/tensor/dense_host_tensor.cc @@ -20,6 +20,10 @@ namespace infrt::tensor { +DenseHostTensor::DenseHostTensor(std::initializer_list&& list, + DType dtype) + : DenseHostTensor(TensorShape(list), dtype) {} + DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype) : HostTensor(TensorMetadata{dtype, shape}) { CHECK(metadata().IsValid()) << "Tensor construct get invalid metadata"; @@ -28,6 +32,9 @@ DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype) } const TensorShape& DenseHostTensor::shape() const { return metadata().shape; } +TensorShape* DenseHostTensor::mutable_shape() { + return &mutable_metadata()->shape; +} void DenseHostTensor::Init(const std::vector& shape, DType dtype) { auto shape_array = llvm::ArrayRef(shape.data(), shape.size()); diff --git a/paddle/infrt/tensor/dense_host_tensor.h b/paddle/infrt/tensor/dense_host_tensor.h index 7821395b54..6003c82118 100644 --- a/paddle/infrt/tensor/dense_host_tensor.h +++ b/paddle/infrt/tensor/dense_host_tensor.h @@ -24,7 +24,8 @@ namespace infrt { class Buffer; } // namespace infrt -namespace infrt::tensor { +namespace infrt { +namespace tensor { enum class DeviceKind { kCPU = 0, @@ -36,6 +37,7 @@ class Tensor { virtual ~Tensor() = default; const TensorMetadata& metadata() const { return metadata_; } + TensorMetadata* mutable_metadata() { return &metadata_; } protected: Tensor() = default; @@ -70,9 +72,11 @@ class DenseHostTensor : public HostTensor { public: DenseHostTensor() = default; DenseHostTensor(const TensorShape& shape, DType dtype); + DenseHostTensor(std::initializer_list&& list, DType dtype); void Init(const std::vector& shape, DType dtype); const TensorShape& shape() const; + TensorShape* mutable_shape(); const Buffer* buffer() const; @@ -89,4 +93,5 @@ class DenseHostTensor : public HostTensor { std::shared_ptr buffer_; }; -} // namespace infrt::tensor +} // namespace tensor +} // namespace infrt diff --git a/paddle/infrt/tensor/tensor_shape.cc b/paddle/infrt/tensor/tensor_shape.cc index 1e6d5c107e..4999500984 100644 --- a/paddle/infrt/tensor/tensor_shape.cc +++ b/paddle/infrt/tensor/tensor_shape.cc @@ -26,6 +26,8 @@ namespace tensor { TensorShape::TensorShape(llvm::ArrayRef dims) : dims_(dims.begin(), dims.end()) {} +TensorShape::TensorShape(std::initializer_list dims) : dims_(dims) {} + int TensorShape::GetRank() const { return dims_.size(); } int64_t TensorShape::GetDim(int idx) const { diff --git a/paddle/infrt/tensor/tensor_shape.h b/paddle/infrt/tensor/tensor_shape.h index cce95072f5..e2232c506e 100644 --- a/paddle/infrt/tensor/tensor_shape.h +++ b/paddle/infrt/tensor/tensor_shape.h @@ -27,6 +27,7 @@ class TensorShape { public: TensorShape() = default; explicit TensorShape(llvm::ArrayRef dims); + explicit TensorShape(std::initializer_list dims); int GetRank() const; @@ -40,6 +41,10 @@ class TensorShape { return a.dims_ == b.dims_; } + friend bool operator!=(const TensorShape& a, const TensorShape& b) { + return !(a == b); + } + private: llvm::SmallVector dims_; }; -- GitLab