diff --git a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc index 13cba6eeabb669cf93deb9a37d87d2ddff66e5c0..18d40ce57649da79cbefa6c7c81cb54da1f226da 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc +++ b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc @@ -97,12 +97,12 @@ void PhiOpConvertPass::convertStage() { } auto loc = getFunction().getLoc(); builder.setInsertionPoint(op); - if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_name)) { - std::string kernel_name = phi::TransToPhiKernelName(op_name); + op_name = phi::TransToPhiKernelName(op_name); + if (!::phi::OpUtilsMap::Instance().Contains(op_name)) { auto kernel_op = builder.create(loc, op->getResultTypes(), op->getOperands(), - kernel_name, + op_name, op->getAttrDictionary()); op->replaceAllUsesWith(kernel_op.getResults()); } else { diff --git a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc index 1cd5b5a85511fe20e8029185caf4c93d95979b72..070867853ad3e427f62c825727de2d15f0442c96 100644 --- a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc +++ b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc @@ -32,17 +32,24 @@ bool ProtoArgumentMappingContext::HasOutput(const std::string& name) const { } bool ProtoArgumentMappingContext::HasAttr(const std::string& name) const { + if (name == "is_test") return true; return op_->hasAttr(name); } paddle::any ProtoArgumentMappingContext::Attr(const std::string& name) const { - mlir::Attribute attrs = op_->getAttr(name); - if (mlir::StringAttr str_attr = attrs.dyn_cast_or_null()) { + if (name == "is_test") { + return paddle::any(true); + } + mlir::Attribute attr = op_->getAttr(name); + if (!attr) { + return paddle::any(); + } + if (mlir::StringAttr str_attr = attr.dyn_cast()) { return paddle::any(str_attr.str()); - } else { - // ToDO: implementation in the ext PR. - return paddle::any(0); } + + // ToDO: implementation in the ext PR. + return paddle::any(0); } size_t ProtoArgumentMappingContext::InputSize(const std::string& name) const { diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 5b92d183b79da21cf9552e8a2f238928962f5832..b0f56f020f4866053d99b01cdc721dfd9f295a36 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -147,6 +147,7 @@ class Value : public common::Object { #endif explicit Value(::phi::DenseTensor&& x) : data(std::move(x)) {} explicit Value(::phi::MetaTensor&& x) : data(std::move(x)) {} + explicit Value(::phi::MetaConfig&& x) : data(std::move(x)) {} #ifdef INFRT_WITH_TRT explicit Value(::infrt::backends::tensorrt::TrtEngine&& x) : data(std::move(x)) {} diff --git a/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.cc b/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.cc index 75e3ebbf00ca54ed3fb2d0ca22bb7819300d0b2b..2e40261f27386717deee886494ef047c2f7166d7 100644 --- a/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.cc +++ b/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.cc @@ -14,6 +14,7 @@ #include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/meta_tensor.h" namespace infrt { namespace kernel { @@ -31,6 +32,10 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape( infershape_kernel_frame_builder.AddArgument(value); } } + if (infershape_kernel_frame_builder.GetNumArgs() < arg_size_) { + infershape_kernel_frame_builder.AddArgument( + new host_context::Value(::phi::MetaConfig())); + } } void InferShapedKernelLauncher::BuildInferShapeCache( diff --git a/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h b/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h index 380b45ea5be09903a7d48e436bb9cc8122df7959..770078115321bd1981a3958ac3c63a0e4dc9bdd3 100644 --- a/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h +++ b/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h @@ -22,11 +22,8 @@ namespace infrt { namespace kernel { struct InferShapedKernelLauncher { - virtual void Invoke(host_context::KernelFrame* frame) = 0; - - virtual ~InferShapedKernelLauncher() = default; - - protected: + explicit InferShapedKernelLauncher(int arg_size) : arg_size_(arg_size) {} + ~InferShapedKernelLauncher() = default; //! Initialize the kernel frame for InferShape kernel. // This method will create a new KernelFrame with all the Tensors(currently // only DenseHostTensor) converted into MetaTensors so that the infer-shape @@ -46,6 +43,7 @@ struct InferShapedKernelLauncher { llvm::SmallVector values; llvm::SmallVector<::phi::DDim, 3> tensor_shape_cache; host_context::KernelFrameBuilder infershape_kernel_frame_builder; + const int arg_size_; }; } // namespace kernel diff --git a/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h b/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h index 75c9e554778dcf1488289c6e9e46fb9652f677dd..2dab7f2324d756967c891a214d1c11d186a2b8e9 100644 --- a/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h +++ b/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h @@ -24,46 +24,44 @@ namespace infrt { namespace kernel { +template +struct FuncArgStatics {}; + +template +struct FuncArgStatics { + constexpr static int arg_size = sizeof...(Args); +}; + template -class KernelLauncher : public InferShapedKernelLauncher { - public: +void KernelLauncherFunc(host_context::KernelFrame* frame) { + static InferShapedKernelLauncher launcher( + FuncArgStatics::arg_size); static const uint16_t num_input_tensors{InferShapeHelper::count}; static const bool turn_on_infer_shape_cache{true}; - void Invoke(host_context::KernelFrame* frame) override { + #ifndef NDEBUG - LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes(); + LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes(); #endif - // Build the infershape KernelFrame if needed. - // TODO(Superjomn) add unlikely here. - if (infershape_kernel_frame_builder.IsEmpty()) { - CreateKernelFrameForInferShape(frame); + // Build the infershape KernelFrame if needed. + // TODO(Superjomn) add unlikely here. + if (launcher.infershape_kernel_frame_builder.IsEmpty()) { + launcher.CreateKernelFrameForInferShape(frame); #ifndef NDEBUG - LOG(INFO) << "infershape.frame: " - << infershape_kernel_frame_builder.DumpArgTypes(); + LOG(INFO) << "infershape.frame: " + << launcher.infershape_kernel_frame_builder.DumpArgTypes(); #endif + } + if (turn_on_infer_shape_cache) { + if (launcher.IsShapeChanged(num_input_tensors)) { + ::infrt::host_context::KernelImpl::Invoke( + &launcher.infershape_kernel_frame_builder); + launcher.BuildInferShapeCache(num_input_tensors); } - if (turn_on_infer_shape_cache) { - if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) { - ::infrt::host_context::KernelImpl::Invoke( - &infershape_kernel_frame_builder); - BuildInferShapeCache(num_input_tensors); - } - } - ::infrt::host_context::KernelImpl::Invoke(frame); } -}; - -template -void KernelLauncherFunc( - KernelLauncher launcher, - host_context::KernelFrame* frame) { - launcher.Invoke(frame); + ::infrt::host_context::KernelImpl::Invoke(frame); } } // namespace kernel diff --git a/paddle/infrt/tests/dialect/phi/phi_test.mlir b/paddle/infrt/tests/dialect/phi/phi_test.mlir index 21ee8ebf0b705894446192b0d5d0bfeb9f10f326..4dda2b7a79d30513c8c80168990f50f71d83e2cc 100644 --- a/paddle/infrt/tests/dialect/phi/phi_test.mlir +++ b/paddle/infrt/tests/dialect/phi/phi_test.mlir @@ -1,14 +1,27 @@ // RUN: infrtexec -i %s module { - func @predict(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + func @predict(%arg0: !infrt.dense_tensor, %arg1: !infrt.dense_tensor, %arg2: !infrt.dense_tensor, %arg3: !infrt.dense_tensor, %arg4: !infrt.dense_tensor) -> !infrt.dense_tensor { %2 = "pd.abs"(%arg0) : (!infrt.dense_tensor) -> !infrt.dense_tensor - infrt.return %2 : !infrt.dense_tensor + %3 = "pd.matmul_v2"(%arg0, %2) {trans_x = false, trans_y = false} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %Y, %MeanOut, %VarianceOut = "pd.batch_norm"(%3, %arg1, %arg2, %arg3, %arg4) {data_layout = "NCHW", epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) + infrt.return %Y : !infrt.dense_tensor } func @main() { %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context - %t = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) + %t = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[1:i64, 3:i64, 8:i64, 8:i64]}: (!phi.context) -> (!infrt.dense_tensor) "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor) -> () - %2 = infrt.call@predict(%t) : (!infrt.dense_tensor) -> !infrt.dense_tensor + %bias = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[3:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%bias) {value=[1.5:f32]} : (!infrt.dense_tensor) -> () + %mean = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[3:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%mean) {value=[3.5:f32]} : (!infrt.dense_tensor) -> () + %scale = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[3:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%scale) {value=[1.0:f32]} : (!infrt.dense_tensor) -> () + %var = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[3:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%var) {value=[0.0:f32]} : (!infrt.dense_tensor) -> () + + %2 = infrt.call@predict(%t, %bias, %mean, %scale, %var) : (!infrt.dense_tensor, !infrt.dense_tensor,!infrt.dense_tensor,!infrt.dense_tensor,!infrt.dense_tensor) -> !infrt.dense_tensor + + //phi_dt.print_tensor(%t : !infrt.dense_tensor) phi_dt.print_tensor(%2 : !infrt.dense_tensor) infrt.return } diff --git a/tools/infrt/generate_phi_kernel_dialect.py b/tools/infrt/generate_phi_kernel_dialect.py index f632c9a9dba504d209946e494e55eb970e727629..bfe1e7e88bec4ea806bb6fbd7cd55af54d642a50 100644 --- a/tools/infrt/generate_phi_kernel_dialect.py +++ b/tools/infrt/generate_phi_kernel_dialect.py @@ -22,7 +22,9 @@ attr_type_converter = { "i": 'SI32Attr', "b": 'BoolAttr', "l": 'SI64Attr', - "f": 'F32Attr' + "f": 'F32Attr', + "NSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE": 'StrAttr', + "St6vectorIiSaIiEE": 'I32ArrayAttr' } target_type_converter = {"CPU": "CPU", "GPU": "GPU"} diff --git a/tools/infrt/get_phi_kernel_function.sh b/tools/infrt/get_phi_kernel_function.sh index febfe5d04762a43da0710b34e21252ffdf4611ea..612620979674934ff8aa70abdf4967200f20b492 100644 --- a/tools/infrt/get_phi_kernel_function.sh +++ b/tools/infrt/get_phi_kernel_function.sh @@ -38,35 +38,36 @@ python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \ --wrapped_infermeta_header_path ${temp_path}/generate.h \ --wrapped_infermeta_source_path ${temp_path}/generate.cc -grep PD_REGISTER_INFER_META_FN ${temp_path}/generate.cc \ +find ${PADDLE_ROOT}/paddle/phi/ -name "*.cc" | xargs grep PD_REGISTER_INFER_META_FN ${temp_path}/generate.cc \ | awk -F "\(|,|::|\)" '{print $2, $4}' > ${temp_path}/wrap_info.txt - #step 3:get ir's attr_name. ir_attr_name_info_file=`mktemp` # phi_cpu attr -all_ir_name=`grep -Eo "PDTCPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'` +all_ir_name=`grep -Eo "PDTCPU_Kernel<.*\"" ${PADDLE_ROOT}/paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'` for ir in $all_ir_name do - attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | grep -Eo "Attr:.*)" \ + attr_name=`grep "<\"$ir" -A 3 ${PADDLE_ROOT}/paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | grep -Eo "Attr:.*)" \ | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BoolAttr/,""); \ gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \ gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \ gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \ + gsub(/I32ArrayAttr/,"");gsub(/SI32ArrayAttr/,""); \ gsub(/Attr/,"");gsub(/\)/,""); \ gsub(/[,:]/,"");print $a}'` echo phi_cpu.$ir $attr_name >> $ir_attr_name_info_file done # phi_gpu attr -all_ir_name=`grep -Eo "PDTGPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'` +all_ir_name=`grep -Eo "PDTGPU_Kernel<.*\"" ${PADDLE_ROOT}/paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'` for ir in $all_ir_name do - attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | grep -Eo "Attr:.*)" \ + attr_name=`grep "<\"$ir" -A 3 ${PADDLE_ROOT}/paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | grep -Eo "Attr:.*)" \ | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BoolAttr/,""); \ gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \ gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \ gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \ - gsub(/Attr/,"");gsub(/\)/,""); \ + gsub(/I32ArrayAttr/,"");gsub(/SI32ArrayAttr/,""); \ + gsub(/Attr/,"");gsub(/\)/,"") \ gsub(/[,:]/,"");print $a}'` echo phi_gpu.$ir $attr_name >> $ir_attr_name_info_file done diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index 8b752f928719bcc7ebef4792c29af02261dbd551..23d9a8ffdd225b06301ee3f4c38f3fa7fb8c4c1c 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -91,11 +91,10 @@ def merge(infer_meta_data, kernel_data, wrap_data): full_kernel_data = [] for l in kernel_data: key = l.split()[0] - if key in meta_map: - if key in meta_map: - full_kernel_data.append((l + " " + wrap_map[key]).split()) - else: - full_kernel_data.append((l + " " + meta_map[key]).split()) + if key in wrap_map: + full_kernel_data.append((l + " " + wrap_map[key]).split()) + elif key in meta_map: + full_kernel_data.append((l + " " + meta_map[key]).split()) else: full_kernel_data.append((l + " unknown").split()) @@ -246,15 +245,10 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): registry->AddKernelWithAttrs("{ir_name}",""" res += f""" - std::bind(&KernelLauncherFunc, - KernelLauncher(), - std::placeholders::_1), {{{attr_names}}}); """ @@ -263,15 +257,10 @@ registry->AddKernelWithAttrs("{ir_name}",""" registry->AddKernel("{ir_name}",""" res += f""" - std::bind(&KernelLauncherFunc, - KernelLauncher(), - std::placeholders::_1)); + {infer_shape_func}>); """ return res