From 3ac9bc9521a7c0914bdaa1c8b27014153a001f03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Fri, 4 Mar 2022 12:33:10 +0800 Subject: [PATCH] [infrt] add ir for convert pd dilect to phi dialect. test=develop (#40104) --- paddle/infrt/dialect/infrt/infrt_ops.td | 7 ++ paddle/infrt/dialect/phi/CMakeLists.txt | 3 + paddle/infrt/dialect/phi/ir/infrt_phi_base.td | 1 + .../infrt/dialect/phi/pass/kernel_op_desc.cc | 63 ++++++++++- .../infrt/dialect/phi/pass/kernel_op_desc.h | 4 + .../infrt/dialect/phi/pass/phi_op_cvt_pass.cc | 100 ++++++++++++++++-- .../dialect/phi/pass/proto_arg_map_context.h | 2 +- paddle/infrt/dialect/phi/phi_exec.cc | 67 +++++++----- paddle/infrt/dialect/phi/phi_ir_exec.cc | 47 ++++++++ paddle/infrt/host_context/CMakeLists.txt | 3 +- paddle/infrt/pass/CMakeLists.txt | 1 - paddle/infrt/tests/CMakeLists.txt | 2 +- .../infrt/tests/dialect/pten/pten_pass.mlir | 2 +- paddle/infrt/tests/lit.cfg.py.in | 3 +- paddle/scripts/infrt_build.sh | 2 +- tools/infrt/fake_models/multi_fc.py | 1 - tools/infrt/generate_phi_kernel_dialect.py | 5 +- 17 files changed, 263 insertions(+), 50 deletions(-) create mode 100644 paddle/infrt/dialect/phi/phi_ir_exec.cc delete mode 100755 paddle/infrt/pass/CMakeLists.txt diff --git a/paddle/infrt/dialect/infrt/infrt_ops.td b/paddle/infrt/dialect/infrt/infrt_ops.td index 00f94805c7..ecd7093e72 100644 --- a/paddle/infrt/dialect/infrt/infrt_ops.td +++ b/paddle/infrt/dialect/infrt/infrt_ops.td @@ -17,3 +17,10 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { OptionalAttr:$attrs); let results = (outs Variadic); } + +def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> { + let summary = "convert tensor type op"; + let description = [{convert tensor type op!}]; + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$output); +} diff --git a/paddle/infrt/dialect/phi/CMakeLists.txt b/paddle/infrt/dialect/phi/CMakeLists.txt index d477b6b9bd..a2677a946c 100644 --- a/paddle/infrt/dialect/phi/CMakeLists.txt +++ b/paddle/infrt/dialect/phi/CMakeLists.txt @@ -5,5 +5,8 @@ endif() add_subdirectory(ir) add_subdirectory(pass) +add_executable(phi-ir-exec phi_ir_exec.cc) +target_link_libraries(phi-ir-exec infrt) + add_executable(phi-exec phi_exec.cc) target_link_libraries(phi-exec infrt) diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td index e9591e7f6d..671646b925 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td @@ -3,6 +3,7 @@ include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/infrt_base.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def PHI_Dialect : Dialect { let name = "phi"; diff --git a/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc b/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc index 6c0f6df892..12a6cfcc3e 100644 --- a/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc +++ b/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc @@ -16,8 +16,10 @@ #include #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_registry.h" -namespace infrt { +#include "paddle/phi/kernels/declarations.h" +namespace infrt { +namespace { phi::Backend cvtTarget2Phi(TargetType target) { switch (target) { case TargetType::CPU: @@ -124,19 +126,76 @@ Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg) { cvtLayoutFromPhi(tensor_arg.layout)); } +} // namespace + +std::string getPhiTargetPrefix(TargetType target) { + switch (target) { + case TargetType::CPU: + return "phi_cpu."; + case TargetType::GPU: + return "phi_gpu."; + default: + LOG(FATAL) << "UnSupported target type !"; + return std::string(); + } +} +std::string getPhiPrecisionSuffix(PrecisionType precision) { + switch (precision) { + case PrecisionType::FLOAT32: + return ".float32"; + case PrecisionType::FLOAT16: + return ".float16"; + case PrecisionType::FLOAT64: + return ".float64"; + case PrecisionType::UINT8: + return ".uint8"; + case PrecisionType::INT8: + return ".int8"; + case PrecisionType::INT16: + return ".int16"; + case PrecisionType::INT32: + return ".int32"; + case PrecisionType::INT64: + return ".int64"; + case PrecisionType::COMPLEX64: + return ".complex64"; + case PrecisionType::COMPLEX128: + return ".complex128"; + case PrecisionType::BOOL: + return ".bool"; + default: + LOG(FATAL) << "UnSupported precision type !"; + return std::string(); + } +} +std::string getPhiLayoutSuffix(LayoutType layout) { + switch (layout) { + case LayoutType::NCHW: + return ".nchw"; + case LayoutType::NHWC: + return ".nhwc"; + case LayoutType::ANY: + return ".any"; + default: + LOG(FATAL) << "UnSupported layout type !"; + return std::string(); + } +} + std::vector getCandidateKernels( std::string name, const std::vector& valid_palces) { std::vector candidate_kernels; PhiKernelDesc phi_kernel_desc; phi::KernelKeyMap kernel_key_map = phi::KernelFactory::Instance().SelectKernelMap(name); - for (const Place& place : valid_palces) { + for (Place place : valid_palces) { phi::KernelKey kernel_key = cvtPlace2Phi(place); if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) { kernel_key = phi::KernelKey(kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue; + place.layout = LayoutType::ANY; } phi_kernel_desc.kernelType = place; phi_kernel_desc.inputsType.clear(); diff --git a/paddle/infrt/dialect/phi/pass/kernel_op_desc.h b/paddle/infrt/dialect/phi/pass/kernel_op_desc.h index b74107f674..34fd2f0f62 100644 --- a/paddle/infrt/dialect/phi/pass/kernel_op_desc.h +++ b/paddle/infrt/dialect/phi/pass/kernel_op_desc.h @@ -26,6 +26,10 @@ struct PhiKernelDesc { Place kernelType; // kernel place }; +std::string getPhiTargetPrefix(TargetType target); +std::string getPhiPrecisionSuffix(PrecisionType precision); +std::string getPhiLayoutSuffix(LayoutType layout); + std::vector getCandidateKernels( std::string name, const std::vector& valid_palces); diff --git a/paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc b/paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc index df3472aa01..376ab31938 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc +++ b/paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc @@ -18,11 +18,14 @@ #include #include #include +#include +#include #include #include #include #include "paddle/infrt/dialect/infrt/infrt_dialect.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h" #include "paddle/phi/core/compat/op_utils.h" @@ -58,8 +61,8 @@ void phiOpCvtPass::convertStage() { continue; } - phi::KernelSignature kernel_sign = - phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( + ::phi::KernelSignature kernel_sign = + ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( ProtoArgumentMappingContext(op)); // resort input&output according to kernel_sign ::llvm::SmallVector inputs, ori_output; @@ -104,13 +107,92 @@ void phiOpCvtPass::diapatchStage() { infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null(&op); if (nullptr != kernel_op) worklist.push_back(kernel_op); } - // ToDo: implementation in the next PR - while (!worklist.empty()) { - // infrt::KernelOp kernel_op = worklist.back(); - worklist.pop_back(); - // std::string kernel_name = kernel_op.name().str(); - // std::vector candidates = - // getCandidateKernels(kernel_name, valid_places_); + + mlir::OpBuilder builder(&block, block.begin()); + std::map phi_context; + for (infrt::KernelOp kernel_op : worklist) { + std::string kernel_name = kernel_op.name().str(); + std::vector candidates = + getCandidateKernels(kernel_name, valid_places_); + if (candidates.empty()) { + LOG(FATAL) << "No candidate kernels for op:" << kernel_name; + continue; + } + builder.setInsertionPoint(kernel_op); + + // Todo: Implimentation the concrete pass pick strategy + const PhiKernelDesc &phi_kernel_desc = candidates.front(); + + kernel_name = getPhiTargetPrefix(phi_kernel_desc.kernelType.target) + + kernel_name + + getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout) + + getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision); + + // mlir::OperationName operation_name = kernel_op.getOperation()->getName(); + + mlir::OperationName operation_name(kernel_name, kernel_op.getContext()); + mlir::OperationState operation_state(kernel_op.getLoc(), operation_name); + + if (phi_context.find(phi_kernel_desc.kernelType.target) == + phi_context.end()) { + switch (phi_kernel_desc.kernelType.target) { + case TargetType::CPU: { + auto alloctor_value = + builder + .create( + kernel_op.getLoc(), + phi::AllocatorType::get(kernel_op.getContext(), + TargetType::CPU)) + .output(); + auto context_value = + builder + .create( + kernel_op.getLoc(), + phi::ContextType::get(kernel_op.getContext(), + TargetType::CPU), + alloctor_value) + .output(); + phi_context[TargetType::CPU] = context_value; + } break; + case TargetType::GPU: + case TargetType::UNK: + default: + LOG(FATAL) << "Unsupported TargetType"; + break; + } + } + operation_state.addOperands( + phi_context.at(phi_kernel_desc.kernelType.target)); + for (size_t index = 0; index < phi_kernel_desc.inputsType.size(); ++index) { + mlir::Value input = kernel_op.getOperand(index); + auto cvt_tensor_type_op = builder.create( + kernel_op.getLoc(), + DenseTensorType::get(kernel_op.getContext(), + phi_kernel_desc.inputsType[index].target, + phi_kernel_desc.inputsType[index].precision, + phi_kernel_desc.inputsType[index].layout), + input); + operation_state.addOperands(cvt_tensor_type_op.output()); + } + for (size_t index = 0; index < phi_kernel_desc.outputsType.size(); + ++index) { + operation_state.addTypes( + DenseTensorType::get(kernel_op.getContext(), + phi_kernel_desc.outputsType[index].target, + phi_kernel_desc.outputsType[index].precision, + phi_kernel_desc.outputsType[index].layout)); + } + operation_state.addAttributes(kernel_op.attrsAttr().getValue()); + mlir::Operation *phi_operation = builder.createOperation(operation_state); + for (size_t index = 0; index < phi_kernel_desc.outputsType.size(); + ++index) { + mlir::Value input = phi_operation->getResult(index); + auto cvt_tensor_type_op = builder.create( + kernel_op.getLoc(), kernel_op.getResultTypes()[index], input); + kernel_op.getResult(index).replaceAllUsesWith( + cvt_tensor_type_op.output()); + } + kernel_op.erase(); } } } // namespace infrt diff --git a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h index ca8a22a7e7..e4e9b5c3ff 100644 --- a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h +++ b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/arg_map_context.h" namespace infrt { -class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { +class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext { public: // only support op in pd dialect explicit ProtoArgumentMappingContext(mlir::Operation* op) diff --git a/paddle/infrt/dialect/phi/phi_exec.cc b/paddle/infrt/dialect/phi/phi_exec.cc index 4e99661a6a..a2808a00cb 100644 --- a/paddle/infrt/dialect/phi/phi_exec.cc +++ b/paddle/infrt/dialect/phi/phi_exec.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,37 +11,46 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include -#include "paddle/infrt/common/global.h" -#include "paddle/infrt/dialect/mlir_loader.h" -#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h" -int main(int argc, char** argv) { - static llvm::cl::opt input_file( - llvm::cl::Positional, - llvm::cl::desc("Specify input filename"), - llvm::cl::init("-")); - - llvm::cl::ParseCommandLineOptions(argc, argv); +#include "paddle/infrt/host_context/paddle_mlir.h" - mlir::MLIRContext* context = infrt::Global::getMLIRContext(); - auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context); +void print_usage() { + std::cout << "Error inputs format, two kinds of inputs are supported:\n"; + std::cout << " [1] ./paddle-mlir-convert $path_to_model_file " + "$path_to_params_file\n"; + std::cout << " [2] ./paddle-mlir-convert $path_to_model_dir(__model__ + " + "params)\n"; +} - module->dump(); - mlir::PassManager pm(context); +bool parse_inputs(int argc, + char** argv, + std::string* model_file_name, + std::string* params_file_name) { + switch (argc) { + case 1: { + print_usage(); + return false; + } + case 2: { + *model_file_name = std::string(argv[1]) + std::string("/__model__"); + *params_file_name = std::string(argv[1]) + std::string("/params"); + return true; + } + case 3: { + *model_file_name = argv[1]; + *params_file_name = argv[2]; + return true; + } + default: { return false; } + } +} - mlir::OpPassManager& phi_pass_manager = pm.nest(); - std::vector valid_places = {{infrt::TargetType::CPU, - infrt::PrecisionType::FLOAT32, - infrt::LayoutType::NCHW}}; - phi_pass_manager.addPass(std::make_unique(valid_places)); - if (mlir::failed(pm.run(*module))) { - std::cout << "\npass failed!\n" << std::endl; - return 4; +int main(int argc, char** argv) { + std::string model_file_name; + std::string params_file_name; + if (parse_inputs(argc, argv, &model_file_name, ¶ms_file_name)) { + MLIRModelGenImpl myGen; + auto module_ = myGen.ImportPaddleModel(model_file_name, params_file_name); + module_.dump(); } - module->dump(); - return 0; } diff --git a/paddle/infrt/dialect/phi/phi_ir_exec.cc b/paddle/infrt/dialect/phi/phi_ir_exec.cc new file mode 100644 index 0000000000..1df929895b --- /dev/null +++ b/paddle/infrt/dialect/phi/phi_ir_exec.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h" + +int main(int argc, char** argv) { + static llvm::cl::opt input_file( + llvm::cl::Positional, + llvm::cl::desc("Specify input filename"), + llvm::cl::init("-")); + + llvm::cl::ParseCommandLineOptions(argc, argv); + + mlir::MLIRContext* context = infrt::Global::getMLIRContext(); + auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context); + context->loadAllAvailableDialects(); + module->dump(); + mlir::PassManager pm(context); + + mlir::OpPassManager& phi_pass_manager = pm.nest(); + std::vector valid_places = {{infrt::TargetType::CPU, + infrt::PrecisionType::FLOAT32, + infrt::LayoutType::NCHW}}; + phi_pass_manager.addPass(std::make_unique(valid_places)); + if (mlir::failed(pm.run(*module))) { + std::cout << "\npass failed!\n" << std::endl; + return 4; + } + module->dump(); + return 0; +} diff --git a/paddle/infrt/host_context/CMakeLists.txt b/paddle/infrt/host_context/CMakeLists.txt index 11304742ec..14cbea70ca 100644 --- a/paddle/infrt/host_context/CMakeLists.txt +++ b/paddle/infrt/host_context/CMakeLists.txt @@ -12,6 +12,7 @@ gather_srcs(infrt_src SRCS function.cc mlir_function_executable.cc mlir_program_executor.cc + paddle_mlir.cc ) cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS}) @@ -21,7 +22,7 @@ cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${ML cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_mlir_to_runtime_translate SRCS mlir_to_runtime_translate_test.cc DEPS infrt ${MLIR_IR_LIBS}) -add_executable(paddle-mlir-convert paddle_mlir.cc paddle_mlir_converter.cc) +add_executable(paddle-mlir-convert paddle_mlir_converter.cc) target_link_libraries(paddle-mlir-convert infrt ${MLIR_IR_LIBS}) add_executable(infrtexec mlir_exec.cc) target_link_libraries(infrtexec infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/pass/CMakeLists.txt b/paddle/infrt/pass/CMakeLists.txt deleted file mode 100755 index 51fecdf907..0000000000 --- a/paddle/infrt/pass/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(phi) diff --git a/paddle/infrt/tests/CMakeLists.txt b/paddle/infrt/tests/CMakeLists.txt index e5cc1ec112..5ce6d86734 100644 --- a/paddle/infrt/tests/CMakeLists.txt +++ b/paddle/infrt/tests/CMakeLists.txt @@ -1,6 +1,6 @@ configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py") add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\"" - DEPENDS infrtopt infrtexec) + DEPENDS infrtopt infrtexec phi-ir-exec) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) diff --git a/paddle/infrt/tests/dialect/pten/pten_pass.mlir b/paddle/infrt/tests/dialect/pten/pten_pass.mlir index 30ff2636ae..61a66cb3d7 100644 --- a/paddle/infrt/tests/dialect/pten/pten_pass.mlir +++ b/paddle/infrt/tests/dialect/pten/pten_pass.mlir @@ -1,4 +1,4 @@ -// RUN: infrtopt %s | FileCheck %s +// RUN: phi-ir-exec %s // CHECK-LABEL: @ops func @ops() { %a = pd.feed() {name="input0"} : !infrt.lod_tensor diff --git a/paddle/infrt/tests/lit.cfg.py.in b/paddle/infrt/tests/lit.cfg.py.in index d47957dac9..fe35dc4b8b 100644 --- a/paddle/infrt/tests/lit.cfg.py.in +++ b/paddle/infrt/tests/lit.cfg.py.in @@ -23,9 +23,10 @@ config.llvm_tools_dir = os.path.join(build_dir, "/third_party/install/llvm/lib") infrtopt_bin = os.path.join(build_dir, "paddle/infrt/dialect/") trtexec_bin = os.path.join(build_dir, "paddle/infrt/dialect/tensorrt/") infrtexec_bin = os.path.join(build_dir, "paddle/infrt/host_context/") +phi_ir_exec_bin = os.path.join(build_dir, "paddle/infrt/dialect/phi") llvm_bin = os.path.join(build_dir, "third_party/install/llvm/bin/") config.environment['PATH'] = os.path.pathsep.join( - (infrtopt_bin, infrtexec_bin, trtexec_bin, llvm_bin, config.environment['PATH'])) + (infrtopt_bin, infrtexec_bin, trtexec_bin, phi_ir_exec_bin, llvm_bin, config.environment['PATH'])) config.suffixes = ['.mlir'] diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index 75b27e4165..fb7be82d1c 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -92,7 +92,7 @@ function infrt_gen_and_build() { exit 7; fi - make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$? + make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-ir-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$? if [ "$build_error" != 0 ];then exit 7; fi diff --git a/tools/infrt/fake_models/multi_fc.py b/tools/infrt/fake_models/multi_fc.py index 03cf6828cc..0d633cfc60 100644 --- a/tools/infrt/fake_models/multi_fc.py +++ b/tools/infrt/fake_models/multi_fc.py @@ -19,7 +19,6 @@ import sys, os import numpy as np import paddle import paddle.fluid as fluid -from paddle.fluid.backward import append_backward size = 2 num_layers = 4 diff --git a/tools/infrt/generate_phi_kernel_dialect.py b/tools/infrt/generate_phi_kernel_dialect.py index 8efa03306f..f3a78a8d4e 100644 --- a/tools/infrt/generate_phi_kernel_dialect.py +++ b/tools/infrt/generate_phi_kernel_dialect.py @@ -16,7 +16,7 @@ import json import sys attr_type_converter = {"i": 'SI32Attr', "b": 'BoolAttr', "l": 'SI64Attr'} -supported_kernels = ['sign', 'dot', 'digamma', 'conj'] +supported_kernels = ['sign', 'dot', 'digamma', 'conj', 'abs', 'add_raw'] target_type_converter = {"CPU": "CPU", "GPU": "GPU"} layout_type_converter = { @@ -66,7 +66,8 @@ def generate_attrs_info(op_name, attrs_info): 'digamma': [], 'lerp': [], 'cast': ['out_dtype', 'in_dtype'], - 'abs': [] + 'abs': [], + 'add_raw': ['axis'], } attrs_args_ = "" if len(kernel_attrs_names[op_name]) == len(attrs_info): -- GitLab