未验证 提交 b3e049f8 编写于 作者: 石晓伟 提交者: GitHub

infershaped autogen (PR #1), test=develop (#39405)

上级 1bd7a143
...@@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc ...@@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launchers.cc infershaped/infershaped_kernel_launchers.cc
) )
cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt) cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
// This file contains a example of the infershape ElementwiseAdd kernel. // This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script. // Some of the following code should be generated from PTEN by script.
...@@ -32,17 +33,19 @@ static void ElementwiseAddInferShape(const MetaTensor& a, ...@@ -32,17 +33,19 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
*c->mutable_shape() = a.shape(); *c->mutable_shape() = a.shape();
} }
static void ElementwiseAdd(const tensor::DenseHostTensor& a, static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b, const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {} tensor::DenseHostTensor* c) {}
// TODO(zhiqiang) This class should be generated by a script offline. template <typename KernelFunc,
class ElementwiseAddLauncher : public InferShapedKernelLauncher { KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
class KernelLauncher : public InferShapedKernelLauncher {
public: public:
static const uint16_t input_tensor_indices[2]; static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const uint16_t num_input_tensors{2};
static const bool turn_on_infer_shape_cache{true}; static const bool turn_on_infer_shape_cache{true};
void Invoke(host_context::KernelFrame* frame) override { void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed. // Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here. // TODO(Superjomn) add unlikely here.
...@@ -50,21 +53,16 @@ class ElementwiseAddLauncher : public InferShapedKernelLauncher { ...@@ -50,21 +53,16 @@ class ElementwiseAddLauncher : public InferShapedKernelLauncher {
CreateKernelFrameForInferShape(frame); CreateKernelFrameForInferShape(frame);
} }
if (turn_on_infer_shape_cache) { if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) { if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape) ::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
(&infershape_kernel_frame_builder); &infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors); BuildInferShapeCache(num_input_tensors);
} }
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
} }
INFRT_KERNEL(ElementwiseAdd)(frame); ::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
} }
}; };
const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};
} // namespace naive } // namespace naive
} // namespace infrt } // namespace infrt
...@@ -17,11 +17,24 @@ ...@@ -17,11 +17,24 @@
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h" #include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h" #include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h" #include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt { namespace infrt {
namespace naive { namespace naive {
namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
}
TEST(utils, registry) {
constexpr uint8_t count =
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
CHECK_EQ(count, 2U);
}
TEST(ElementwiseAdd, registry) { TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry; InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry); RegisterInferShapeLaunchers(&registry);
...@@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) { ...@@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
tensor::DenseHostTensor c({2, 8}, GetDType<float>()); tensor::DenseHostTensor c({2, 8}, GetDType<float>());
host_context::KernelFrameBuilder kernel_frame_builder; host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a))); kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b))); kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))}); kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
......
...@@ -20,7 +20,7 @@ namespace naive { ...@@ -20,7 +20,7 @@ namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape( void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) { host_context::KernelFrame* frame) {
for (host_context::Value* value : for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) { frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this. // TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) { if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()}); values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
...@@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape( ...@@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
} }
void InferShapedKernelLauncher::BuildInferShapeCache( void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) { const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs); tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) { for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] = tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i]) infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
->get<MetaTensor>()
.shape();
} }
} }
bool InferShapedKernelLauncher::IsShapeChanged( bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const { const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty()) if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true; return true;
bool changed = false; bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) { for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] != changed = changed ||
infershape_kernel_frame_builder (tensor_shape_cache[i] !=
.GetArgAt<MetaTensor>(input_indices[i]) infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
.shape());
} }
return changed; return changed;
} }
......
...@@ -39,12 +39,10 @@ struct InferShapedKernelLauncher { ...@@ -39,12 +39,10 @@ struct InferShapedKernelLauncher {
//! Build or update the infer-shape cache using the latest shape from //! Build or update the infer-shape cache using the latest shape from
//! InferShapeFrame. //! InferShapeFrame.
void BuildInferShapeCache(const uint16_t* input_indices, void BuildInferShapeCache(const uint16_t num_inputs);
const uint16_t num_inputs);
//! Compare the latest shape with the shape cache. //! Compare the latest shape with the shape cache.
bool IsShapeChanged(const uint16_t* input_indices, bool IsShapeChanged(const uint16_t num_inputs) const;
const uint16_t num_inputs) const;
// values to hold the TensorMeta. // values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values; llvm::SmallVector<host_context::ValueRef, 3> values;
......
...@@ -13,12 +13,18 @@ ...@@ -13,12 +13,18 @@
// limitations under the License. // limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h" #include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/elementwise_add.h" #include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h" #include "paddle/infrt/naive/infershaped/infershaped_registry.h"
namespace infrt { namespace infrt {
namespace naive { namespace naive {
using ElementwiseAddLauncher =
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>;
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) { void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add", registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher)); INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
namespace infershaped {
using KeyType = const tensor::DenseHostTensor&;
using CountType = uint8_t;
constexpr CountType value(std::true_type) { return 1; }
constexpr CountType value(std::false_type) { return 0; }
template <typename T>
constexpr CountType value() {
return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
}
template <typename FirstArg>
constexpr CountType count(CountType num) {
return num;
}
template <typename FirstArg>
constexpr CountType count() {
return 0;
}
template <>
constexpr CountType count<KeyType>(CountType num) {
return num + 1;
}
template <>
constexpr CountType count<KeyType>() {
return 1;
}
template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count(CountType num) {
return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
}
template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count() {
return count<SecondArg, RestOfArgs...>(value<FirstArg>());
}
} // namespace infershaped
template <typename F>
struct InferShapeHelper;
template <typename Return, typename... Args>
struct InferShapeHelper<Return (*)(Args...)> {
static constexpr int count = infershaped::count<Args...>();
};
} // namespace naive
} // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册