未验证 提交 1035d21f 编写于 作者: H huzhiqiang 提交者: GitHub

refine data loader api in infrt (#39580)

* update generate_pd_op_dialect_from_paddle_op_maker.py

* update mlir tensor load interface

* refine

* fix bug

* fix

* refine

* fix

* 3

* fix

* codestyle
Co-authored-by: 圣颖君's avatarweishengying <1343838695@qq.com>
上级 1c9b2483
...@@ -112,6 +112,7 @@ def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> { ...@@ -112,6 +112,7 @@ def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> {
let verifier = ?; let verifier = ?;
} }
def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> { def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> {
let summary = "dt.tensor_map_get_tensor operation"; let summary = "dt.tensor_map_get_tensor operation";
...@@ -122,7 +123,7 @@ def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> { ...@@ -122,7 +123,7 @@ def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> {
// input path of model params. // input path of model params.
let arguments = (ins let arguments = (ins
TensorMapType:$map, TensorMapType:$map,
StringType:$name StrAttr:$name
); );
let results = (outs TensorType:$output); let results = (outs TensorType:$output);
let assemblyFormat = "`(` operands `)` attr-dict `->` type($output)"; let assemblyFormat = "`(` operands `)` attr-dict `->` type($output)";
......
...@@ -171,7 +171,7 @@ void MLIRModelGenImpl::UpdateModelParams( ...@@ -171,7 +171,7 @@ void MLIRModelGenImpl::UpdateModelParams(
builder_, builder_,
&precision_); &precision_);
mlir::Type type_ = mlir::RankedTensorType::get(dims, precision_); mlir::Type type_ = mlir::RankedTensorType::get(dims, precision_);
auto op = builder_.create<infrt::dt::GetParamOp>( auto op = builder_.create<infrt::dt::TensorMapGetTensorOp>(
mlir::UnknownLoc::get(context_), type_, map, name); mlir::UnknownLoc::get(context_), type_, map, name);
params_map_.insert(std::pair<std::string, mlir::Value>( params_map_.insert(std::pair<std::string, mlir::Value>(
var_desc.name(), op.getOperation()->getResult(0))); var_desc.name(), op.getOperation()->getResult(0)));
...@@ -224,15 +224,14 @@ llvm::SmallVector<mlir::Value, 4> MLIRModelGenImpl::GetOpInputValue( ...@@ -224,15 +224,14 @@ llvm::SmallVector<mlir::Value, 4> MLIRModelGenImpl::GetOpInputValue(
const infrt::paddle::framework_proto::OpDesc &op_) { const infrt::paddle::framework_proto::OpDesc &op_) {
llvm::SmallVector<mlir::Value, 4> operands; llvm::SmallVector<mlir::Value, 4> operands;
std::vector<std::string> inputs_info = {}; std::unordered_map<std::string, uint8_t> inputs_info = {};
if (pd_dialect_inputs_info_map_.count(op_.type())) if (pd_dialect_inputs_info_map_.count(op_.type()))
inputs_info = pd_dialect_inputs_info_map_.at(op_.type()); inputs_info = pd_dialect_inputs_info_map_.at(op_.type());
for (int var_idx = 0; var_idx < op_.inputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_.inputs_size(); ++var_idx) {
auto &var = op_.inputs(var_idx); auto &var = op_.inputs(var_idx);
if (!var.arguments().empty()) { if (!var.arguments().empty()) {
if (!std::count(inputs_info.begin(), inputs_info.end(), var.parameter())) if (!inputs_info.count(var.parameter())) continue;
continue;
operands.push_back((params_map_[var.arguments()[0]])); operands.push_back((params_map_[var.arguments()[0]]));
} }
} }
...@@ -243,7 +242,7 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType( ...@@ -243,7 +242,7 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType(
const infrt::paddle::framework_proto::OpDesc &op_) { const infrt::paddle::framework_proto::OpDesc &op_) {
llvm::SmallVector<mlir::Type, 4> resultTypes; llvm::SmallVector<mlir::Type, 4> resultTypes;
std::vector<std::string> pd_dialect_outputs_info = {}; std::unordered_map<std::string, uint8_t> pd_dialect_outputs_info = {};
if (pd_dialect_outputs_info_map_.count(op_.type())) if (pd_dialect_outputs_info_map_.count(op_.type()))
pd_dialect_outputs_info = pd_dialect_outputs_info_map_.at(op_.type()); pd_dialect_outputs_info = pd_dialect_outputs_info_map_.at(op_.type());
...@@ -251,9 +250,7 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType( ...@@ -251,9 +250,7 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType(
for (int var_idx = 0; var_idx < op_.outputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_.outputs_size(); ++var_idx) {
auto &var_name = op_.outputs(var_idx).arguments()[0]; auto &var_name = op_.outputs(var_idx).arguments()[0];
if (!std::count(pd_dialect_outputs_info.begin(), if (!pd_dialect_outputs_info.count(op_.outputs(var_idx).parameter()))
pd_dialect_outputs_info.end(),
op_.outputs(var_idx).parameter()))
continue; continue;
// update persistable tensors // update persistable tensors
......
...@@ -54,10 +54,11 @@ TensorMap LoadParams(const std::string &path) { ...@@ -54,10 +54,11 @@ TensorMap LoadParams(const std::string &path) {
} }
void TensorMapGetTensor(TensorMap map, void TensorMapGetTensor(TensorMap map,
const std::string &name, DenseHostTensor *out,
DenseHostTensor *out) { Attribute<std::string> name) {
auto it = map.find(name); auto it = map.find(name.get());
CHECK(it != map.end()) << "No tensor called " << name << " in the TensorMap"; CHECK(it != map.end()) << "No tensor called " << name.get()
<< " in the TensorMap";
*out = *it->second; *out = *it->second;
} }
......
...@@ -6,8 +6,7 @@ func @load_tensor_map() { ...@@ -6,8 +6,7 @@ func @load_tensor_map() {
%size = dt.tensor_map_get_size(%map) -> i32 %size = dt.tensor_map_get_size(%map) -> i32
infrt.print.i32 %size infrt.print.i32 %size
%tensor_name = infrt.get_string("fc_bias") %a = dt.tensor_map_get_tensor(%map) {name="fc_bias"} -> !infrt.tensor<X86, NCHW, F32>
%a = dt.tensor_map_get_tensor(%map, %tensor_name) -> !infrt.tensor<X86, NCHW, F32>
// CHECK: tensor: shape=shape[2], values=[0, 0] // CHECK: tensor: shape=shape[2], values=[0, 0]
dt.print_tensor (%a : !infrt.tensor<X86, NCHW, F32>) dt.print_tensor (%a : !infrt.tensor<X86, NCHW, F32>)
......
...@@ -90,7 +90,7 @@ function infrt_gen_and_build() { ...@@ -90,7 +90,7 @@ function infrt_gen_and_build() {
exit 7; exit 7;
fi fi
make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec infrt_lib_dist;build_error=$? make -j ${parallel_number} infrt infrtopt infrtexec test_infrt_exec trt-exec infrt_lib_dist paddle-mlir-convert;build_error=$?
if [ "$build_error" != 0 ];then if [ "$build_error" != 0 ];then
exit 7; exit 7;
fi fi
......
...@@ -110,10 +110,92 @@ def get_all_ops_desc(): ...@@ -110,10 +110,92 @@ def get_all_ops_desc():
return all_op_protos_dict return all_op_protos_dict
def generate_all_ops_inputs_outputs_map(op_descs):
# 1. Collect input and output name information of each Op
original_ops_ = get_original_ops()
ops_inputs_map = {}
ops_outputs_map = {}
for op_type, op_proto in op_descs.items():
if op_type not in original_ops_:
continue
inputs = list()
outpus = list()
for input_ in op_proto[INPUTS]:
if op_proto[INPUTS][input_][EXTRA] != True and op_proto[INPUTS][
input_][INTERMEDIATE] != True:
inputs.append(input_)
for output_ in op_proto[OUTPUTS]:
if op_proto[OUTPUTS][output_][EXTRA] != True and op_proto[OUTPUTS][
output_][INTERMEDIATE] != True:
outpus.append(output_)
ops_inputs_map[op_type] = inputs
ops_outputs_map[op_type] = outpus
# 2. Generate Cpp style map str
cpp_style_ops_inputs_map_str = ""
start_ = "#include <unordered_map>\n#include <vector>\n#include <string>\n" + \
"const std::unordered_map<std::string, std::unordered_map<std::string, uint8_t>> pd_dialect_inputs_info_map_ = {\n"
ops_inputs_str = ""
for ele in ops_inputs_map.items():
op_name = ele[0]
op_inputs = ele[1]
op_inputs_str = "{"
input_idx = 0
for op_input in op_inputs:
op_input_str = '{left_brace}"{op_input}", {input_idx}{right_brace}, '.format(
left_brace="{",
op_input=op_input,
input_idx=input_idx,
right_brace="}")
input_idx = input_idx + 1
op_inputs_str = op_inputs_str + op_input_str
op_inputs_str = op_inputs_str[:-2] + "}"
pair = '{left_brace}"{op_name}", {op_inputs}{right_brace},\n'.format(
left_brace="{",
op_name=op_name,
op_inputs=op_inputs_str,
right_brace="}")
ops_inputs_str = ops_inputs_str + " " + pair
ops_inputs_str = ops_inputs_str[:-2]
cpp_style_ops_inputs_map_str = start_ + ops_inputs_str + "\n};"
cpp_style_ops_outputs_map_str = ""
start_ = "const std::unordered_map<std::string, std::unordered_map<std::string, uint8_t>> pd_dialect_outputs_info_map_ = {\n"
ops_outputs_str = ""
for ele in ops_outputs_map.items():
op_name = ele[0]
op_outputs = ele[1]
op_outputs_str = "{"
output_idx = 0
for op_output in op_outputs:
op_output_str = '{left_brace}"{op_output}", {output_idx}{right_brace}, '.format(
left_brace="{",
op_output=op_output,
output_idx=output_idx,
right_brace="}")
output_idx = output_idx + 1
op_outputs_str = op_outputs_str + op_output_str
op_outputs_str = op_outputs_str[:-2] + "}"
pair = '{left_brace}"{op_name}", {op_outputs}{right_brace},\n'.format(
left_brace="{",
op_name=op_name,
op_outputs=op_outputs_str,
right_brace="}")
ops_outputs_str = ops_outputs_str + " " + pair
ops_outputs_str = ops_outputs_str[:-2]
cpp_style_ops_outputs_map_str = start_ + ops_outputs_str + "\n};"
# 3. Write to header file
dst_head_file = "../../paddle/infrt/dialect/pd_ops_info.h"
with open(dst_head_file, 'w') as ops_inputs_outputs_head_file:
ops_inputs_outputs_head_file.write(cpp_style_ops_inputs_map_str)
ops_inputs_outputs_head_file.write("\n\n")
ops_inputs_outputs_head_file.write(cpp_style_ops_outputs_map_str)
# funtion to generate paddle op dialect file # funtion to generate paddle op dialect file
def convert_op_proto_into_mlir(op_descs): def convert_op_proto_into_mlir(op_descs):
dst_dialect_file = "../../paddle/infrt/dialect/pd_ops.td" dst_dialect_file = "../../paddle/infrt/dialect/pd_ops.td"
dialect_info_file = "../../paddle/infrt/dialect/pd_ops_info.h"
custom_dialect_file = "custom_pdop.td" custom_dialect_file = "custom_pdop.td"
# 1. Head files # 1. Head files
...@@ -153,41 +235,38 @@ def convert_op_proto_into_mlir(op_descs): ...@@ -153,41 +235,38 @@ def convert_op_proto_into_mlir(op_descs):
original_ops_ = get_original_ops() original_ops_ = get_original_ops()
automatically_generated_op_dialect = [] automatically_generated_op_dialect = []
ops_inputs_map_ = {}
ops_outputs_map_ = {}
for op_type, op_proto in op_descs.items(): for op_type, op_proto in op_descs.items():
if (op_type in skipped_op_list) or (op_type not in original_ops_): if (op_type in skipped_op_list) or (op_type not in original_ops_):
continue continue
automatically_generated_op_dialect.append(op_type) automatically_generated_op_dialect.append(op_type)
# 2.1 OpDef # 2.1 OpDef
HEAD = "def PD_" + op_type.capitalize( HEAD = 'def PD_{op_type_capitalize}Op : PD_Op<"{op_type}", [NoSideEffect]> {left_brace}\n'.format(
) + "Op : PD_Op<\"" + op_type + "\", [NoSideEffect]> {\n" op_type_capitalize=op_type.capitalize(),
SUMMARY = " let summary = \"" + op_type + " op\";\n" op_type=op_type,
left_brace="{")
SUMMARY = ' let summary = "{} op";\n'.format(op_type)
CANONICALIZATION = "let hasCanonicalizer = 1;" if op_type in ops_having_canonicalization else "" CANONICALIZATION = "let hasCanonicalizer = 1;" if op_type in ops_having_canonicalization else ""
# 2.2 Description # 2.2 Description
DESCRIPTION = " let description = [{\n" contents = ""
contents = (op_proto[COMMENT]).split("\n") origin_contents = (op_proto[COMMENT]).split("\n")
for line_ in contents: for line_ in origin_contents:
DESCRIPTION = DESCRIPTION + " " + line_ + "\n" contents = contents + " {}\n".format(line_)
DESCRIPTION += " }];\n" DESCRIPTION = " let description = [{left_brace}\n{description} {right_brace}];\n".format(
left_brace="{", description=contents, right_brace="}")
# 2.3 arguments info # 2.3 arguments info
ARGUMENTS = "" ARGUMENTS = ""
if (len(op_proto[INPUTS]) > 0 or len(op_proto[ATTRS]) > 0): if (len(op_proto[INPUTS]) > 0 or len(op_proto[ATTRS]) > 0):
ARGUMENTS = " let arguments = (ins " ARGUMENTS = " let arguments = (ins "
# 2.3.1 inputs # 2.3.1 inputs
ins_cache_list_ = []
for input_ in op_proto[INPUTS]: for input_ in op_proto[INPUTS]:
if op_proto[INPUTS][input_][EXTRA] != True and op_proto[INPUTS][ if op_proto[INPUTS][input_][EXTRA] != True and op_proto[INPUTS][
input_][INTERMEDIATE] != True: input_][INTERMEDIATE] != True:
ins_cache_list_.append(input_)
if op_proto[INPUTS][input_][DUPLICABLE] != "true": if op_proto[INPUTS][input_][DUPLICABLE] != "true":
ARGUMENTS = ARGUMENTS + " PD_Tensor:$" + input_ + "," ARGUMENTS = ARGUMENTS + " PD_Tensor:$" + input_ + ","
else: else:
ARGUMENTS = ARGUMENTS + " PD_Tensor_Array:$" + input_ + "," ARGUMENTS = ARGUMENTS + " PD_Tensor_Array:$" + input_ + ","
ops_inputs_map_[op_type] = ins_cache_list_
# unsupported: BLOCK = 8; BLOCKS = 10; # unsupported: BLOCK = 8; BLOCKS = 10;
attr_mlir_converter = { attr_mlir_converter = {
0: 'SI32Attr', 0: 'SI32Attr',
...@@ -252,19 +331,17 @@ def convert_op_proto_into_mlir(op_descs): ...@@ -252,19 +331,17 @@ def convert_op_proto_into_mlir(op_descs):
# 2.4 results info # 2.4 results info
RESULTS = "" RESULTS = ""
if (len(op_proto[OUTPUTS]) > 0): if (len(op_proto[OUTPUTS]) > 0):
RESULTS = "\n let results = (outs " outputs = ""
outs_cache_list_ = []
for output_ in op_proto[OUTPUTS]: for output_ in op_proto[OUTPUTS]:
if op_proto[OUTPUTS][output_][EXTRA] != True and op_proto[ if op_proto[OUTPUTS][output_][EXTRA] != True and op_proto[
OUTPUTS][output_][INTERMEDIATE] != True: OUTPUTS][output_][INTERMEDIATE] != True:
outs_cache_list_.append(output_)
if op_proto[OUTPUTS][output_][DUPLICABLE] != "true": if op_proto[OUTPUTS][output_][DUPLICABLE] != "true":
RESULTS = RESULTS + "PD_Tensor:$" + output_ + "," outputs = outputs + "PD_Tensor:${},".format(output_)
else: else:
RESULTS = RESULTS + "PD_Tensor_Array:$" + output_ + "," outputs = outputs + "PD_Tensor_Array:${},".format(
print(HEAD + " PD_Tensor_Array:$" + output_ + ",") output_)
ops_outputs_map_[op_type] = outs_cache_list_ RESULTS = "\n let results = (outs {});\n".format(outputs[:-1])
RESULTS = RESULTS[:-1] + ");\n"
with open(dst_dialect_file, 'a') as ops_mlir_file: with open(dst_dialect_file, 'a') as ops_mlir_file:
ops_mlir_file.write(HEAD) ops_mlir_file.write(HEAD)
ops_mlir_file.write(SUMMARY) ops_mlir_file.write(SUMMARY)
...@@ -278,29 +355,6 @@ def convert_op_proto_into_mlir(op_descs): ...@@ -278,29 +355,6 @@ def convert_op_proto_into_mlir(op_descs):
print("Automatically generated op dialects num: " + str( print("Automatically generated op dialects num: " + str(
len(automatically_generated_op_dialect))) len(automatically_generated_op_dialect)))
with open(dialect_info_file, 'w') as pd_ops_info_file:
pd_ops_info_file.write(
"#include<map>\n#include<string>\n#include<vector>\n")
pd_ops_info_file.write(
"const std::map<std::string, std::vector<std::string>> pd_dialect_inputs_info_map_ = {\n"
)
for data_ in ops_inputs_map_:
pd_ops_info_file.write(" {\"" + data_ + "\", {")
for var_ in ops_inputs_map_[data_]:
pd_ops_info_file.write("\"" + var_ + "\",")
pd_ops_info_file.write("}},\n")
pd_ops_info_file.write("};\n")
pd_ops_info_file.write(
"const std::map<std::string, std::vector<std::string>> pd_dialect_outputs_info_map_ = {\n"
)
for data_ in ops_outputs_map_:
pd_ops_info_file.write(" {\"" + data_ + "\", {")
for var_ in ops_outputs_map_[data_]:
pd_ops_info_file.write("\"" + var_ + "\",")
pd_ops_info_file.write("}},\n")
pd_ops_info_file.write("};\n")
# 3. custom op dialect and end of file # 3. custom op dialect and end of file
with open(dst_dialect_file, 'a') as ops_mlir_file: with open(dst_dialect_file, 'a') as ops_mlir_file:
with open(custom_dialect_file, 'r') as custom_ops_file: with open(custom_dialect_file, 'r') as custom_ops_file:
...@@ -313,4 +367,5 @@ def convert_op_proto_into_mlir(op_descs): ...@@ -313,4 +367,5 @@ def convert_op_proto_into_mlir(op_descs):
if __name__ == "__main__": if __name__ == "__main__":
all_op_protos_dict = get_all_ops_desc() all_op_protos_dict = get_all_ops_desc()
generate_all_ops_inputs_outputs_map(all_op_protos_dict)
convert_op_proto_into_mlir(all_op_protos_dict) convert_op_proto_into_mlir(all_op_protos_dict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册