未验证 提交 b3e049f8 编写于 作者: 石晓伟 提交者: GitHub

infershaped autogen (PR #1), test=develop (#39405)

上级 1bd7a143
......@@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launchers.cc
)
cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
......@@ -17,6 +17,7 @@
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
......@@ -32,17 +33,19 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
*c->mutable_shape() = a.shape();
}
static void ElementwiseAdd(const tensor::DenseHostTensor& a,
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}
// TODO(zhiqiang) This class should be generated by a script offline.
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
class KernelLauncher : public InferShapedKernelLauncher {
public:
static const uint16_t input_tensor_indices[2];
static const uint16_t num_input_tensors{2};
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};
void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
......@@ -50,21 +53,16 @@ class ElementwiseAddLauncher : public InferShapedKernelLauncher {
CreateKernelFrameForInferShape(frame);
}
if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape)
(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&infershape_kernel_frame_builder);
BuildInferShapeCache(num_input_tensors);
}
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}
INFRT_KERNEL(ElementwiseAdd)(frame);
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
}
};
const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};
} // namespace naive
} // namespace infrt
......@@ -17,11 +17,24 @@
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
}
TEST(utils, registry) {
constexpr uint8_t count =
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
CHECK_EQ(count, 2U);
}
TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
......@@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
tensor::DenseHostTensor c({2, 8}, GetDType<float>());
host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
......
......@@ -20,7 +20,7 @@ namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) {
frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
......@@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
}
void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) {
const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
->get<MetaTensor>()
.shape();
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
}
}
bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const {
const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true;
bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] !=
infershape_kernel_frame_builder
.GetArgAt<MetaTensor>(input_indices[i])
.shape());
changed = changed ||
(tensor_shape_cache[i] !=
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
}
return changed;
}
......
......@@ -39,12 +39,10 @@ struct InferShapedKernelLauncher {
//! Build or update the infer-shape cache using the latest shape from
//! InferShapeFrame.
void BuildInferShapeCache(const uint16_t* input_indices,
const uint16_t num_inputs);
void BuildInferShapeCache(const uint16_t num_inputs);
//! Compare the latest shape with the shape cache.
bool IsShapeChanged(const uint16_t* input_indices,
const uint16_t num_inputs) const;
bool IsShapeChanged(const uint16_t num_inputs) const;
// values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values;
......
......@@ -13,12 +13,18 @@
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
namespace infrt {
namespace naive {
using ElementwiseAddLauncher =
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>;
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
namespace infershaped {
using KeyType = const tensor::DenseHostTensor&;
using CountType = uint8_t;
constexpr CountType value(std::true_type) { return 1; }
constexpr CountType value(std::false_type) { return 0; }
template <typename T>
constexpr CountType value() {
return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
}
template <typename FirstArg>
constexpr CountType count(CountType num) {
return num;
}
template <typename FirstArg>
constexpr CountType count() {
return 0;
}
template <>
constexpr CountType count<KeyType>(CountType num) {
return num + 1;
}
template <>
constexpr CountType count<KeyType>() {
return 1;
}
template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count(CountType num) {
return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
}
template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count() {
return count<SecondArg, RestOfArgs...>(value<FirstArg>());
}
} // namespace infershaped
template <typename F>
struct InferShapeHelper;
template <typename Return, typename... Args>
struct InferShapeHelper<Return (*)(Args...)> {
static constexpr int count = infershaped::count<Args...>();
};
} // namespace naive
} // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册