未验证 提交 eacfc1eb 编写于 作者: Y Yan Chunwei 提交者: GitHub

INFRT/Add infershape schedule (3rd PR) (#39290)

上级 65805227
...@@ -76,6 +76,7 @@ add_subdirectory(tensor) ...@@ -76,6 +76,7 @@ add_subdirectory(tensor)
add_subdirectory(support) add_subdirectory(support)
add_subdirectory(external_kernels) add_subdirectory(external_kernels)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(naive)
add_subdirectory(tests) add_subdirectory(tests)
...@@ -93,7 +94,7 @@ set(infrt_mlir_incs ...@@ -93,7 +94,7 @@ set(infrt_mlir_incs
) )
message(STATUS "infrt srcs:\n${infrt_src}") message(STATUS "infrt srcs:\n${infrt_src}")
cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto) cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto infrt_naive)
cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto) cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto)
add_dependencies(infrt ${infrt_mlir_incs}) add_dependencies(infrt ${infrt_mlir_incs})
......
...@@ -31,7 +31,7 @@ namespace host_context { ...@@ -31,7 +31,7 @@ namespace host_context {
class KernelFrame { class KernelFrame {
public: public:
int GetNumArgs() const { return num_arguments_; } int GetNumArgs() const { return num_arguments_; }
int GetNumResults() const { return num_results_; } int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; }
int GetNumAttributes() const { int GetNumAttributes() const {
return value_or_attrs_.size() - num_arguments_ - return value_or_attrs_.size() - num_arguments_ -
(num_results_ == -1 ? 0 : num_results_); (num_results_ == -1 ? 0 : num_results_);
...@@ -45,6 +45,9 @@ class KernelFrame { ...@@ -45,6 +45,9 @@ class KernelFrame {
return value_or_attrs_[index]->template get_or_default<T>(); return value_or_attrs_[index]->template get_or_default<T>();
} }
// Get number of elements, either input, attributes or results.
size_t GetNumElements() const { return value_or_attrs_.size(); }
template <typename T> template <typename T>
T& GetArgAt(int index) { T& GetArgAt(int index) {
CHECK_LT(index, GetNumArgs()); CHECK_LT(index, GetNumArgs());
...@@ -118,6 +121,8 @@ class KernelFrame { ...@@ -118,6 +121,8 @@ class KernelFrame {
return llvm::makeMutableArrayRef(&value_or_attrs_[from], length); return llvm::makeMutableArrayRef(&value_or_attrs_[from], length);
} }
bool IsEmpty() const { return value_or_attrs_.empty(); }
protected: protected:
int num_arguments_{}; int num_arguments_{};
int num_results_{-1}; int num_results_{-1};
......
...@@ -24,6 +24,8 @@ ValueRef::ValueRef(int64_t val) : Shared<Value>(new Value(val)) {} ...@@ -24,6 +24,8 @@ ValueRef::ValueRef(int64_t val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(float val) : Shared<Value>(new Value(val)) {} ValueRef::ValueRef(float val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(double val) : Shared<Value>(new Value(val)) {} ValueRef::ValueRef(double val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(bool val) : Shared<Value>(new Value(val)) {} ValueRef::ValueRef(bool val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(naive::MetaTensor&& val)
: Shared<Value>(new Value(std::move(val))) {}
const char* Value::type_info() const { return __type_info__; } const char* Value::type_info() const { return __type_info__; }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/infrt/common/object.h" #include "paddle/infrt/common/object.h"
#include "paddle/infrt/common/shared.h" #include "paddle/infrt/common/shared.h"
#include "paddle/infrt/host_context/function.h" #include "paddle/infrt/host_context/function.h"
#include "paddle/infrt/naive/meta_tensor.h"
#include "paddle/infrt/support/variant.h" #include "paddle/infrt/support/variant.h"
#include "paddle/infrt/tensor/dense_host_tensor.h" #include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/dense_tensor_view.h" #include "paddle/infrt/tensor/dense_tensor_view.h"
...@@ -50,6 +51,7 @@ using ValueVariantType = Variant<int16_t, ...@@ -50,6 +51,7 @@ using ValueVariantType = Variant<int16_t,
tensor::TensorMap, tensor::TensorMap,
// pten::CPUContext, // pten::CPUContext,
// pten::DenseTensor, // pten::DenseTensor,
naive::MetaTensor,
std::vector<int16_t>, std::vector<int16_t>,
std::vector<int32_t>, std::vector<int32_t>,
std::vector<int64_t>, std::vector<int64_t>,
...@@ -82,6 +84,7 @@ class Value : public common::Object { ...@@ -82,6 +84,7 @@ class Value : public common::Object {
explicit Value(tensor::TensorShape&& x) : data(std::move(x)) {} explicit Value(tensor::TensorShape&& x) : data(std::move(x)) {}
explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {} explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {}
explicit Value(MlirFunctionExecutable* x) : data(x) {} explicit Value(MlirFunctionExecutable* x) : data(x) {}
explicit Value(naive::MetaTensor&& x) : data(std::move(x)) {}
template <typename T> template <typename T>
const T& get() const { const T& get() const {
...@@ -113,6 +116,11 @@ class Value : public common::Object { ...@@ -113,6 +116,11 @@ class Value : public common::Object {
bool valid() const { return true; } bool valid() const { return true; }
template <typename T>
bool is_type() const {
return data.template is<T>();
}
const char* type_info() const override; const char* type_info() const override;
friend void CopyTo(const Value& from, Value* to); friend void CopyTo(const Value& from, Value* to);
...@@ -134,6 +142,7 @@ class ValueRef : common::Shared<Value> { ...@@ -134,6 +142,7 @@ class ValueRef : common::Shared<Value> {
explicit ValueRef(float val); explicit ValueRef(float val);
explicit ValueRef(double val); explicit ValueRef(double val);
explicit ValueRef(bool val); explicit ValueRef(bool val);
explicit ValueRef(naive::MetaTensor&& val);
using common::Shared<Value>::get; using common::Shared<Value>::get;
using common::Shared<Value>::Reset; using common::Shared<Value>::Reset;
......
cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launcher.cc
infershaped/infershaped_registry.cc
infershaped/infershaped_kernel_launchers.cc
)
cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/SmallVector.h>
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
namespace infrt {
namespace naive {
static void ElementwiseAddInferShape(const MetaTensor& a,
const MetaTensor& b,
MetaTensor* c) {
CHECK(a.shape() == b.shape())
<< "ElementwiseAdd, but shapes of a b are not match";
*c->mutable_shape() = a.shape();
}
static void ElementwiseAdd(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}
// TODO(zhiqiang) This class should be generated by a script offline.
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
public:
static const uint16_t input_tensor_indices[2];
static const uint16_t num_input_tensors{2};
static const bool turn_on_infer_shape_cache{true};
void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame);
}
if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape)
(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}
INFRT_KERNEL(ElementwiseAdd)(frame);
}
};
const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
ASSERT_EQ(registry.size(), 1UL);
auto creator = registry.GetKernel("elementwise_add");
auto infershape_launcher_handle = creator();
// fake some tensors
tensor::DenseHostTensor a({2, 8}, GetDType<float>());
tensor::DenseHostTensor b({2, 8}, GetDType<float>());
tensor::DenseHostTensor c({2, 8}, GetDType<float>());
host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
infershape_launcher_handle->Invoke(&kernel_frame_builder);
}
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
namespace infrt {
namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
infershape_kernel_frame_builder.AddArgument(values.back().get());
} else {
infershape_kernel_frame_builder.AddArgument(value);
}
}
}
void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
->get<MetaTensor>()
.shape();
}
}
bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true;
bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] !=
infershape_kernel_frame_builder
.GetArgAt<MetaTensor>(input_indices[i])
.shape());
}
return changed;
}
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/SmallVector.h>
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/naive/meta_tensor.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
struct InferShapedKernelLauncher {
virtual void Invoke(host_context::KernelFrame* frame) = 0;
virtual ~InferShapedKernelLauncher() = default;
protected:
//! Initialize the kernel frame for InferShape kernel.
// This method will create a new KernelFrame with all the Tensors(currently
// only DenseHostTensor) converted into MetaTensors so that the infer-shape
// function can work with.
// @frame: the frame containing argument list that is same with the ones of
// the corresponding kernel.
void CreateKernelFrameForInferShape(host_context::KernelFrame* frame);
//! Build or update the infer-shape cache using the latest shape from
//! InferShapeFrame.
void BuildInferShapeCache(const uint16_t* input_indices,
const uint16_t num_inputs);
//! Compare the latest shape with the shape cache.
bool IsShapeChanged(const uint16_t* input_indices,
const uint16_t num_inputs) const;
// values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values;
llvm::SmallVector<tensor::TensorShape, 3> tensor_shape_cache;
host_context::KernelFrameBuilder infershape_kernel_frame_builder;
};
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
namespace infrt {
namespace naive {
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
}
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace infrt {
namespace naive {
struct InferShapedKernelRegistry;
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry);
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include <unordered_map>
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
namespace infrt {
namespace naive {
struct InferShapedKernelRegistry::Impl {
std::unordered_map<std::string, InferShapeLauncherCreator> data;
};
InferShapedKernelRegistry::InferShapedKernelRegistry()
: impl_(std::make_unique<Impl>()) {}
void InferShapedKernelRegistry::AddKernel(
const std::string& key,
InferShapedKernelRegistry::InferShapeLauncherCreator&& creator) {
CHECK(!impl_->data.count(key)) << "Item called " << key << " duplicates";
impl_->data.emplace(key, std::move(creator));
}
const InferShapedKernelRegistry::InferShapeLauncherCreator&
InferShapedKernelRegistry::GetKernel(const std::string& key) const {
auto it = impl_->data.find(key);
CHECK(it != impl_->data.end()) << "No item called " << key << " exists";
return it->second;
}
size_t InferShapedKernelRegistry::size() const { return impl_->data.size(); }
InferShapedKernelRegistry* GetInferShapeRegistry() {
static auto registry = std::make_unique<InferShapedKernelRegistry>();
return registry.get();
}
InferShapedKernelRegistry::~InferShapedKernelRegistry() {}
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
namespace infrt {
namespace naive {
struct InferShapedKernelLauncher;
class InferShapedKernelRegistry {
public:
using InferShapeLauncherHandle = std::unique_ptr<InferShapedKernelLauncher>;
using InferShapeLauncherCreator = std::function<InferShapeLauncherHandle()>;
InferShapedKernelRegistry();
void AddKernel(const std::string& key, InferShapeLauncherCreator&& creator);
const InferShapeLauncherCreator& GetKernel(const std::string& key) const;
size_t size() const;
~InferShapedKernelRegistry();
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
//! The global infershape registry.
InferShapedKernelRegistry* GetInferShapeRegistry();
} // namespace naive
} // namespace infrt
#define INFERSHAPED_KERNEL_CREATOR(infershape_launcher_class_) \
[]() \
-> ::infrt::naive::InferShapedKernelRegistry::InferShapeLauncherHandle { \
return std::make_unique<infershape_launcher_class_>(); \
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/meta_tensor.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt {
namespace naive {
const tensor::TensorShape& MetaTensor::shape() const {
return mutable_tensor_->shape();
}
tensor::TensorShape* MetaTensor::mutable_shape() {
return mutable_tensor_->mutable_shape();
}
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// A naive implementation of MetaTensor
#pragma once
#include "paddle/infrt/common/common.h"
namespace infrt {
namespace tensor {
struct DenseHostTensor;
struct TensorShape;
} // namespace tensor
namespace naive {
class MetaTensor {
public:
MetaTensor() = default;
explicit MetaTensor(tensor::DenseHostTensor* tensor)
: mutable_tensor_(tensor) {}
explicit MetaTensor(const tensor::DenseHostTensor* tensor)
: mutable_tensor_(&Reference(tensor)) {}
explicit MetaTensor(MetaTensor&& other)
: mutable_tensor_(other.mutable_tensor_) {}
explicit MetaTensor(const MetaTensor& other)
: mutable_tensor_(other.mutable_tensor_) {}
const tensor::TensorShape& shape() const;
tensor::TensorShape* mutable_shape();
private:
tensor::DenseHostTensor* mutable_tensor_{};
};
} // namespace naive
} // namespace infrt
...@@ -20,6 +20,10 @@ ...@@ -20,6 +20,10 @@
namespace infrt::tensor { namespace infrt::tensor {
DenseHostTensor::DenseHostTensor(std::initializer_list<int64_t>&& list,
DType dtype)
: DenseHostTensor(TensorShape(list), dtype) {}
DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype) DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype)
: HostTensor(TensorMetadata{dtype, shape}) { : HostTensor(TensorMetadata{dtype, shape}) {
CHECK(metadata().IsValid()) << "Tensor construct get invalid metadata"; CHECK(metadata().IsValid()) << "Tensor construct get invalid metadata";
...@@ -28,6 +32,9 @@ DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype) ...@@ -28,6 +32,9 @@ DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype)
} }
const TensorShape& DenseHostTensor::shape() const { return metadata().shape; } const TensorShape& DenseHostTensor::shape() const { return metadata().shape; }
TensorShape* DenseHostTensor::mutable_shape() {
return &mutable_metadata()->shape;
}
void DenseHostTensor::Init(const std::vector<int64_t>& shape, DType dtype) { void DenseHostTensor::Init(const std::vector<int64_t>& shape, DType dtype) {
auto shape_array = llvm::ArrayRef<int64_t>(shape.data(), shape.size()); auto shape_array = llvm::ArrayRef<int64_t>(shape.data(), shape.size());
......
...@@ -24,7 +24,8 @@ namespace infrt { ...@@ -24,7 +24,8 @@ namespace infrt {
class Buffer; class Buffer;
} // namespace infrt } // namespace infrt
namespace infrt::tensor { namespace infrt {
namespace tensor {
enum class DeviceKind { enum class DeviceKind {
kCPU = 0, kCPU = 0,
...@@ -36,6 +37,7 @@ class Tensor { ...@@ -36,6 +37,7 @@ class Tensor {
virtual ~Tensor() = default; virtual ~Tensor() = default;
const TensorMetadata& metadata() const { return metadata_; } const TensorMetadata& metadata() const { return metadata_; }
TensorMetadata* mutable_metadata() { return &metadata_; }
protected: protected:
Tensor() = default; Tensor() = default;
...@@ -70,9 +72,11 @@ class DenseHostTensor : public HostTensor { ...@@ -70,9 +72,11 @@ class DenseHostTensor : public HostTensor {
public: public:
DenseHostTensor() = default; DenseHostTensor() = default;
DenseHostTensor(const TensorShape& shape, DType dtype); DenseHostTensor(const TensorShape& shape, DType dtype);
DenseHostTensor(std::initializer_list<int64_t>&& list, DType dtype);
void Init(const std::vector<int64_t>& shape, DType dtype); void Init(const std::vector<int64_t>& shape, DType dtype);
const TensorShape& shape() const; const TensorShape& shape() const;
TensorShape* mutable_shape();
const Buffer* buffer() const; const Buffer* buffer() const;
...@@ -89,4 +93,5 @@ class DenseHostTensor : public HostTensor { ...@@ -89,4 +93,5 @@ class DenseHostTensor : public HostTensor {
std::shared_ptr<Buffer> buffer_; std::shared_ptr<Buffer> buffer_;
}; };
} // namespace infrt::tensor } // namespace tensor
} // namespace infrt
...@@ -26,6 +26,8 @@ namespace tensor { ...@@ -26,6 +26,8 @@ namespace tensor {
TensorShape::TensorShape(llvm::ArrayRef<int64_t> dims) TensorShape::TensorShape(llvm::ArrayRef<int64_t> dims)
: dims_(dims.begin(), dims.end()) {} : dims_(dims.begin(), dims.end()) {}
TensorShape::TensorShape(std::initializer_list<int64_t> dims) : dims_(dims) {}
int TensorShape::GetRank() const { return dims_.size(); } int TensorShape::GetRank() const { return dims_.size(); }
int64_t TensorShape::GetDim(int idx) const { int64_t TensorShape::GetDim(int idx) const {
......
...@@ -27,6 +27,7 @@ class TensorShape { ...@@ -27,6 +27,7 @@ class TensorShape {
public: public:
TensorShape() = default; TensorShape() = default;
explicit TensorShape(llvm::ArrayRef<int64_t> dims); explicit TensorShape(llvm::ArrayRef<int64_t> dims);
explicit TensorShape(std::initializer_list<int64_t> dims);
int GetRank() const; int GetRank() const;
...@@ -40,6 +41,10 @@ class TensorShape { ...@@ -40,6 +41,10 @@ class TensorShape {
return a.dims_ == b.dims_; return a.dims_ == b.dims_;
} }
friend bool operator!=(const TensorShape& a, const TensorShape& b) {
return !(a == b);
}
private: private:
llvm::SmallVector<int64_t, 4> dims_; llvm::SmallVector<int64_t, 4> dims_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册