未验证 提交 3ac9bc95 编写于 作者: 王明冬 提交者: GitHub

[infrt] add ir for convert pd dilect to phi dialect. test=develop (#40104)

上级 abacc4cb
...@@ -17,3 +17,10 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { ...@@ -17,3 +17,10 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
OptionalAttr<DictionaryAttr>:$attrs); OptionalAttr<DictionaryAttr>:$attrs);
let results = (outs Variadic<AnyType>); let results = (outs Variadic<AnyType>);
} }
def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> {
let summary = "convert tensor type op";
let description = [{convert tensor type op!}];
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
}
...@@ -5,5 +5,8 @@ endif() ...@@ -5,5 +5,8 @@ endif()
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(pass) add_subdirectory(pass)
add_executable(phi-ir-exec phi_ir_exec.cc)
target_link_libraries(phi-ir-exec infrt)
add_executable(phi-exec phi_exec.cc) add_executable(phi-exec phi_exec.cc)
target_link_libraries(phi-exec infrt) target_link_libraries(phi-exec infrt)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td" include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
def PHI_Dialect : Dialect { def PHI_Dialect : Dialect {
let name = "phi"; let name = "phi";
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace infrt { #include "paddle/phi/kernels/declarations.h"
namespace infrt {
namespace {
phi::Backend cvtTarget2Phi(TargetType target) { phi::Backend cvtTarget2Phi(TargetType target) {
switch (target) { switch (target) {
case TargetType::CPU: case TargetType::CPU:
...@@ -124,19 +126,76 @@ Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg) { ...@@ -124,19 +126,76 @@ Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg) {
cvtLayoutFromPhi(tensor_arg.layout)); cvtLayoutFromPhi(tensor_arg.layout));
} }
} // namespace
std::string getPhiTargetPrefix(TargetType target) {
switch (target) {
case TargetType::CPU:
return "phi_cpu.";
case TargetType::GPU:
return "phi_gpu.";
default:
LOG(FATAL) << "UnSupported target type !";
return std::string();
}
}
std::string getPhiPrecisionSuffix(PrecisionType precision) {
switch (precision) {
case PrecisionType::FLOAT32:
return ".float32";
case PrecisionType::FLOAT16:
return ".float16";
case PrecisionType::FLOAT64:
return ".float64";
case PrecisionType::UINT8:
return ".uint8";
case PrecisionType::INT8:
return ".int8";
case PrecisionType::INT16:
return ".int16";
case PrecisionType::INT32:
return ".int32";
case PrecisionType::INT64:
return ".int64";
case PrecisionType::COMPLEX64:
return ".complex64";
case PrecisionType::COMPLEX128:
return ".complex128";
case PrecisionType::BOOL:
return ".bool";
default:
LOG(FATAL) << "UnSupported precision type !";
return std::string();
}
}
std::string getPhiLayoutSuffix(LayoutType layout) {
switch (layout) {
case LayoutType::NCHW:
return ".nchw";
case LayoutType::NHWC:
return ".nhwc";
case LayoutType::ANY:
return ".any";
default:
LOG(FATAL) << "UnSupported layout type !";
return std::string();
}
}
std::vector<PhiKernelDesc> getCandidateKernels( std::vector<PhiKernelDesc> getCandidateKernels(
std::string name, const std::vector<Place>& valid_palces) { std::string name, const std::vector<Place>& valid_palces) {
std::vector<PhiKernelDesc> candidate_kernels; std::vector<PhiKernelDesc> candidate_kernels;
PhiKernelDesc phi_kernel_desc; PhiKernelDesc phi_kernel_desc;
phi::KernelKeyMap kernel_key_map = phi::KernelKeyMap kernel_key_map =
phi::KernelFactory::Instance().SelectKernelMap(name); phi::KernelFactory::Instance().SelectKernelMap(name);
for (const Place& place : valid_palces) { for (Place place : valid_palces) {
phi::KernelKey kernel_key = cvtPlace2Phi(place); phi::KernelKey kernel_key = cvtPlace2Phi(place);
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) { if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) {
kernel_key = phi::KernelKey(kernel_key.backend(), kernel_key = phi::KernelKey(kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT, phi::DataLayout::ALL_LAYOUT,
kernel_key.dtype()); kernel_key.dtype());
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue; if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue;
place.layout = LayoutType::ANY;
} }
phi_kernel_desc.kernelType = place; phi_kernel_desc.kernelType = place;
phi_kernel_desc.inputsType.clear(); phi_kernel_desc.inputsType.clear();
......
...@@ -26,6 +26,10 @@ struct PhiKernelDesc { ...@@ -26,6 +26,10 @@ struct PhiKernelDesc {
Place kernelType; // kernel place Place kernelType; // kernel place
}; };
std::string getPhiTargetPrefix(TargetType target);
std::string getPhiPrecisionSuffix(PrecisionType precision);
std::string getPhiLayoutSuffix(LayoutType layout);
std::vector<PhiKernelDesc> getCandidateKernels( std::vector<PhiKernelDesc> getCandidateKernels(
std::string name, const std::vector<Place>& valid_palces); std::string name, const std::vector<Place>& valid_palces);
......
...@@ -18,11 +18,14 @@ ...@@ -18,11 +18,14 @@
#include <llvm/ADT/SetVector.h> #include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h> #include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
#include <list> #include <list>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h" #include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h" #include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
...@@ -58,8 +61,8 @@ void phiOpCvtPass::convertStage() { ...@@ -58,8 +61,8 @@ void phiOpCvtPass::convertStage() {
continue; continue;
} }
phi::KernelSignature kernel_sign = ::phi::KernelSignature kernel_sign =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
ProtoArgumentMappingContext(op)); ProtoArgumentMappingContext(op));
// resort input&output according to kernel_sign // resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output; ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
...@@ -104,13 +107,92 @@ void phiOpCvtPass::diapatchStage() { ...@@ -104,13 +107,92 @@ void phiOpCvtPass::diapatchStage() {
infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op); infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
if (nullptr != kernel_op) worklist.push_back(kernel_op); if (nullptr != kernel_op) worklist.push_back(kernel_op);
} }
// ToDo: implementation in the next PR
while (!worklist.empty()) { mlir::OpBuilder builder(&block, block.begin());
// infrt::KernelOp kernel_op = worklist.back(); std::map<TargetType, mlir::Value> phi_context;
worklist.pop_back(); for (infrt::KernelOp kernel_op : worklist) {
// std::string kernel_name = kernel_op.name().str(); std::string kernel_name = kernel_op.name().str();
// std::vector<PhiKernelDesc> candidates = std::vector<PhiKernelDesc> candidates =
// getCandidateKernels(kernel_name, valid_places_); getCandidateKernels(kernel_name, valid_places_);
if (candidates.empty()) {
LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
continue;
}
builder.setInsertionPoint(kernel_op);
// Todo: Implimentation the concrete pass pick strategy
const PhiKernelDesc &phi_kernel_desc = candidates.front();
kernel_name = getPhiTargetPrefix(phi_kernel_desc.kernelType.target) +
kernel_name +
getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout) +
getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision);
// mlir::OperationName operation_name = kernel_op.getOperation()->getName();
mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);
if (phi_context.find(phi_kernel_desc.kernelType.target) ==
phi_context.end()) {
switch (phi_kernel_desc.kernelType.target) {
case TargetType::CPU: {
auto alloctor_value =
builder
.create<infrt::phi::CreateAllocatorOp_cpu>(
kernel_op.getLoc(),
phi::AllocatorType::get(kernel_op.getContext(),
TargetType::CPU))
.output();
auto context_value =
builder
.create<infrt::phi::CreateContextOp_cpu>(
kernel_op.getLoc(),
phi::ContextType::get(kernel_op.getContext(),
TargetType::CPU),
alloctor_value)
.output();
phi_context[TargetType::CPU] = context_value;
} break;
case TargetType::GPU:
case TargetType::UNK:
default:
LOG(FATAL) << "Unsupported TargetType";
break;
}
}
operation_state.addOperands(
phi_context.at(phi_kernel_desc.kernelType.target));
for (size_t index = 0; index < phi_kernel_desc.inputsType.size(); ++index) {
mlir::Value input = kernel_op.getOperand(index);
auto cvt_tensor_type_op = builder.create<CvtTensorOp>(
kernel_op.getLoc(),
DenseTensorType::get(kernel_op.getContext(),
phi_kernel_desc.inputsType[index].target,
phi_kernel_desc.inputsType[index].precision,
phi_kernel_desc.inputsType[index].layout),
input);
operation_state.addOperands(cvt_tensor_type_op.output());
}
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
++index) {
operation_state.addTypes(
DenseTensorType::get(kernel_op.getContext(),
phi_kernel_desc.outputsType[index].target,
phi_kernel_desc.outputsType[index].precision,
phi_kernel_desc.outputsType[index].layout));
}
operation_state.addAttributes(kernel_op.attrsAttr().getValue());
mlir::Operation *phi_operation = builder.createOperation(operation_state);
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
++index) {
mlir::Value input = phi_operation->getResult(index);
auto cvt_tensor_type_op = builder.create<CvtTensorOp>(
kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
kernel_op.getResult(index).replaceAllUsesWith(
cvt_tensor_type_op.output());
}
kernel_op.erase();
} }
} }
} // namespace infrt } // namespace infrt
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
namespace infrt { namespace infrt {
class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
public: public:
// only support op in pd dialect // only support op in pd dialect
explicit ProtoArgumentMappingContext(mlir::Operation* op) explicit ProtoArgumentMappingContext(mlir::Operation* op)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -11,37 +11,46 @@ ...@@ -11,37 +11,46 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream>
#include <string>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h"
int main(int argc, char** argv) { #include "paddle/infrt/host_context/paddle_mlir.h"
static llvm::cl::opt<std::string> input_file(
llvm::cl::Positional,
llvm::cl::desc("Specify input filename"),
llvm::cl::init("-"));
llvm::cl::ParseCommandLineOptions(argc, argv);
mlir::MLIRContext* context = infrt::Global::getMLIRContext(); void print_usage() {
auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context); std::cout << "Error inputs format, two kinds of inputs are supported:\n";
std::cout << " [1] ./paddle-mlir-convert $path_to_model_file "
"$path_to_params_file\n";
std::cout << " [2] ./paddle-mlir-convert $path_to_model_dir(__model__ + "
"params)\n";
}
module->dump(); bool parse_inputs(int argc,
mlir::PassManager pm(context); char** argv,
std::string* model_file_name,
std::string* params_file_name) {
switch (argc) {
case 1: {
print_usage();
return false;
}
case 2: {
*model_file_name = std::string(argv[1]) + std::string("/__model__");
*params_file_name = std::string(argv[1]) + std::string("/params");
return true;
}
case 3: {
*model_file_name = argv[1];
*params_file_name = argv[2];
return true;
}
default: { return false; }
}
}
mlir::OpPassManager& phi_pass_manager = pm.nest<mlir::FuncOp>(); int main(int argc, char** argv) {
std::vector<infrt::Place> valid_places = {{infrt::TargetType::CPU, std::string model_file_name;
infrt::PrecisionType::FLOAT32, std::string params_file_name;
infrt::LayoutType::NCHW}}; if (parse_inputs(argc, argv, &model_file_name, &params_file_name)) {
phi_pass_manager.addPass(std::make_unique<infrt::phiOpCvtPass>(valid_places)); MLIRModelGenImpl myGen;
if (mlir::failed(pm.run(*module))) { auto module_ = myGen.ImportPaddleModel(model_file_name, params_file_name);
std::cout << "\npass failed!\n" << std::endl; module_.dump();
return 4;
} }
module->dump();
return 0;
} }
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream>
#include <string>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h"
int main(int argc, char** argv) {
static llvm::cl::opt<std::string> input_file(
llvm::cl::Positional,
llvm::cl::desc("Specify input filename"),
llvm::cl::init("-"));
llvm::cl::ParseCommandLineOptions(argc, argv);
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context);
context->loadAllAvailableDialects();
module->dump();
mlir::PassManager pm(context);
mlir::OpPassManager& phi_pass_manager = pm.nest<mlir::FuncOp>();
std::vector<infrt::Place> valid_places = {{infrt::TargetType::CPU,
infrt::PrecisionType::FLOAT32,
infrt::LayoutType::NCHW}};
phi_pass_manager.addPass(std::make_unique<infrt::phiOpCvtPass>(valid_places));
if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl;
return 4;
}
module->dump();
return 0;
}
...@@ -12,6 +12,7 @@ gather_srcs(infrt_src SRCS ...@@ -12,6 +12,7 @@ gather_srcs(infrt_src SRCS
function.cc function.cc
mlir_function_executable.cc mlir_function_executable.cc
mlir_program_executor.cc mlir_program_executor.cc
paddle_mlir.cc
) )
cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS})
...@@ -21,7 +22,7 @@ cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${ML ...@@ -21,7 +22,7 @@ cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${ML
cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_mlir_to_runtime_translate SRCS mlir_to_runtime_translate_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_mlir_to_runtime_translate SRCS mlir_to_runtime_translate_test.cc DEPS infrt ${MLIR_IR_LIBS})
add_executable(paddle-mlir-convert paddle_mlir.cc paddle_mlir_converter.cc) add_executable(paddle-mlir-convert paddle_mlir_converter.cc)
target_link_libraries(paddle-mlir-convert infrt ${MLIR_IR_LIBS}) target_link_libraries(paddle-mlir-convert infrt ${MLIR_IR_LIBS})
add_executable(infrtexec mlir_exec.cc) add_executable(infrtexec mlir_exec.cc)
target_link_libraries(infrtexec infrt ${MLIR_IR_LIBS}) target_link_libraries(infrtexec infrt ${MLIR_IR_LIBS})
configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py") configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py")
add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\"" add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\""
DEPENDS infrtopt infrtexec) DEPENDS infrtopt infrtexec phi-ir-exec)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir)
// RUN: infrtopt %s | FileCheck %s // RUN: phi-ir-exec %s
// CHECK-LABEL: @ops // CHECK-LABEL: @ops
func @ops() { func @ops() {
%a = pd.feed() {name="input0"} : !infrt.lod_tensor<?xf32,0> %a = pd.feed() {name="input0"} : !infrt.lod_tensor<?xf32,0>
......
...@@ -23,9 +23,10 @@ config.llvm_tools_dir = os.path.join(build_dir, "/third_party/install/llvm/lib") ...@@ -23,9 +23,10 @@ config.llvm_tools_dir = os.path.join(build_dir, "/third_party/install/llvm/lib")
infrtopt_bin = os.path.join(build_dir, "paddle/infrt/dialect/") infrtopt_bin = os.path.join(build_dir, "paddle/infrt/dialect/")
trtexec_bin = os.path.join(build_dir, "paddle/infrt/dialect/tensorrt/") trtexec_bin = os.path.join(build_dir, "paddle/infrt/dialect/tensorrt/")
infrtexec_bin = os.path.join(build_dir, "paddle/infrt/host_context/") infrtexec_bin = os.path.join(build_dir, "paddle/infrt/host_context/")
phi_ir_exec_bin = os.path.join(build_dir, "paddle/infrt/dialect/phi")
llvm_bin = os.path.join(build_dir, "third_party/install/llvm/bin/") llvm_bin = os.path.join(build_dir, "third_party/install/llvm/bin/")
config.environment['PATH'] = os.path.pathsep.join( config.environment['PATH'] = os.path.pathsep.join(
(infrtopt_bin, infrtexec_bin, trtexec_bin, llvm_bin, config.environment['PATH'])) (infrtopt_bin, infrtexec_bin, trtexec_bin, phi_ir_exec_bin, llvm_bin, config.environment['PATH']))
config.suffixes = ['.mlir'] config.suffixes = ['.mlir']
...@@ -92,7 +92,7 @@ function infrt_gen_and_build() { ...@@ -92,7 +92,7 @@ function infrt_gen_and_build() {
exit 7; exit 7;
fi fi
make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$? make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec phi-ir-exec phi-exec infrt_lib_dist paddle-mlir-convert;build_error=$?
if [ "$build_error" != 0 ];then if [ "$build_error" != 0 ];then
exit 7; exit 7;
fi fi
......
...@@ -19,7 +19,6 @@ import sys, os ...@@ -19,7 +19,6 @@ import sys, os
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.backward import append_backward
size = 2 size = 2
num_layers = 4 num_layers = 4
......
...@@ -16,7 +16,7 @@ import json ...@@ -16,7 +16,7 @@ import json
import sys import sys
attr_type_converter = {"i": 'SI32Attr', "b": 'BoolAttr', "l": 'SI64Attr'} attr_type_converter = {"i": 'SI32Attr', "b": 'BoolAttr', "l": 'SI64Attr'}
supported_kernels = ['sign', 'dot', 'digamma', 'conj'] supported_kernels = ['sign', 'dot', 'digamma', 'conj', 'abs', 'add_raw']
target_type_converter = {"CPU": "CPU", "GPU": "GPU"} target_type_converter = {"CPU": "CPU", "GPU": "GPU"}
layout_type_converter = { layout_type_converter = {
...@@ -66,7 +66,8 @@ def generate_attrs_info(op_name, attrs_info): ...@@ -66,7 +66,8 @@ def generate_attrs_info(op_name, attrs_info):
'digamma': [], 'digamma': [],
'lerp': [], 'lerp': [],
'cast': ['out_dtype', 'in_dtype'], 'cast': ['out_dtype', 'in_dtype'],
'abs': [] 'abs': [],
'add_raw': ['axis'],
} }
attrs_args_ = "" attrs_args_ = ""
if len(kernel_attrs_names[op_name]) == len(attrs_info): if len(kernel_attrs_names[op_name]) == len(attrs_info):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册