未验证 提交 93f7a02a 编写于 作者: H hong 提交者: GitHub

Support tensor attribute runtime (#54692)

* add kernel dialect

* change DenseTensorTypeStorage to DenseTensorType

* add test case`

* add first pd_op to kernel dialect

* lower pd op to kernel dialect

* update

* update

* remove useless code

* add attrite print test

* fix bug

* update

* update

* update

* update

* polish code

* fix bug

* polish  code  and add python test

* add test

* fix test error

* add env flag

* fix bug

* revert test env

* change cc_test_old to cc_test

* fix build_static bug

* fix type test error

* udpate cmake

* disable test in windows

* fix inference compile

* update

* support tensor attribute runtime

* add result check

* polish test code

* fix test error

* add scalar test & polish code

* re-open test case
上级 752670e2
......@@ -965,6 +965,8 @@ void BuildOpFuncList(
op_func_node.infer_shape_interface_ =
op_info.GetInterfaceImpl<paddle::dialect::InferShapeInterface>();
VLOG(6) << "op name" << op_func_node.phi_op_name_;
::ir::BuildInferMetaContext((*it),
value_2_name_map,
scope,
......@@ -976,6 +978,8 @@ void BuildOpFuncList(
auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
VLOG(6) << "finish process infer meta context";
auto t1 =
phi::KernelFactory::Instance().SelectKernel(kernel_name, kernel_key);
op_func_node.phi_kernel_ = new phi::Kernel(t1);
......@@ -992,6 +996,7 @@ void BuildOpFuncList(
&(op_func_node.input_index),
&(op_func_node.output_index));
VLOG(6) << "finish process kernel context";
op_func_node.kernel_context_.SetDeviceContext(
phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
......
......@@ -293,8 +293,10 @@ paddle::framework::FetchList InterpreterCore::Run(
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
if (FLAGS_enable_new_ir_in_executor) {
VLOG(6) << "begin to build scope";
::ir::BuildScope(
ir_program_->block(), local_scope_, &value_2_var_name_map_);
VLOG(6) << "build ccope finshed";
} else {
interpreter::BuildVariableScope(block_, execution_config_, &var_scope_);
}
......
......@@ -43,9 +43,14 @@ phi::KernelKey GetKernelKey(
// only suppurt non vector input for now
std::map<std::string, int> input_map;
int index = 0;
int tensor_input_number = 0;
for (auto& t : input_info) {
// todo filter attribute tensor
input_map[t.name] = index++;
if (!t.is_mutable_attribute) {
tensor_input_number += 1;
}
}
std::map<std::string, std::string> attr_type_map;
......@@ -70,12 +75,12 @@ phi::KernelKey GetKernelKey(
// parse from input
int in_index = input_map.at(slot_name);
dialect::AllocatedDenseTensorType type =
dialect::DenseTensorType type =
op->operand(in_index)
.source()
.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
kernel_data_type = type.dyn_cast<dialect::DataTypeAttribute>().data();
.dyn_cast<paddle::dialect::DenseTensorType>();
kernel_data_type = TransToPhiDataType(type.dtype());
} else {
PADDLE_ENFORCE_EQ(
attr_type_map.count(slot_name),
......@@ -89,7 +94,7 @@ phi::KernelKey GetKernelKey(
// parse all the input tensor
if (input_map.size() == 0 || op->name() == "pd.full_") {
if (tensor_input_number == 0 || op->name() == "pd.full_") {
// all the information have to get from attribute and context
kernel_backend = paddle::experimental::ParseBackend(place);
......@@ -98,6 +103,9 @@ phi::KernelKey GetKernelKey(
for (size_t i = 0; i < input_info.size(); ++i) {
// todo filter attribute tensor
if (input_info[i].is_mutable_attribute) {
continue;
}
auto input_tmp = op->operand(i).source();
auto new_input_tmp = map_value_pair.at(input_tmp);
dialect::AllocatedDenseTensorType type =
......@@ -154,6 +162,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
for (auto it = block->begin(); it != block->end(); ++it) {
VLOG(6) << "op name " << (*it)->name();
auto kernel_key = GetKernelKey(*it, cpu_place, map_value_pair);
// create new Op
......@@ -163,6 +172,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
std::vector<ir::Type> op_output_types;
if ((*it)->num_results() > 0) {
// filter tensor attribute
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
......@@ -172,7 +182,8 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
}
// constuct input
std::vector<ir::OpResult> vec_inputs;
if ((*it)->name() != "pd.full_" && (*it)->num_operands() > 0) {
if ((*it)->name() != "pd.full" && (*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i).source();
auto new_in = map_value_pair.at(cur_in);
......
......@@ -148,12 +148,15 @@ void BuildInferMetaContext(
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(phi::IntArray(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(phi::Scalar(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
......@@ -177,6 +180,10 @@ void BuildInferMetaContext(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
......@@ -247,12 +254,15 @@ void BuildPhiKernelContext(
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(phi::IntArray(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(phi::Scalar(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
......@@ -282,6 +292,10 @@ void BuildPhiKernelContext(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <limits>
#include <sstream>
#include <vector>
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
class TensorBase;
// In static model pre analysis, we can't get the data from tensor
class TensorRef {
public:
// Constructor support implicit
TensorRef() = default;
explicit TensorRef(const DenseTensor* base) : tensor_base_(base) {}
const DenseTensor* Get() const {
PADDLE_ENFORCE_NOT_NULL(tensor_base_,
"Can not get null ptr from Tensor ref scalar");
return tensor_base_;
}
private:
const DenseTensor* tensor_base_{nullptr};
};
} // namespace phi
......@@ -21,6 +21,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/tensor_ref.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/variant.h"
......@@ -46,7 +47,8 @@ using Attribute = paddle::variant<bool,
IntArray,
DataType,
DataLayout,
Place>;
Place,
TensorRef>;
using AttributeMap = paddle::flat_hash_map<std::string, Attribute>;
......
......@@ -135,6 +135,10 @@ const AttrType& InferMetaContext::AttrAt(size_t idx) const {
}
}
const Attribute& InferMetaContext::AttrAt(size_t idx) const {
return attrs_.at(idx);
}
template const bool& InferMetaContext::AttrAt(size_t idx) const;
template const int& InferMetaContext::AttrAt(size_t idx) const;
template const int64_t& InferMetaContext::AttrAt(size_t idx) const;
......@@ -154,6 +158,7 @@ template const IntArray& InferMetaContext::AttrAt(size_t idx) const;
template const DataType& InferMetaContext::AttrAt(size_t idx) const;
template const DataLayout& InferMetaContext::AttrAt(size_t idx) const;
template const Place& InferMetaContext::AttrAt(size_t idx) const;
template const TensorRef& InferMetaContext::AttrAt(size_t idx) const;
MetaFnFactory& MetaFnFactory::Instance() {
static MetaFnFactory g_meta_fn_map;
......
......@@ -63,6 +63,8 @@ class InferMetaContext {
template <typename AttrType>
const AttrType& AttrAt(size_t idx) const;
const Attribute& AttrAt(size_t idx) const;
const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
......@@ -116,6 +118,29 @@ class InferMetaContext {
} \
}
#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR_INTARRAY( \
attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type&, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
const Attribute& t = ctx->AttrAt(attr_idx); \
static Attribute cmp_t = phi::TensorRef(nullptr); \
attr_type attr1; \
if (cmp_t.index() == t.index()) { \
attr1 = attr_type((*paddle::get<phi::TensorRef>(t).Get())); \
} else { \
attr1 = paddle::get<attr_type>(t); \
} \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
attr1); \
} \
}
template <typename T>
struct InferMetaTypeTag {};
......@@ -197,8 +222,8 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR_INTARRAY(Scalar);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR_INTARRAY(IntArray);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<bool>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
......
......@@ -123,6 +123,10 @@ const AttrType& KernelContext::AttrAt(size_t idx) const {
}
}
const Attribute& KernelContext::AttrAt(size_t idx) const {
return attrs_.at(idx);
}
template const bool& KernelContext::AttrAt(size_t idx) const;
template const int& KernelContext::AttrAt(size_t idx) const;
template const int64_t& KernelContext::AttrAt(size_t idx) const;
......@@ -142,5 +146,6 @@ template const IntArray& KernelContext::AttrAt(size_t idx) const;
template const DataType& KernelContext::AttrAt(size_t idx) const;
template const DataLayout& KernelContext::AttrAt(size_t idx) const;
template const Place& KernelContext::AttrAt(size_t idx) const;
template const TensorRef& KernelContext::AttrAt(size_t idx) const;
} // namespace phi
......@@ -139,6 +139,7 @@ class KernelContext {
template <typename AttrType>
const AttrType& AttrAt(size_t idx) const;
const Attribute& AttrAt(size_t idx) const;
size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); }
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/tensor_ref.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/extended_tensor.h"
......@@ -220,6 +221,31 @@ namespace phi {
} \
}
#define PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR_INTARRAY(attr_type) \
template <typename... Tail> \
struct KernelCallHelper<const attr_type&, Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"Kernel's Attributes should appear before Outputs."); \
const Attribute& t = ctx->AttrAt(attr_idx); \
static Attribute cmp_t = phi::TensorRef(nullptr); \
attr_type attr1; \
if (cmp_t.index() == t.index()) { \
attr1 = attr_type(*paddle::get<phi::TensorRef>(t).Get()); \
} else { \
attr1 = paddle::get<attr_type>(t); \
} \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx, attr_idx + 1, out_idx>( \
ctx, pargs..., attr1); \
} \
}
template <typename T>
struct TypeTag {};
......@@ -299,8 +325,8 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR_INTARRAY(Scalar);
PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR_INTARRAY(IntArray);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<bool>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int64_t>);
......
......@@ -29,6 +29,8 @@
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/platform/init_phi.h"
DECLARE_FILE_SYMBOLS(kernel_dialect);
......@@ -38,12 +40,12 @@ PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
namespace paddle {
namespace framework {
TEST(StandaloneExecutor, run) {
std::cerr << "here" << std::endl;
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
......@@ -59,8 +61,6 @@ TEST(StandaloneExecutor, run) {
builder.Build<paddle::dialect::AddOp>(op1->result(0), op2->result(0));
program.Print(std::cout);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
kernel_program->Print(std::cout);
......@@ -74,9 +74,90 @@ TEST(StandaloneExecutor, run) {
test_core.Run({});
auto tensor = scope.Var("inner_var_2")->Get<phi::DenseTensor>();
auto out_tensor = scope.Var("inner_var_2")->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
TEST(StandaloneExecutor, run_2) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder(ctx, program.block());
ir::Block* block = program.block();
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->result(0), uniform2->result(0));
EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
paddle::dialect::ScaleOp scale =
builder.Build<paddle::dialect::ScaleOp>(add->result(0), 1.0, 0.0, true);
EXPECT_EQ(scale->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(
place, prog_desc.Block(0), &scope, std::move(kernel_program));
test_core.Run({});
std::cerr << "uot" << tensor << std::endl;
auto out_tensor = scope.Var("inner_var_10")->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.80721);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.70047);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 1.56764);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 1.85063);
std::cerr << out_tensor.data<float>()[0] << "\t"
<< out_tensor.data<float>()[1] << "\t"
<< out_tensor.data<float>()[2] << "\t"
<< out_tensor.data<float>()[3] << std::endl;
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册