未验证 提交 767647ce 编写于 作者: H huzhiqiang 提交者: GitHub

[Infrt]Update kernel dialect (#40141)

上级 aeaf69b3
...@@ -56,6 +56,7 @@ paddle/infrt/dialect/pd_ops.td ...@@ -56,6 +56,7 @@ paddle/infrt/dialect/pd_ops.td
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td
tools/infrt/kernels.json tools/infrt/kernels.json
tools/infrt/kernel_signature.json
paddle/infrt/dialect/pd_ops_info.h paddle/infrt/dialect/pd_ops_info.h
.lit_test_times.txt .lit_test_times.txt
paddle/infrt/tests/dialect/Output paddle/infrt/tests/dialect/Output
......
...@@ -44,35 +44,41 @@ int main(int argc, char **argv) { ...@@ -44,35 +44,41 @@ int main(int argc, char **argv) {
paddle::framework::InitDefaultKernelSignatureMap(); paddle::framework::InitDefaultKernelSignatureMap();
auto &kernel_signature_map = phi::DefaultKernelSignatureMap::Instance(); auto &kernel_signature_map = phi::DefaultKernelSignatureMap::Instance();
auto &kernel_factory = phi::KernelFactory::Instance(); auto &kernel_factory = phi::KernelFactory::Instance();
std::cout << "{"; std::string kernel_signature_map_str{"{"};
for (const auto &op_kernel_pair : kernel_factory.kernels()) { for (const auto &op_kernel_pair : kernel_factory.kernels()) {
if (kernel_signature_map.Has(op_kernel_pair.first)) { if (kernel_signature_map.Has(op_kernel_pair.first)) {
std::cout << "\"" << op_kernel_pair.first << "\":{"; kernel_signature_map_str =
kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{";
auto &args = kernel_signature_map.Get(op_kernel_pair.first).args; auto &args = kernel_signature_map.Get(op_kernel_pair.first).args;
std::cout << "\"inputs\":["; kernel_signature_map_str += "\"inputs\":[";
auto inputs_ = std::get<0>(args); auto inputs_ = std::get<0>(args);
if (inputs_.size() > 0) std::cout << inputs_[0]; for (size_t i = 0; i < inputs_.size(); i++) {
for (size_t i = 1; i < inputs_.size(); i++) { kernel_signature_map_str =
std::cout << ",\"" << inputs_[i] << "\""; kernel_signature_map_str + "\"" + inputs_[i] + "\",";
} }
if (inputs_.size()) kernel_signature_map_str.pop_back();
std::cout << "],\"attrs\":["; kernel_signature_map_str += "],\"attrs\":[";
auto attrs_ = std::get<1>(args); auto attrs_ = std::get<1>(args);
if (attrs_.size() > 0) std::cout << attrs_[0]; for (size_t i = 0; i < attrs_.size(); i++) {
for (size_t i = 1; i < attrs_.size(); i++) { kernel_signature_map_str =
std::cout << ",\"" << attrs_[i] << "\""; kernel_signature_map_str + "\"" + attrs_[i] + "\",";
} }
if (attrs_.size()) kernel_signature_map_str.pop_back();
std::cout << "],\"outputs\":["; kernel_signature_map_str += "],\"outputs\":[";
auto outputs_ = std::get<2>(args); auto outputs_ = std::get<2>(args);
for (size_t i = 1; i < outputs_.size(); i++) { for (size_t i = 0; i < outputs_.size(); i++) {
std::cout << ",\"" << outputs_[i] << "\""; kernel_signature_map_str =
kernel_signature_map_str + "\"" + outputs_[i] + "\",";
} }
std::cout << "]},"; if (outputs_.size()) kernel_signature_map_str.pop_back();
kernel_signature_map_str += "]},";
} }
} }
std::cout << "}" << std::endl; kernel_signature_map_str.pop_back();
kernel_signature_map_str += "}\n";
std::cout << kernel_signature_map_str;
return 0; return 0;
} }
...@@ -125,10 +125,8 @@ void phiOpCvtPass::diapatchStage() { ...@@ -125,10 +125,8 @@ void phiOpCvtPass::diapatchStage() {
kernel_name = getPhiTargetPrefix(phi_kernel_desc.kernelType.target) + kernel_name = getPhiTargetPrefix(phi_kernel_desc.kernelType.target) +
kernel_name + kernel_name +
getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout) + getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision) +
getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision); getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout);
// mlir::OperationName operation_name = kernel_op.getOperation()->getName();
mlir::OperationName operation_name(kernel_name, kernel_op.getContext()); mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
mlir::OperationState operation_state(kernel_op.getLoc(), operation_name); mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);
......
...@@ -56,6 +56,7 @@ mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( ...@@ -56,6 +56,7 @@ mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel(
UpdateModelParams(program, &mainFunc); UpdateModelParams(program, &mainFunc);
UpdateModelOps(program); UpdateModelOps(program);
UpdateModelOutputs(program); UpdateModelOutputs(program);
return module_; return module_;
} }
...@@ -143,13 +144,14 @@ void MLIRModelGenImpl::UpdateModelParams( ...@@ -143,13 +144,14 @@ void MLIRModelGenImpl::UpdateModelParams(
const infrt::paddle::framework_proto::ProgramDesc &program, const infrt::paddle::framework_proto::ProgramDesc &program,
mlir::FuncOp *mainFunc) { mlir::FuncOp *mainFunc) {
// update input vars // update input vars
int input_index = 1;
for (auto &op_desc : main_block_.ops()) { for (auto &op_desc : main_block_.ops()) {
if (op_desc.type() == "feed") { if (op_desc.type() == "feed") {
for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) {
// update input variables // update input variables
auto &in = op_desc.outputs()[var_idx]; auto &in = op_desc.outputs()[var_idx];
std::string input_var_name = in.arguments(0); std::string input_var_name = in.arguments(0);
::mlir::Value input_ = mainFunc->getArgument(1); ::mlir::Value input_ = mainFunc->getArgument(input_index++);
params_map_.insert( params_map_.insert(
std::pair<std::string, mlir::Value>(input_var_name, input_)); std::pair<std::string, mlir::Value>(input_var_name, input_));
} }
...@@ -211,7 +213,6 @@ void MLIRModelGenImpl::buildOperation( ...@@ -211,7 +213,6 @@ void MLIRModelGenImpl::buildOperation(
const infrt::paddle::framework_proto::OpDesc &op_) { const infrt::paddle::framework_proto::OpDesc &op_) {
const std::string &op_name = "pd." + op_.type(); const std::string &op_name = "pd." + op_.type();
mlir::Location loc = mlir::UnknownLoc::get(context_); mlir::Location loc = mlir::UnknownLoc::get(context_);
llvm::SmallVector<mlir::Value, 4> operands = GetOpInputValue(op_); llvm::SmallVector<mlir::Value, 4> operands = GetOpInputValue(op_);
llvm::SmallVector<mlir::Type, 4> resultTypes = GetOpOutputType(op_); llvm::SmallVector<mlir::Type, 4> resultTypes = GetOpOutputType(op_);
llvm::SmallVector<mlir::NamedAttribute, 4> attrs = GetOpAttributes(op_); llvm::SmallVector<mlir::NamedAttribute, 4> attrs = GetOpAttributes(op_);
...@@ -227,7 +228,6 @@ llvm::SmallVector<mlir::Value, 4> MLIRModelGenImpl::GetOpInputValue( ...@@ -227,7 +228,6 @@ llvm::SmallVector<mlir::Value, 4> MLIRModelGenImpl::GetOpInputValue(
std::unordered_map<std::string, uint8_t> inputs_info = {}; std::unordered_map<std::string, uint8_t> inputs_info = {};
if (pd_dialect_inputs_info_map_.count(op_.type())) if (pd_dialect_inputs_info_map_.count(op_.type()))
inputs_info = pd_dialect_inputs_info_map_.at(op_.type()); inputs_info = pd_dialect_inputs_info_map_.at(op_.type());
for (int var_idx = 0; var_idx < op_.inputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_.inputs_size(); ++var_idx) {
auto &var = op_.inputs(var_idx); auto &var = op_.inputs(var_idx);
if (!var.arguments().empty()) { if (!var.arguments().empty()) {
...@@ -249,10 +249,8 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType( ...@@ -249,10 +249,8 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType(
// update op outputs info // update op outputs info
for (int var_idx = 0; var_idx < op_.outputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_.outputs_size(); ++var_idx) {
auto &var_name = op_.outputs(var_idx).arguments()[0]; auto &var_name = op_.outputs(var_idx).arguments()[0];
if (!pd_dialect_outputs_info.count(op_.outputs(var_idx).parameter())) if (!pd_dialect_outputs_info.count(op_.outputs(var_idx).parameter()))
continue; continue;
// update persistable tensors // update persistable tensors
for (int i = 0; i < main_block_.vars_size(); i++) { for (int i = 0; i < main_block_.vars_size(); i++) {
auto var_desc = main_block_.vars(i); auto var_desc = main_block_.vars(i);
...@@ -315,7 +313,6 @@ llvm::SmallVector<mlir::NamedAttribute, 4> MLIRModelGenImpl::GetOpAttributes( ...@@ -315,7 +313,6 @@ llvm::SmallVector<mlir::NamedAttribute, 4> MLIRModelGenImpl::GetOpAttributes(
llvm::ArrayRef<mlir::StringAttr> attr_names_ = llvm::ArrayRef<mlir::StringAttr> attr_names_ =
registered_op_name_.getAttributeNames(); registered_op_name_.getAttributeNames();
std::vector<mlir::StringAttr> attr_names_vec_ = attr_names_.vec(); std::vector<mlir::StringAttr> attr_names_vec_ = attr_names_.vec();
// update attrs // update attrs
for (int attrs_num = 0; attrs_num < op_.attrs_size(); attrs_num++) { for (int attrs_num = 0; attrs_num < op_.attrs_size(); attrs_num++) {
auto attr_name_ = op_.attrs(attrs_num).name(); auto attr_name_ = op_.attrs(attrs_num).name();
...@@ -351,11 +348,17 @@ llvm::SmallVector<mlir::NamedAttribute, 4> MLIRModelGenImpl::GetOpAttributes( ...@@ -351,11 +348,17 @@ llvm::SmallVector<mlir::NamedAttribute, 4> MLIRModelGenImpl::GetOpAttributes(
void MLIRModelGenImpl::RegisterOpOutputVars( void MLIRModelGenImpl::RegisterOpOutputVars(
const infrt::paddle::framework_proto::OpDesc &op_, const infrt::paddle::framework_proto::OpDesc &op_,
mlir::Operation *mlir_op_) { mlir::Operation *mlir_op_) {
std::unordered_map<std::string, uint8_t> pd_dialect_outputs_info =
pd_dialect_outputs_info_map_.at(op_.type());
// op outputs // op outputs
for (int var_idx = 0; var_idx < op_.outputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_.outputs_size(); ++var_idx) {
if (!pd_dialect_outputs_info.count(op_.outputs(var_idx).parameter()))
continue;
auto &var_name = op_.outputs(var_idx).arguments()[0]; auto &var_name = op_.outputs(var_idx).arguments()[0];
int index = pd_dialect_outputs_info[op_.outputs(var_idx).parameter()];
// output name // output name
auto var_ = mlir_op_->getResult(var_idx); auto var_ = mlir_op_->getResult(index);
params_map_.insert(std::pair<std::string, mlir::Value>(var_name, var_)); params_map_.insert(std::pair<std::string, mlir::Value>(var_name, var_));
} }
} }
......
...@@ -54,7 +54,7 @@ TEST(ElementwiseAdd, launcher_registry) { ...@@ -54,7 +54,7 @@ TEST(ElementwiseAdd, launcher_registry) {
host_context::KernelRegistry registry; host_context::KernelRegistry registry;
RegisterInferShapeLaunchers(&registry); RegisterInferShapeLaunchers(&registry);
ASSERT_GE(registry.size(), 1UL); ASSERT_GE(registry.size(), 1UL);
auto creator = registry.GetKernel("phi_cpu.add.any.float32"); auto creator = registry.GetKernel("phi_cpu.add.float32.any");
const phi::DDim dims({1, 2}); const phi::DDim dims({1, 2});
const phi::DataType dtype{phi::DataType::FLOAT32}; const phi::DataType dtype{phi::DataType::FLOAT32};
......
...@@ -6,7 +6,7 @@ func @sign_any_float32_execute() { ...@@ -6,7 +6,7 @@ func @sign_any_float32_execute() {
%ctx = "phi_dt.create_context.cpu" (%allocator): (!phi.allocator<CPU>) -> !phi.context<CPU> %ctx = "phi_dt.create_context.cpu" (%allocator): (!phi.allocator<CPU>) -> !phi.context<CPU>
%t = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!phi.allocator<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>) %t = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!phi.allocator<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> () "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%e = "phi_cpu.sign.any.float32"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>) %e = "phi_cpu.sign.float32.any"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// CHECK: dense_tensor: shape=shape[1], values=[1] // CHECK: dense_tensor: shape=shape[1], values=[1]
"phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> () "phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
......
...@@ -33,16 +33,17 @@ function update_pd_ops() { ...@@ -33,16 +33,17 @@ function update_pd_ops() {
rm -rf ${PADDLE_ROOT}/build && mkdir -p ${PADDLE_ROOT}/build rm -rf ${PADDLE_ROOT}/build && mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build
cmake .. -DWITH_PYTHON=ON -DWITH_GPU=OFF -DPYTHON_EXECUTABLE=`which python3` -DWITH_XBYAK=OFF -DWITH_NCCL=OFF -DWITH_RCCL=OFF -DWITH_CRYPTO=OFF cmake .. -DWITH_PYTHON=ON -DWITH_GPU=OFF -DPYTHON_EXECUTABLE=`which python3` -DWITH_XBYAK=OFF -DWITH_NCCL=OFF -DWITH_RCCL=OFF -DWITH_CRYPTO=OFF
make -j8 paddle_python print_pten_kernels make -j8 paddle_python print_pten_kernels kernel_signature_generator
cd ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build
./paddle/phi/tools/print_pten_kernels > ../tools/infrt/kernels.json ./paddle/phi/tools/print_pten_kernels > ../tools/infrt/kernels.json
./paddle/fluid/pybind/kernel_signature_generator > ../tools/infrt/kernel_signature.json
cd python/dist/ cd python/dist/
python3 -m pip uninstall -y paddlepaddle python3 -m pip uninstall -y paddlepaddle
python3 -m pip install *whl python3 -m pip install *whl
# update pd_ops.td # update pd_ops.td
cd ${PADDLE_ROOT}/tools/infrt/ cd ${PADDLE_ROOT}/tools/infrt/
python3 generate_pd_op_dialect_from_paddle_op_maker.py python3 generate_pd_op_dialect_from_paddle_op_maker.py
python3 generate_phi_kernel_dialect.py ./kernels.json python3 generate_phi_kernel_dialect.py
} }
function init() { function init() {
......
...@@ -14,9 +14,16 @@ ...@@ -14,9 +14,16 @@
import json import json
import sys import sys
import os
attr_type_converter = {"i": 'SI32Attr', "b": 'BoolAttr', "l": 'SI64Attr'} from get_compat_kernel_signature import get_compat_kernels_info
supported_kernels = ['sign', 'dot', 'digamma', 'conj', 'abs', 'add_raw']
#TODO @DannyIsFunny: more attr types need to be supported.
attr_type_converter = {
"i": 'SI32Attr',
"b": 'BoolAttr',
"l": 'SI64Attr',
"f": 'F32Attr'
}
target_type_converter = {"CPU": "CPU", "GPU": "GPU"} target_type_converter = {"CPU": "CPU", "GPU": "GPU"}
layout_type_converter = { layout_type_converter = {
...@@ -39,40 +46,34 @@ precision_type_converter = { ...@@ -39,40 +46,34 @@ precision_type_converter = {
"bool": "BOOL" "bool": "BOOL"
} }
kernel_types_info_file = "./kernels.json"
kernel_signature_info_file = "./kernel_signature.json"
def generate_kernel_name(op_name, place_str): def generate_kernel_name(op_name, place_str):
[target_, layout_, precision_] = place_str[1:-1].split(',') [target_, layout_, precision_] = place_str[1:-1].split(',')
target_ = target_type_converter[target_.strip()] target_ = target_type_converter[target_.strip()]
layout_ = layout_type_converter[layout_.strip()] layout_ = layout_type_converter[layout_.strip()]
precision_ = precision_type_converter[precision_.strip()] precision_ = precision_type_converter[precision_.strip()]
class_name_ = "{}{}".format(
op_name.replace("_", "").title(), "".join([
target_.strip().title(), precision_.strip(), layout_.strip().title()
.title()
]))
alias_ = "{}.{}".format(op_name, ".".join( alias_ = "{}.{}".format(op_name, ".".join(
[target_.strip(), layout_.strip(), precision_.strip()])) [target_.strip(), precision_.strip(), layout_.strip()]))
return alias_ return alias_, class_name_
def generate_attrs_info(op_name, attrs_info): def generate_attrs_info(op_name, attrs_info):
kernel_attrs_names = { kernel_attrs_names = {}
'split': ['sections', 'num', 'axis', 'mkldnn_data_type'],
'sign': [],
'masked_select': [],
'trace': ['offset', 'axis1', 'axis2'],
'concat': ['axis'],
'empty': ['shape', 'dtype'],
'conj': [],
'norm': ['axis', 'epsilon', 'is_test'],
'histogram': ['bins', 'min', 'max'],
'dot': [],
'scale': ['scale', 'bias', 'bias_after_scale'],
'digamma': [],
'lerp': [],
'cast': ['out_dtype', 'in_dtype'],
'abs': [],
'add_raw': ['axis'],
}
attrs_args_ = "" attrs_args_ = ""
if len(kernel_attrs_names[op_name]) == len(attrs_info): with open(kernel_signature_info_file) as f:
kernel_attrs_names = json.load(f)
kernel_attrs_names.update(get_compat_kernels_info())
if len(kernel_attrs_names[op_name]["attrs"]) == len(attrs_info):
for index in range(len(attrs_info)): for index in range(len(attrs_info)):
attr_name = kernel_attrs_names[op_name][index] attr_name = kernel_attrs_names[op_name]["attrs"][index]
attr_type = attr_type_converter[attrs_info[index]] attr_type = attr_type_converter[attrs_info[index]]
attrs_args_ += '{type_}:${name_},'.format( attrs_args_ += '{type_}:${name_},'.format(
type_=attr_type, name_=attr_name) type_=attr_type, name_=attr_name)
...@@ -97,7 +98,11 @@ def generate_arguments_info(op_name, input_info, attr_info): ...@@ -97,7 +98,11 @@ def generate_arguments_info(op_name, input_info, attr_info):
input_args = generate_inputs_info(input_info) input_args = generate_inputs_info(input_info)
attr_args = generate_attrs_info(op_name, attr_info) attr_args = generate_attrs_info(op_name, attr_info)
context_args = "Context:$dev_ctx" context_args = "Context:$dev_ctx"
argument_ = "{},{},{}".format(context_args, input_args, attr_args) argument_list = [context_args] + input_args.split(",") + attr_args.split(
",")
while ("" in argument_list):
argument_list.remove("")
argument_ = ",".join(argument_list)
return (("let arguments = (ins {});".format(argument_.strip(",")))) return (("let arguments = (ins {});".format(argument_.strip(","))))
...@@ -116,6 +121,10 @@ def generate_results_info(output_info): ...@@ -116,6 +121,10 @@ def generate_results_info(output_info):
def generate_supported_kernel_list(load_dict): def generate_supported_kernel_list(load_dict):
supported_kernels_list_ = [] supported_kernels_list_ = []
kernel_attrs_names = {}
with open(kernel_signature_info_file) as f:
kernel_attrs_names = json.load(f)
kernel_attrs_names.update(get_compat_kernels_info())
for op_name in load_dict: for op_name in load_dict:
kernel_list = load_dict[op_name] kernel_list = load_dict[op_name]
for kernel_info in kernel_list: for kernel_info in kernel_list:
...@@ -125,13 +134,10 @@ def generate_supported_kernel_list(load_dict): ...@@ -125,13 +134,10 @@ def generate_supported_kernel_list(load_dict):
for attribute in attributes: for attribute in attributes:
if attribute not in attr_type_converter: if attribute not in attr_type_converter:
flag = False flag = False
if flag: if flag and op_name in kernel_attrs_names:
supported_kernels_list_.append(op_name) supported_kernels_list_.append(op_name)
alias_ = generate_kernel_dialect(op_name, kernel_alias_,
kernel_info[kernel_alias_])
supported_kernels_list_ = list(set(supported_kernels_list_)) supported_kernels_list_ = list(set(supported_kernels_list_))
print(supported_kernels_list_) return supported_kernels_list_
def scan_kernel_info(load_dict): def scan_kernel_info(load_dict):
...@@ -156,16 +162,14 @@ def scan_kernel_info(load_dict): ...@@ -156,16 +162,14 @@ def scan_kernel_info(load_dict):
def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info): def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info):
alias = generate_kernel_name(op_name, kernel_alias_) alias, class_name = generate_kernel_name(op_name, kernel_alias_)
summary = 'let summary = "{name}";'.format(name=alias) summary = 'let summary = "{name}";'.format(name=alias)
dialect_name = alias.split(".") dialect_name = alias.split(".")
dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[ dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[
3] 3]
header = 'def {kernel_name} : PDTCPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format( header = 'def {kernel_name} : PDTCPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format(
kernel_name=alias.replace(".", ""), kernel_name=class_name, name=dialect_name.lower(), left_brace="{")
name=dialect_name.lower(),
left_brace="{")
inputs_ = kernel_info["input"] inputs_ = kernel_info["input"]
attributes = kernel_info["attribute"] attributes = kernel_info["attribute"]
...@@ -185,16 +189,14 @@ def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info): ...@@ -185,16 +189,14 @@ def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info):
def generate_gpu_kernel_dialect(op_name, kernel_alias_, kernel_info): def generate_gpu_kernel_dialect(op_name, kernel_alias_, kernel_info):
alias = generate_kernel_name(op_name, kernel_alias_) alias, class_name = generate_kernel_name(op_name, kernel_alias_)
summary = 'let summary = "{name}";'.format(name=alias) summary = 'let summary = "{name}";'.format(name=alias)
dialect_name = alias.split(".") dialect_name = alias.split(".")
dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[ dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[
3] 3]
header = 'def {kernel_name} : PDTGPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format( header = 'def {kernel_name} : PDTGPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format(
kernel_name=alias.replace(".", ""), kernel_name=class_name, name=dialect_name.lower(), left_brace="{")
name=dialect_name.lower(),
left_brace="{")
inputs_ = kernel_info["input"] inputs_ = kernel_info["input"]
attributes = kernel_info["attribute"] attributes = kernel_info["attribute"]
arguments = generate_arguments_info(op_name, inputs_, attributes) arguments = generate_arguments_info(op_name, inputs_, attributes)
...@@ -236,14 +238,17 @@ def get_kernel_target(kernel_alias_): ...@@ -236,14 +238,17 @@ def get_kernel_target(kernel_alias_):
return target[0] return target[0]
def main(path_): def main():
with open(path_, "r") as f: with open(kernel_types_info_file, "r") as f:
load_dict = json.load(f) load_dict = json.load(f)
head = generate_dialect_head() head = generate_dialect_head()
cpu_registry_ = "" cpu_registry_ = ""
gpu_registry_ = "" gpu_registry_ = ""
supported_kernels = generate_supported_kernel_list(load_dict)
print("Supported kernels:")
print(supported_kernels)
for op_name in load_dict: for op_name in load_dict:
if op_name not in supported_kernels: if op_name not in supported_kernels:
continue continue
...@@ -273,5 +278,12 @@ def main(path_): ...@@ -273,5 +278,12 @@ def main(path_):
if __name__ == '__main__': if __name__ == '__main__':
path = sys.argv[1] if not os.path.exists(kernel_types_info_file):
main(path) print("Error: '{file_name}' not exist!".format(
file_name=kernel_types_info_file))
if not os.path.exists(kernel_signature_info_file):
print("Error: '{file_name}' not exist!".format(
file_name=kernel_signature_info_file))
if os.path.exists(kernel_types_info_file) and os.path.exists(
kernel_signature_info_file):
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import json
def parse_compat_registry(kernel_info):
name, inputs_str, attrs_str, outputs_str = kernel_info.split(",{")
kernel_info = {}
kernel_info["inputs"] = inputs_str[:-1].split(",")
kernel_info["attrs"] = attrs_str[:-1].split(",")
kernel_info["outputs"] = outputs_str[:-1].split(",")
return name, kernel_info
def remove_grad_registry(kernels_registry):
clean_kernel_registry = {}
for registry in kernels_registry:
if (not "_grad" in registry):
clean_kernel_registry[registry] = kernels_registry[registry]
return clean_kernel_registry
def get_compat_kernels_info():
kernels_info = {}
compat_files = os.listdir("../../paddle/phi/ops/compat")
for file_ in compat_files:
if not ".cc" in file_:
compat_files.remove(file_)
for file_ in compat_files:
with open("../../paddle/phi/ops/compat/" + file_) as in_file:
txt = in_file.readlines()
content = ""
registry = False
for line in txt:
if ("KernelSignature(" in line):
content = ""
registry = True
if (registry):
content += line
if (registry and ";" in line):
data = content.replace("\n", "").replace(
" ", "").strip("return").strip(
"KernelSignature(").strip("\);").replace("\"", "")
registry = False
name, registry_info = parse_compat_registry(data)
if name in kernels_info:
cur_reg = kernels_info[name]
kernels_info[name]["inputs"] = list(
set(registry_info["inputs"] + kernels_info[name][
"inputs"]))
kernels_info[name]["attrs"] = list(
set(registry_info["attrs"] + kernels_info[name][
"attrs"]))
kernels_info[name]["outputs"] = list(
set(registry_info["outputs"] + kernels_info[name][
"outputs"]))
else:
kernels_info[name] = registry_info
compat_registry_ = remove_grad_registry(kernels_info)
return compat_registry_
...@@ -219,8 +219,8 @@ def gen_register_info(resources: List[List[str]]): ...@@ -219,8 +219,8 @@ def gen_register_info(resources: List[List[str]]):
for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
kernel_func = gen_kernel_func(update_item[3], ctx_name, kernel_func = gen_kernel_func(update_item[3], ctx_name,
origin_dtype) origin_dtype)
ir_name = 'phi_cpu.' + update_item[0].lower() + '.' + update_item[ ir_name = 'phi_cpu.' + update_item[0].lower(
2].lower() + '.' + ir_dtype ) + '.' + ir_dtype + '.' + update_item[2].lower()
res += f""" res += f"""
registry->AddKernel("{ir_name}",""" registry->AddKernel("{ir_name}","""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册