未验证 提交 8f2d14ad 编写于 作者: 石晓伟 提交者: GitHub

change classes to pten, test=develop (#39643)

上级 1035d21f
......@@ -82,7 +82,6 @@ add_subdirectory(tensor)
add_subdirectory(support)
add_subdirectory(external_kernels)
add_subdirectory(paddle)
add_subdirectory(naive)
add_subdirectory(tests)
......@@ -99,14 +98,15 @@ set(infrt_mlir_incs
trt_ops_inc
)
if (INFRT_WITH_PTEN)
set(pten_libs pten)
set(infrt_mlir_incs ${infrt_mlir_incs}
MLIRinfrt_pten_tensorIncGen
MLIRinfrt_pten_baseIncGen
)
endif()
cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto infrt_naive)
cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto)
cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${pten_libs} paddle_framework_proto infrt_naive)
cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${pten_libs} paddle_framework_proto)
add_dependencies(infrt ${infrt_mlir_incs} mlir-headers)
add_custom_target(test_infrt_exec DEPENDS ${INFRT_TEST_TARGETS})
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/allocator.h"
namespace infrt {
namespace backends {
class CpuPtenAllocator : public pten::Allocator {
public:
static void deleter(pten::Allocation* ptr) { ::operator delete(ptr); }
AllocationPtr Allocate(size_t bytes_size) {
return AllocationPtr(
new pten::Allocation(::operator new(bytes_size),
bytes_size,
pten::Place(pten::AllocationType::CPU)),
deleter);
}
};
} // namespace backends
} // namespace infrt
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
namespace infrt {
namespace backends {
class CpuPtenContext : public pten::CPUContext {
public:
using Base = pten::CPUContext;
using pten::CPUContext::SetEigenDevice;
};
} // namespace backends
} // namespace infrt
......@@ -5,6 +5,7 @@ endif()
#mlir_tablegen_on(infrt_pten_base DIALECT pten)
add_mlir_dialect(infrt_pten_base pten)
add_mlir_dialect(infrt_pten_tensor pten_dt)
add_mlir_dialect(infrt_pten_kernel pten_kernel)
#mlir_tablegen_on(infrt_pten_tensor)
gather_srcs(infrt_src SRCS
......
#ifndef PTEN_KERNEL
#define PTEN_KERNEL
include "paddle/infrt/dialect/pten/infrt_pten_tensor.td"
def PTEN_KernelDialect : Dialect {
let name = "pten_kernel";
let description = [{
The PTEN Kernel dialect.
}];
let cppNamespace = "::infrt::pten";
}
// PTEN Kernel related ops.
class PDT_Kernel<string mnemonic, list<OpTrait> traits = []> : Op<PTEN_KernelDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
}
def FakeKernelOp : PDT_Kernel<"pten.matmul.host.fp32"> {
let arguments = (ins CPU_Context:$dev_ctx, TensorType:$x, TensorType:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y);
let results = (outs TensorType:$output);
}
#endif
......@@ -33,6 +33,7 @@
#include "paddle/infrt/dialect/pten/infrt_pten_tensorTypes.h.inc"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/pten/pten_base.h"
// NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pten/infrt_pten_tensor.h.inc"
......@@ -21,84 +21,36 @@ def PTEN_DenseTensorDialect : Dialect {
class PDT_Op<string mnemonic, list<OpTrait> traits = []> : Op<PTEN_DenseTensorDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
}
class CreateUninitTensorOp<string dtype>
: PDT_Op<"create_uninit_tensor." # dtype, [NoSideEffect]> {
let summary = "pdt.create_uninit_tensor operation";
let description = [{
An operation that creates an uninitialized tensor.
}];
let arguments = (ins I64ArrayAttr:$shape);
let results = (outs TensorType:$output);
}
class CreateInitedTensorOp<string dtype, Attr array_attr>
: PDT_Op<"create_inited_tensor." #dtype, [NoSideEffect]> {
let summary = "pdt.create_inited_tensor operation";
let description = [{
An operation that creates an tensor with shape and values assigned.
}];
let arguments = (ins I64ArrayAttr:$shape, array_attr:$values);
class CreateDenseTensorOp<string place, string dtype, string layout>
: PDT_Op<"create_dense_tensor." # place # "." # dtype # "." # layout, [NoSideEffect]> {
let arguments = (ins CPU_Allocator:$allocator, I64ArrayAttr:$dims, I64ArrayAttr:$lod);
let results = (outs TensorType:$output);
}
def PrintTensorOp : PDT_Op<"print_tensor"> {
let summary = "pdt.print_tensor operation";
let description = [{
An operation that prints a tensor.
}];
let arguments = (ins TensorType:$input);
let results = (outs);
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
}
class FillTensor<string dtype, Attr attr_type> :
PDT_Op<"fill_tensor." # dtype> {
let summary = "dt.fill_tensor operation";
let description = [{
An operation that fills an input tensor with a values.
}];
class FillDenseTensorOp<Attr attr_type, string dtype> :
PDT_Op<"fill_dense_tensor." # dtype> {
let arguments = (ins
TensorType:$input,
attr_type:$value
);
let results = (outs);
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
}
class FillTensorWithConstantOp<string dtype> :
PDT_Op<"fill_tensor_with_constant." # dtype> {
let summary = "dt.fill_tensor_with_constant operation";
let description = [{
An operation that fills an input tensor with a single value.
}];
let arguments = (ins
TensorType:$input,
AnyAttr:$value
);
let results = (outs);
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
class CreateCPUAllocatorOp
: PDT_Op<"create_allocator." # "cpu", [NoSideEffect]> {
let arguments = (ins);
let results = (outs CPU_Allocator:$output);
}
foreach dtype = ["ui8", "ui16", "ui32", "ui64", "i32", "f32", "f64", "i64"] in {
def PDT_CreateUninitTensorOp_#dtype : CreateUninitTensorOp<dtype>;
def PDT_FillTensorWithConstantOp_#dtype : FillTensorWithConstantOp<dtype>;
class CreateCPUContextOp
: PDT_Op<"create_context." # "cpu", [NoSideEffect]> {
let arguments = (ins);
let results = (outs CPU_Context:$output);
}
def PDT_FillTensor_f32: FillTensor<"f32", F32ArrayAttr>;
def PDT_FillTensor_i32: FillTensor<"i32", I32ArrayAttr>;
def PDT_CreateInitedTensorOp_f32 : CreateInitedTensorOp<"f32", F32ArrayAttr>;
def PDT_CreateInitedTensorOp_i32 : CreateInitedTensorOp<"i32", I32ArrayAttr>;
def PDT_CreateDenseTensorOp_cpu_f32_nchw : CreateDenseTensorOp<"cpu", "f32", "nchw">;
def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp<F32ArrayAttr, "f32">;
def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp;
def PDT_CreateContextOp_cpu : CreateCPUContextOp;
#endif
......@@ -29,7 +29,23 @@ namespace pten {
void PTENDialect::printType(::mlir::Type type,
mlir::DialectAsmPrinter& os) const {
Dialect::printType(type, os);
if (type.isa<CPUAllocatorType>()) {
os << "CPU_Allocator";
return;
}
if (type.isa<GPUAllocatorType>()) {
os << "GPU_Allocator";
return;
}
if (type.isa<CPUContextType>()) {
os << "CPU_Context";
return;
}
if (type.isa<GPUContextType>()) {
os << "GPU_Context";
return;
}
llvm_unreachable("unexpected 'allocator/context' type kind");
}
void PTENDialect::initialize() {
......@@ -46,14 +62,16 @@ void PTENDialect::initialize() {
mlir::Type PTENDialect::parseType(mlir::DialectAsmParser& parser) const {
llvm::StringRef keyword;
if (parser.parseKeyword(&keyword)) return mlir::Type();
if (keyword == "allocator_CPU") {
if (keyword == "CPU_allocator") {
return CPUAllocatorType::get(parser.getContext());
} else if (keyword == "allocator_GPU") {
} else if (keyword == "GPU_allocator") {
return GPUAllocatorType::get(parser.getContext());
} else if (keyword == "context_CPU") {
} else if (keyword == "CPU_context") {
return CPUContextType::get(parser.getContext());
} else if (keyword == "context_GPU") {
} else if (keyword == "GPU_context") {
return GPUContextType::get(parser.getContext());
} else {
llvm_unreachable("unexpected 'allocator/context' type kind");
}
return mlir::Type();
......
......@@ -14,6 +14,7 @@
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <vector>
......@@ -23,7 +24,7 @@ namespace host_context {
class KernelFrame;
using KernelImplementation = void (*)(KernelFrame *frame);
using KernelImplementation = std::function<void(KernelFrame *frame)>;
/**
* Hold the kernels registered in the system.
......
......@@ -28,7 +28,7 @@ TEST(KernelRegistry, basic) {
std::string key = "infrt.test.add.i32";
registry.AddKernel(key, INFRT_KERNEL(add_i32));
auto* kernel_impl = registry.GetKernel(key);
const auto& kernel_impl = registry.GetKernel(key);
ASSERT_TRUE(kernel_impl);
ValueRef a(1);
......
......@@ -28,6 +28,9 @@
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
#ifdef INFRT_WITH_PTEN
#include "paddle/infrt/kernel/pten/registry.h"
#endif
static llvm::cl::list<std::string> cl_shared_libs( // NOLINT
"shared_libs",
......@@ -53,6 +56,9 @@ int main(int argc, char** argv) {
kernel::RegisterTensorShapeKernels(&registry);
kernel::RegisterTensorKernels(&registry);
kernel::RegisterControlFlowKernels(&registry);
#ifdef INFRT_WITH_PTEN
kernel::RegisterPtenKernels(&registry);
#endif
// load extra shared library
for (const auto& lib_path : cl_shared_libs) {
......
......@@ -24,7 +24,13 @@ ValueRef::ValueRef(int64_t val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(float val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(double val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(bool val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(naive::MetaTensor&& val)
ValueRef::ValueRef(backends::CpuPtenContext&& val)
: Shared<Value>(new Value(std::move(val))) {}
ValueRef::ValueRef(::pten::CPUContext&& val)
: Shared<Value>(new Value(std::move(val))) {}
ValueRef::ValueRef(::pten::DenseTensor&& val)
: Shared<Value>(new Value(std::move(val))) {}
ValueRef::ValueRef(::pten::MetaTensor&& val)
: Shared<Value>(new Value(std::move(val))) {}
const char* Value::type_info() const { return __type_info__; }
......@@ -36,31 +42,31 @@ void CopyTo(const Value& from, Value* to) {
[&](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if (std::is_same<T, int16_t>::value)
to->data = arg;
to->data = reinterpret_cast<int16_t const&>(arg);
else if (std::is_same<T, int32_t>::value)
to->data = arg;
to->data = reinterpret_cast<int32_t const&>(arg);
else if (std::is_same<T, float>::value)
to->data = arg;
to->data = reinterpret_cast<float const&>(arg);
else if (std::is_same<T, double>::value)
to->data = arg;
to->data = reinterpret_cast<double const&>(arg);
else if (std::is_same<T, uint32_t>::value)
to->data = arg;
to->data = reinterpret_cast<uint32_t const&>(arg);
else if (std::is_same<T, uint64_t>::value)
to->data = arg;
to->data = reinterpret_cast<uint64_t const&>(arg);
else if (std::is_same<T, bool>::value)
to->data = arg;
to->data = reinterpret_cast<bool const&>(arg);
else if (std::is_same<T, tensor::TensorShape>::value)
to->data = arg;
to->data = reinterpret_cast<tensor::TensorShape const&>(arg);
else if (std::is_same<T, MlirFunctionExecutable*>::value)
to->data = arg;
to->data = reinterpret_cast<MlirFunctionExecutable* const&>(arg);
else if (std::is_same<T, tensor::DenseHostTensor>::value)
to->data = arg;
to->data = reinterpret_cast<tensor::DenseHostTensor const&>(arg);
else if (std::is_same<T, std::vector<int16_t>>::value)
to->data = arg;
to->data = reinterpret_cast<std::vector<int16_t> const&>(arg);
else if (std::is_same<T, std::vector<int64_t>>::value)
to->data = arg;
to->data = reinterpret_cast<std::vector<int64_t> const&>(arg);
else if (std::is_same<T, tensor::TensorMap>::value)
to->data = arg;
to->data = reinterpret_cast<tensor::TensorMap const&>(arg);
else
LOG(FATAL) << "Not supported Value copy: " << typeid(T).name();
},
......
......@@ -23,15 +23,19 @@
#include "paddle/infrt/common/object.h"
#include "paddle/infrt/common/shared.h"
#include "paddle/infrt/host_context/function.h"
#include "paddle/infrt/naive/meta_tensor.h"
#include "paddle/infrt/support/variant.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/dense_tensor_view.h"
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
// Disabled temporarily for failed compile, will enable latter.
// #include "paddle/pten/backends/cpu/cpu_context.h"
// #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/meta_tensor.h"
#ifdef INFRT_WITH_PTEN
#include "paddle/infrt/backends/host/pten_allocator.h"
#include "paddle/infrt/backends/host/pten_context.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#endif
namespace infrt {
namespace host_context {
......@@ -44,14 +48,20 @@ using ValueVariantType = Variant<int16_t,
float,
double,
bool,
uint32_t,
uint64_t,
std::string,
tensor::TensorShape,
tensor::DenseHostTensor,
MlirFunctionExecutable*,
tensor::TensorMap,
// pten::CPUContext,
// pten::DenseTensor,
naive::MetaTensor,
#ifdef INFRT_WITH_PTEN
::pten::MetaTensor,
::pten::DenseTensor,
backends::CpuPtenAllocator,
backends::CpuPtenContext,
::pten::CPUContext,
#endif
std::vector<int16_t>,
std::vector<int32_t>,
std::vector<int64_t>,
......@@ -84,7 +94,13 @@ class Value : public common::Object {
explicit Value(tensor::TensorShape&& x) : data(std::move(x)) {}
explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {}
explicit Value(MlirFunctionExecutable* x) : data(x) {}
explicit Value(naive::MetaTensor&& x) : data(std::move(x)) {}
#ifdef INFRT_WITH_PTEN
explicit Value(backends::CpuPtenContext&& x) : data(std::move(x)) {}
explicit Value(::pten::CPUContext&& x) : data(std::move(x)) {}
explicit Value(::pten::DenseTensor&& x) : data(std::move(x)) {}
explicit Value(::pten::MetaTensor&& x) : data(std::move(x)) {}
explicit Value(backends::CpuPtenAllocator&& x) : data(std::move(x)) {}
#endif
template <typename T>
const T& get() const {
......@@ -142,7 +158,10 @@ class ValueRef : common::Shared<Value> {
explicit ValueRef(float val);
explicit ValueRef(double val);
explicit ValueRef(bool val);
explicit ValueRef(naive::MetaTensor&& val);
explicit ValueRef(::pten::MetaTensor&& val);
explicit ValueRef(backends::CpuPtenContext&& x);
explicit ValueRef(::pten::CPUContext&& x);
explicit ValueRef(::pten::DenseTensor&& x);
using common::Shared<Value>::get;
using common::Shared<Value>::Reset;
......
add_subdirectory(pten)
core_gather_headers()
gather_srcs(infrt_src SRCS
......
if (NOT INFRT_WITH_PTEN)
return()
endif()
core_gather_headers()
gather_srcs(infrt_src SRCS
registry.cc
dense_tensor_kernels.cc
context_kernels.cc
allocator_kernels.cc
)
cc_library(infrt_naive SRCS infershaped/infershaped_kernel_launcher.cc
infershaped/infershaped_kernel_launchers.cc
)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/pten/allocator_kernels.h"
namespace infrt {
namespace kernel {
namespace pten {
backends::CpuPtenAllocator CreateCpuAllocator() { return {}; }
} // namespace pten
} // namespace kernel
} // namespace infrt
......@@ -12,20 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/meta_tensor.h"
#pragma once
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/tensor_shape.h"
#include "paddle/infrt/backends/host/pten_allocator.h"
#include "paddle/pten/core/dense_tensor.h"
namespace infrt {
namespace naive {
namespace kernel {
namespace pten {
const tensor::TensorShape& MetaTensor::shape() const {
return mutable_tensor_->shape();
}
tensor::TensorShape* MetaTensor::mutable_shape() {
return mutable_tensor_->mutable_shape();
}
backends::CpuPtenAllocator CreateCpuAllocator();
} // namespace naive
} // namespace pten
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/pten/context_kernels.h"
namespace infrt {
namespace kernel {
namespace pten {
backends::CpuPtenContext CreateCpuContext() { return {}; }
} // namespace pten
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/infrt/backends/host/pten_context.h"
#include "paddle/pten/core/dense_tensor.h"
namespace infrt {
namespace kernel {
namespace pten {
backends::CpuPtenContext CreateCpuContext();
} // namespace pten
} // namespace kernel
} // namespace infrt
......@@ -12,23 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/kernel/pten/dense_tensor_kernels.h"
namespace infrt {
namespace naive {
namespace kernel {
namespace pten {
using ElementwiseAddLauncher =
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>;
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
::pten::DenseTensor CreateDenseTensorCpuF32Nchw(
backends::CpuPtenAllocator* allocator,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<std::vector<int64_t>> lod) {
return ::pten::DenseTensor(
allocator,
::pten::DenseTensorMeta(::pten::DataType::FLOAT32,
::pten::framework::make_ddim(dims.get()),
::pten::DataLayout::NCHW,
{}));
}
} // namespace naive
void FillDenseTensorF32(::pten::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<int64_t>> values) {}
} // namespace pten
} // namespace kernel
} // namespace infrt
......@@ -12,29 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/pten_kernels.h"
#pragma once
#include <iostream>
#include <string>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/backends/host/pten_allocator.h"
#include "paddle/infrt/host_context/kernel_utils.h"
// Disable temporarily.
// #include "paddle/pten/backends/cpu/cpu_context.h"
// #include "paddle/pten/kernels/math_kernel.h"
using infrt::host_context::Attribute;
#include "paddle/pten/core/dense_tensor.h"
namespace infrt {
namespace kernel {
namespace pten {
::pten::DenseTensor CreateDenseTensorCpuF32Nchw(
backends::CpuPtenAllocator* allocator,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<std::vector<int64_t>> lod);
void RegisterPtenKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("pd_cpu.add.float32",
INFRT_KERNEL(pten::AddKernel<float, pten::CPUContext>));
registry->AddKernel("pd_cpu.add.int32",
INFRT_KERNEL(pten::AddKernel<int, pten::CPUContext>));
}
void FillDenseTensorF32(::pten::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<int64_t>> values);
} // namespace pten
} // namespace kernel
} // namespace infrt
......@@ -16,27 +16,23 @@
#include <llvm/ADT/SmallVector.h>
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_utils.h"
// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
namespace infrt {
namespace naive {
namespace kernel {
static void ElementwiseAddInferShape(const MetaTensor& a,
const MetaTensor& b,
MetaTensor* c) {
CHECK(a.shape() == b.shape())
<< "ElementwiseAdd, but shapes of a b are not match";
*c->mutable_shape() = a.shape();
}
static void ElementwiseAddInferShape(const ::pten::MetaTensor& a,
const ::pten::MetaTensor& b,
::pten::MetaTensor* c) {}
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}
static void ElementwiseAdd(const ::pten::CPUContext& /*Context*/,
const ::pten::DenseTensor& a,
const ::pten::DenseTensor& b,
::pten::DenseTensor* c) {}
template <typename KernelFunc,
KernelFunc kernel,
......@@ -64,5 +60,15 @@ class KernelLauncher : public InferShapedKernelLauncher {
}
};
} // namespace naive
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
void KernelLauncherFunc(
KernelLauncher<KernelFunc, kernel, InferShapedFunc, infershape> launcher,
host_context::KernelFrame* frame) {
launcher.Invoke(frame);
}
} // namespace kernel
} // namespace infrt
......@@ -14,19 +14,17 @@
#include <gtest/gtest.h>
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_utils.h"
namespace infrt {
namespace naive {
namespace kernel {
namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
static void ElementwiseAddTest(const ::pten::DenseTensor& a,
const ::pten::DenseTensor& b,
::pten::DenseTensor* c);
}
TEST(utils, registry) {
......@@ -35,26 +33,24 @@ TEST(utils, registry) {
CHECK_EQ(count, 2U);
}
TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
TEST(ElementwiseAdd, launcher_registry) {
host_context::KernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
ASSERT_EQ(registry.size(), 1UL);
auto creator = registry.GetKernel("elementwise_add");
auto infershape_launcher_handle = creator();
// fake some tensors
tensor::DenseHostTensor a({2, 8}, GetDType<float>());
tensor::DenseHostTensor b({2, 8}, GetDType<float>());
tensor::DenseHostTensor c({2, 8}, GetDType<float>());
::pten::CPUContext ctx{};
::pten::DenseTensor a{};
::pten::DenseTensor b{};
::pten::DenseTensor c{};
host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(ctx)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
infershape_launcher_handle->Invoke(&kernel_frame_builder);
creator(&kernel_frame_builder);
}
} // namespace naive
} // namespace kernel
} // namespace infrt
......@@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.h"
namespace infrt {
namespace naive {
namespace kernel {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
if (value->is_type<::pten::DenseTensor>()) {
values.emplace_back(
::pten::MetaTensor{&value->get<::pten::DenseTensor>()});
infershape_kernel_frame_builder.AddArgument(values.back().get());
} else {
infershape_kernel_frame_builder.AddArgument(value);
......@@ -35,8 +36,9 @@ void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
tensor_shape_cache[i] = infershape_kernel_frame_builder.GetArgAt(i)
->get<::pten::MetaTensor>()
.dims();
}
}
......@@ -49,10 +51,11 @@ bool InferShapedKernelLauncher::IsShapeChanged(
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed ||
(tensor_shape_cache[i] !=
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
infershape_kernel_frame_builder.GetArgAt<::pten::MetaTensor>(i)
.dims());
}
return changed;
}
} // namespace naive
} // namespace kernel
} // namespace infrt
......@@ -17,11 +17,9 @@
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/naive/meta_tensor.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
namespace kernel {
struct InferShapedKernelLauncher {
virtual void Invoke(host_context::KernelFrame* frame) = 0;
......@@ -46,9 +44,9 @@ struct InferShapedKernelLauncher {
// values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values;
llvm::SmallVector<tensor::TensorShape, 3> tensor_shape_cache;
llvm::SmallVector<::pten::DDim, 3> tensor_shape_cache;
host_context::KernelFrameBuilder infershape_kernel_frame_builder;
};
} // namespace naive
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/pten/infershaped/elementwise_add.h"
namespace infrt {
namespace kernel {
void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {
registry->AddKernel(
"elementwise_add",
std::bind(&KernelLauncherFunc<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>,
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>(),
std::placeholders::_1));
}
} // namespace kernel
} // namespace infrt
......@@ -14,12 +14,12 @@
#pragma once
namespace infrt {
namespace naive {
#include "paddle/infrt/host_context/kernel_registry.h"
struct InferShapedKernelRegistry;
namespace infrt {
namespace kernel {
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry);
void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry);
} // namespace naive
} // namespace kernel
} // namespace infrt
......@@ -18,10 +18,10 @@
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace naive {
namespace kernel {
namespace infershaped {
using KeyType = const tensor::DenseHostTensor&;
using KeyType = const ::pten::DenseTensor&;
using CountType = uint8_t;
constexpr CountType value(std::true_type) { return 1; }
......@@ -73,5 +73,5 @@ struct InferShapeHelper<Return (*)(Args...)> {
static constexpr int count = infershaped::count<Args...>();
};
} // namespace naive
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/pten/registry.h"
#include <iostream>
#include <string>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/pten/allocator_kernels.h"
#include "paddle/infrt/kernel/pten/context_kernels.h"
#include "paddle/infrt/kernel/pten/dense_tensor_kernels.h"
#include "paddle/infrt/kernel/pten/infershaped/elementwise_add.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/include/kernels.h"
#include "paddle/pten/kernels/matmul_kernel.h"
using infrt::host_context::Attribute;
namespace infrt {
namespace kernel {
void RegisterPtenKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("pten_dt.create_allocator.cpu",
INFRT_KERNEL(infrt::kernel::pten::CreateCpuAllocator));
registry->AddKernel("pten_dt.create_context.cpu",
INFRT_KERNEL(infrt::kernel::pten::CreateCpuContext));
registry->AddKernel(
"pten_dt.create_dense_tensor.cpu.f32.nchw",
INFRT_KERNEL(infrt::kernel::pten::CreateDenseTensorCpuF32Nchw));
registry->AddKernel("pten_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::pten::FillDenseTensorF32));
registry->AddKernel(
"pten.matmul.host.fp32",
std::bind(&kernel::KernelLauncherFunc<
decltype(&::pten::MatmulKernel<float, ::pten::CPUContext>),
&::pten::MatmulKernel<float, ::pten::CPUContext>,
decltype(&::pten::MatmulInferMeta),
&::pten::MatmulInferMeta>,
kernel::KernelLauncher<
decltype(&::pten::MatmulKernel<float, ::pten::CPUContext>),
&::pten::MatmulKernel<float, ::pten::CPUContext>,
decltype(&::pten::MatmulInferMeta),
&::pten::MatmulInferMeta>(),
std::placeholders::_1));
}
} // namespace kernel
} // namespace infrt
cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launcher.cc
infershaped/infershaped_registry.cc
infershaped/infershaped_kernel_launchers.cc
)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include <unordered_map>
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
namespace infrt {
namespace naive {
struct InferShapedKernelRegistry::Impl {
std::unordered_map<std::string, InferShapeLauncherCreator> data;
};
InferShapedKernelRegistry::InferShapedKernelRegistry()
: impl_(std::make_unique<Impl>()) {}
void InferShapedKernelRegistry::AddKernel(
const std::string& key,
InferShapedKernelRegistry::InferShapeLauncherCreator&& creator) {
CHECK(!impl_->data.count(key)) << "Item called " << key << " duplicates";
impl_->data.emplace(key, std::move(creator));
}
const InferShapedKernelRegistry::InferShapeLauncherCreator&
InferShapedKernelRegistry::GetKernel(const std::string& key) const {
auto it = impl_->data.find(key);
CHECK(it != impl_->data.end()) << "No item called " << key << " exists";
return it->second;
}
size_t InferShapedKernelRegistry::size() const { return impl_->data.size(); }
InferShapedKernelRegistry* GetInferShapeRegistry() {
static auto registry = std::make_unique<InferShapedKernelRegistry>();
return registry.get();
}
InferShapedKernelRegistry::~InferShapedKernelRegistry() {}
} // namespace naive
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
namespace infrt {
namespace naive {
struct InferShapedKernelLauncher;
class InferShapedKernelRegistry {
public:
using InferShapeLauncherHandle = std::unique_ptr<InferShapedKernelLauncher>;
using InferShapeLauncherCreator = std::function<InferShapeLauncherHandle()>;
InferShapedKernelRegistry();
void AddKernel(const std::string& key, InferShapeLauncherCreator&& creator);
const InferShapeLauncherCreator& GetKernel(const std::string& key) const;
size_t size() const;
~InferShapedKernelRegistry();
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
//! The global infershape registry.
InferShapedKernelRegistry* GetInferShapeRegistry();
} // namespace naive
} // namespace infrt
#define INFERSHAPED_KERNEL_CREATOR(infershape_launcher_class_) \
[]() \
-> ::infrt::naive::InferShapedKernelRegistry::InferShapeLauncherHandle { \
return std::make_unique<infershape_launcher_class_>(); \
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// A naive implementation of MetaTensor
#pragma once
#include "paddle/infrt/common/common.h"
namespace infrt {
namespace tensor {
struct DenseHostTensor;
struct TensorShape;
} // namespace tensor
namespace naive {
class MetaTensor {
public:
MetaTensor() = default;
explicit MetaTensor(tensor::DenseHostTensor* tensor)
: mutable_tensor_(tensor) {}
explicit MetaTensor(const tensor::DenseHostTensor* tensor)
: mutable_tensor_(&Reference(tensor)) {}
explicit MetaTensor(MetaTensor&& other)
: mutable_tensor_(other.mutable_tensor_) {}
explicit MetaTensor(const MetaTensor& other)
: mutable_tensor_(other.mutable_tensor_) {}
const tensor::TensorShape& shape() const;
tensor::TensorShape* mutable_shape();
private:
tensor::DenseHostTensor* mutable_tensor_{};
};
} // namespace naive
} // namespace infrt
// RUN: infrtopt %s | FileCheck %s
// CHECK-LABEL: basic_tensor
// CHECK-LABEL: @basic_tensor
func @basic_tensor() {
%a = "pten_dt.create_uninit_tensor.f32" () { shape=[12:i64, 23:i64] } : () -> !infrt.tensor<X86, NCHW, F32>
%b = "pten_dt.create_inited_tensor.f32" () { shape=[2:i64, 2:i64], values=[0.1:f32, 0.2:f32, 0.3:f32, 0.4:f32] } : () -> !infrt.tensor<X86, NCHW, F32>
"pten_dt.fill_tensor_with_constant.f32" (%a) { value=0.1:f32 } : (!infrt.tensor<X86, NCHW, F32>) -> ()
%a = "pten_dt.create_allocator.cpu" (): () -> !pten.CPU_allocator
%b = "pten_dt.create_context.cpu" (): () -> !pten.CPU_context
%c = "pten_dt.create_dense_tensor.cpu.f32.nchw" (%a) {dims=[1:i64], lod=[1:i64]}: (!pten.CPU_allocator) -> (!infrt.tensor<X86, NCHW, F32>)
// "pten_dt.fill_dense_tensor.f32" (%c) {value=[1.0:f32]} : (!infrt.tensor<X86, NCHW, F32>) -> ()
infrt.return
}
......@@ -58,6 +58,10 @@ CPUContext::CPUContext(const Place& place)
CPUContext::~CPUContext() = default;
CPUContext::CPUContext(CPUContext&&) = default;
CPUContext& CPUContext::operator=(CPUContext&&) = default;
void CPUContext::Init() { impl_->Init(); }
Eigen::DefaultDevice* CPUContext::eigen_device() const {
......
......@@ -27,6 +27,8 @@ namespace pten {
class CPUContext : public DeviceContext {
public:
CPUContext();
CPUContext(CPUContext&&);
CPUContext& operator=(CPUContext&&);
explicit CPUContext(const Place&);
virtual ~CPUContext();
Eigen::DefaultDevice* eigen_device() const;
......
......@@ -149,6 +149,8 @@ DeviceContext::DeviceContext(DeviceContext&& other) {
impl_ = std::move(other.impl_);
}
DeviceContext& DeviceContext::operator=(DeviceContext&&) = default;
DeviceContext::~DeviceContext() = default;
void DeviceContext::SetAllocator(const Allocator* allocator) {
......
......@@ -49,6 +49,11 @@ class DeviceContext {
*/
DeviceContext(DeviceContext&&);
/**
* @brief Move assign operator.
*/
DeviceContext& operator=(DeviceContext&&);
/**
* @brief Default destruct.
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册