未验证 提交 2b6da4de 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Supported auto code gen for sparse kernels (#40276)

上级 a07f19ee
set(api_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml") set(api_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml,${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml")
set(backward_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/backward.yaml") set(backward_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/backward.yaml,${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml")
set(tmp_forwards_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.cc") set(tmp_forwards_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.cc")
set(tmp_forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.h") set(tmp_forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.h")
set(tmp_nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/tmp_nodes.cc") set(tmp_nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/tmp_nodes.cc")
......
...@@ -23,6 +23,7 @@ core_ops_returns_info = {} ...@@ -23,6 +23,7 @@ core_ops_returns_info = {}
core_ops_args_info = {} core_ops_args_info = {}
core_ops_args_type_info = {} core_ops_args_type_info = {}
namespace = ""
yaml_types_mapping = { yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
...@@ -125,6 +126,7 @@ def GetAutoGradMetaVectorName(string): ...@@ -125,6 +126,7 @@ def GetAutoGradMetaVectorName(string):
def ReadFwdFile(filepath): def ReadFwdFile(filepath):
f = open(filepath, 'r') f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader) contents = yaml.load(f, Loader=yaml.FullLoader)
f.close()
return contents return contents
...@@ -133,9 +135,13 @@ def ReadBwdFile(filepath): ...@@ -133,9 +135,13 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader) contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {} ret = {}
for content in contents: for content in contents:
assert 'backward_api' in content.keys() if 'backward_api' in content.keys():
api_name = content['backward_api'] api_name = content['backward_api']
else:
assert False
ret[api_name] = content ret[api_name] = content
f.close()
return ret return ret
...@@ -608,16 +614,23 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -608,16 +614,23 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str += f"return returns;\n" returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name) grad_node_name = GetGradNodeName(fwd_api_name)
if len(namespace) > 0:
grad_api_namespace = f"paddle::experimental::{namespace}"
else:
grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """ FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{
// Call grad_api function // Call grad_api function
auto grad_api_returns = paddle::experimental::{}({}); auto grad_api_returns = {}::{}({});
{} {}
}} }}
""" """
node_definition_str = FUNCTION_TEMPLATE.format( node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, bwd_api_name, grad_api_args_str, returns_str) grad_node_name, grad_api_namespace, bwd_api_name, grad_api_args_str,
returns_str)
return node_definition_str return node_definition_str
...@@ -850,7 +863,11 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -850,7 +863,11 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
function_name = fwd_api_name function_name = fwd_api_name
else: else:
function_name = fwd_api_name + "_intermediate" function_name = fwd_api_name + "_intermediate"
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
if len(namespace) > 0:
forward_call_str = f"auto api_result = paddle::experimental::{namespace}::{function_name}({inputs_call_args_str});"
else:
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs # Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) - len( num_outputs = len(forward_outputs_position_map.keys()) - len(
...@@ -1002,6 +1019,7 @@ def GenerateNodeCCFile(filepath, node_definition_str): ...@@ -1002,6 +1019,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h" #include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/phi/api/include/sparse_api.h"
""" """
file_contents += node_definition_str file_contents += node_definition_str
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
...@@ -1025,6 +1043,7 @@ def GenerateForwardCCFile(filepath, forward_definition_str): ...@@ -1025,6 +1043,7 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
""" """
...@@ -1055,134 +1074,184 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str): ...@@ -1055,134 +1074,184 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
if __name__ == "__main__": if __name__ == "__main__":
args = ParseArguments() args = ParseArguments()
api_yaml_path = args.api_yaml_path api_yaml_paths = args.api_yaml_path.split(",")
backward_yaml_path = args.backward_yaml_path backward_yaml_paths = args.backward_yaml_path.split(",")
fwd_api_list = ReadFwdFile(api_yaml_path)
grad_api_dict = ReadBwdFile(backward_yaml_path)
# Generate per Dygraph API # Generate per Dygraph API
node_declaration_str = "" node_declaration_str = ""
node_definition_str = "" node_definition_str = ""
forward_definition_str = "" forward_definition_str = ""
forward_declaration_str = "" forward_declaration_str = ""
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys() for i in range(len(api_yaml_paths)):
assert 'args' in fwd_api.keys() api_yaml_path = api_yaml_paths[i]
assert 'output' in fwd_api.keys() backward_yaml_path = backward_yaml_paths[i]
assert 'backward' in fwd_api.keys()
if "sparse" in api_yaml_path:
no_need_buffer_set = set() assert "sparse" in backward_yaml_path
if 'no_need_buffer' in fwd_api.keys(): namespace = "sparse"
no_need_buffer_set = ParseNoNeedBuffer(fwd_api['no_need_buffer']) else:
namespace = ""
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args'] fwd_api_list = ReadFwdFile(api_yaml_path)
fwd_returns_str = fwd_api['output'] grad_api_dict = ReadBwdFile(backward_yaml_path)
bwd_api_name = fwd_api['backward'] yaml_forward_definition_str = ""
assert bwd_api_name in grad_api_dict.keys() yaml_forward_declaration_str = ""
bwd_api = grad_api_dict[bwd_api_name] yaml_node_declaration_str = ""
yaml_node_definition_str = ""
assert 'args' in bwd_api.keys() for fwd_api in fwd_api_list:
assert 'output' in bwd_api.keys() # We only generate Ops with grad
assert 'forward' in bwd_api.keys() if 'backward' not in fwd_api.keys():
continue
# Parse Dispensable Inputs
optional_inputs = [] assert 'api' in fwd_api.keys()
if 'optional' in fwd_api.keys(): assert 'args' in fwd_api.keys()
optional_inputs = ParseDispensable(fwd_api['optional']) assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args'] no_need_buffer_set = set()
bwd_returns_str = bwd_api['output'] if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api[
# Collect Forward Inputs/Outputs 'no_need_buffer'])
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward(
bwd_forward_str) fwd_api_name = fwd_api['api']
print("Parsed Forward Inputs List: ", forward_inputs_list) fwd_args_str = fwd_api['args']
print("Prased Forward Attrs List: ", forward_attrs_list) fwd_returns_str = fwd_api['output']
print("Parsed Forward Returns List: ", forward_returns_list)
bwd_api_name = fwd_api['backward']
intermediate_outputs = [] assert bwd_api_name in grad_api_dict.keys()
if 'intermediate' in fwd_api.keys(): bwd_api = grad_api_dict[bwd_api_name]
intermediate_outputs = ParseIntermediate(fwd_api['intermediate'])
assert 'args' in bwd_api.keys()
IntermediateValidationCheck(intermediate_outputs, forward_returns_list) assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys()
# Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( # Parse Dispensable Inputs
fwd_args_str, fwd_returns_str) optional_inputs = []
print("Parsed Original Forward Inputs List: ", orig_forward_inputs_list) if 'optional' in fwd_api.keys():
print("Prased Original Forward Attrs List: ", orig_forward_attrs_list) optional_inputs = ParseDispensable(fwd_api['optional'])
print("Parsed Original Forward Returns List: ",
orig_forward_returns_list) bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args']
# Forward Validation Checks bwd_returns_str = bwd_api['output']
ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
forward_returns_list, orig_forward_inputs_list, # Collect Forward Inputs/Outputs
orig_forward_attrs_list, forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward(
orig_forward_returns_list) bwd_forward_str)
print("Parsed Forward Inputs List: ", forward_inputs_list)
# Parse Backward Inputs/Outputs print("Prased Forward Attrs List: ", forward_attrs_list)
backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward( print("Parsed Forward Returns List: ", forward_returns_list)
bwd_args_str, bwd_returns_str)
print("Parsed Backward Inputs List: ", backward_inputs_list) intermediate_outputs = []
print("Prased Backward Attrs List: ", backward_attrs_list) if 'intermediate' in fwd_api.keys():
print("Parsed Backward Returns List: ", backward_returns_list) intermediate_outputs = ParseIntermediate(fwd_api[
'intermediate'])
# Determine Forward Inputs/Outputs Position
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap( IntermediateValidationCheck(intermediate_outputs,
forward_inputs_list, forward_returns_list) forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map) # Collect Original Forward Inputs/Outputs and then perform validation checks
print("Generated Forward Output Position Map: ", orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
forward_outputs_position_map) fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ",
# SlotName Matching orig_forward_inputs_list)
backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching( print("Prased Original Forward Attrs List: ",
backward_inputs_list, backward_returns_list, orig_forward_attrs_list)
forward_inputs_position_map, forward_outputs_position_map) print("Parsed Original Forward Returns List: ",
print("Generated Backward Fwd Input Map: ", backward_fwd_input_map) orig_forward_returns_list)
print("Generated Backward Grad Input Map: ", backward_grad_input_map)
print("Generated Backward Grad Output Map: ", backward_grad_output_map) # Forward Validation Checks
ForwardsValidationCheck(
# Backward Validation Check forward_inputs_list, forward_attrs_list, forward_returns_list,
BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map, orig_forward_inputs_list, orig_forward_attrs_list,
backward_attrs_list) orig_forward_returns_list)
# Node Declaration Generation # Parse Backward Inputs/Outputs
node_declaration_str += GenerateNodeDeclaration( backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward(
fwd_api_name, backward_fwd_input_map, backward_attrs_list, bwd_args_str, bwd_returns_str)
no_need_buffer_set) print("Parsed Backward Inputs List: ", backward_inputs_list)
print("Generated Node Declaration: ", node_declaration_str) print("Prased Backward Attrs List: ", backward_attrs_list)
print("Parsed Backward Returns List: ", backward_returns_list)
node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map, # Determine Forward Inputs/Outputs Position
backward_grad_input_map, backward_grad_output_map, forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
backward_attrs_list) forward_inputs_list, forward_returns_list)
print("Generated Node Definition: ", node_definition_str) print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
# Node Definition Generation print("Generated Forward Output Position Map: ",
definition_declaration_pair = GenerateForwardDefinition( forward_outputs_position_map)
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, # SlotName Matching
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching(
backward_grad_output_map, backward_attrs_list, optional_inputs, backward_inputs_list, backward_returns_list,
intermediate_outputs) forward_inputs_position_map, forward_outputs_position_map)
print("Generated Forward Definition: ", forward_definition_str) print("Generated Backward Fwd Input Map: ", backward_fwd_input_map)
print("Generated Forward Declaration: ", forward_declaration_str) print("Generated Backward Grad Input Map: ",
forward_definition_str += definition_declaration_pair[0] backward_grad_input_map)
forward_declaration_str += definition_declaration_pair[1] print("Generated Backward Grad Output Map: ",
backward_grad_output_map)
# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, # Backward Validation Check
forward_outputs_position_map, BackwardValidationCheck(backward_fwd_input_map,
forward_attrs_list) backward_grad_input_map,
backward_attrs_list)
# Node Declaration Generation
yaml_node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str)
yaml_node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list)
print("Generated Node Definition: ", node_definition_str)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
yaml_forward_definition_str += definition_declaration_pair[0]
yaml_forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
forward_attrs_list)
if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str}
}}
"""
forward_declaration_str += f"""namespace {namespace} {{
{yaml_forward_declaration_str}
}}
"""
node_declaration_str += f"""namespace {namespace} {{
{yaml_node_declaration_str}
}}
"""
node_definition_str += f"""namespace {namespace} {{
{yaml_node_definition_str}
}}
"""
else:
forward_definition_str += yaml_forward_definition_str
forward_declaration_str += yaml_forward_declaration_str
node_declaration_str += yaml_node_declaration_str
node_definition_str += yaml_node_definition_str
# Generate Files # Generate Files
nodes_h_path = args.nodes_h_path nodes_h_path = args.nodes_h_path
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import argparse import argparse
from eager_gen import yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
skipped_fwd_api_names = set(["scale"]) skipped_fwd_api_names = set(["scale"])
...@@ -126,16 +126,20 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -126,16 +126,20 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
}} }}
""" """
namespace_str = ""
if len(namespace) > 0:
namespace_str = f"{namespace}::"
if is_forward_only: if is_forward_only:
fwd_function_name = fwd_api_name fwd_function_name = "paddle::experimental::" + namespace_str + fwd_api_name
else: else:
fwd_function_name = GetForwardFunctionName(fwd_api_name) fwd_function_name = namespace_str + GetForwardFunctionName(fwd_api_name)
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str, fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
fwd_function_name, dygraph_function_call_str) fwd_function_name, dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n" python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void)) {namespace_str}eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n"
return python_c_function_str, python_c_function_reg_str return python_c_function_str, python_c_function_reg_str
...@@ -189,7 +193,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) { ...@@ -189,7 +193,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
""" """
core_ops_infos_registry = """ core_ops_infos_registry = """
,{\"get_final_state_core_ops_args_info\", {\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS, (PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"}, \"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\", {\"get_final_state_core_ops_args_type_info\",
...@@ -222,6 +226,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): ...@@ -222,6 +226,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
...@@ -254,57 +259,80 @@ def GeneratePythonCFile(filepath, python_c_str): ...@@ -254,57 +259,80 @@ def GeneratePythonCFile(filepath, python_c_str):
if __name__ == "__main__": if __name__ == "__main__":
args = ParseArguments() args = ParseArguments()
api_yaml_path = args.api_yaml_path api_yaml_paths = args.api_yaml_path.split(",")
fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_functions_reg_str = ""
python_c_function_list = [] python_c_functions_str = ""
python_c_function_reg_list = []
for fwd_api in fwd_api_list: for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
# We only generate Ops with grad
is_forward_only = False if "sparse" in api_yaml_path:
if 'backward' not in fwd_api.keys(): namespace = "sparse"
is_forward_only = True else:
namespace = ""
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys() fwd_api_list = ReadFwdFile(api_yaml_path)
assert 'output' in fwd_api.keys()
python_c_function_list = []
fwd_api_name = fwd_api['api'] python_c_function_reg_list = []
fwd_args_str = fwd_api['args'] for fwd_api in fwd_api_list:
fwd_returns_str = fwd_api['output']
# We only generate Ops with grad
if fwd_api_name in skipped_fwd_api_names: is_forward_only = False
continue if 'backward' not in fwd_api.keys():
is_forward_only = True
# Parse Dispensable Inputs
optional_inputs = [] assert 'api' in fwd_api.keys()
if 'optional' in fwd_api.keys(): assert 'args' in fwd_api.keys()
optional_inputs = ParseDispensable(fwd_api['optional']) assert 'output' in fwd_api.keys()
# Collect Original Forward Inputs/Outputs and then perform validation checks fwd_api_name = fwd_api['api']
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward( fwd_args_str = fwd_api['args']
fwd_args_str, fwd_returns_str) fwd_returns_str = fwd_api['output']
print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list) if fwd_api_name in skipped_fwd_api_names:
print("Parsed Original Forward Returns List: ", forward_returns_list) continue
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap( # Parse Dispensable Inputs
forward_inputs_list, forward_returns_list) optional_inputs = []
print("Generated Forward Input Position Map: ", if 'optional' in fwd_api.keys():
forward_inputs_position_map) optional_inputs = ParseDispensable(fwd_api['optional'])
print("Generated Forward Output Position Map: ",
forward_outputs_position_map) # Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction( fwd_args_str, fwd_returns_str)
fwd_api_name, forward_inputs_position_map, forward_attrs_list, print("Parsed Original Forward Inputs List: ", forward_inputs_list)
forward_outputs_position_map, optional_inputs, is_forward_only) print("Prased Original Forward Attrs List: ", forward_attrs_list)
python_c_function_list.append(python_c_function_str) print("Parsed Original Forward Returns List: ",
python_c_function_reg_list.append(python_c_function_reg_str) forward_returns_list)
print("Generated Python-C Function: ", python_c_function_str)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
python_c_functions_str = "\n".join(python_c_function_list) forward_inputs_list, forward_returns_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list) print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map, optional_inputs, is_forward_only)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
# Append Namespace
python_c_functions_reg_str += ",\n".join(
python_c_function_reg_list) + ","
python_c_functions = "\n".join(python_c_function_list)
if len(namespace) > 0:
python_c_functions_str += f"""namespace {namespace} {{
{python_c_functions}
}}
"""
else:
python_c_functions_str += python_c_functions
python_c_str = GeneratePythonCWrappers(python_c_functions_str, python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str) python_c_functions_reg_str)
......
- sparse_api : conv3d - api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups)
output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
kernel : kernel :
func : sparse_conv3d func : sparse_conv3d
layout : x layout : x
- sparse_api : to_dense - api : to_dense
args : (Tensor x, Backend backend) args : (Tensor x, Backend backend)
output : Tensor(out@DenseTensor) output : Tensor(out@DenseTensor)
invoke : to_dense_impl(x, backend) invoke : to_dense_impl(x, backend)
- sparse_api : to_sparse_coo - api : to_sparse_coo
args : (Tensor x, Backend backend, int64 sparse_dim) args : (Tensor x, Backend backend, int64 sparse_dim)
output : Tensor(out@SparseCooTensor) output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, backend, sparse_dim) invoke : to_sparse_coo_impl(x, backend, sparse_dim)
- sparse_api : to_sparse_csr - api : to_sparse_csr
args : (Tensor x, Backend backend) args : (Tensor x, Backend backend)
output : Tensor(out@SparseCsrTensor) output : Tensor(out@SparseCsrTensor)
invoke : to_sparse_csr_impl(x, backend) invoke : to_sparse_csr_impl(x, backend)
...@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI): ...@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
super(SparseAPI, self).__init__(api_item_yaml) super(SparseAPI, self).__init__(api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['sparse_api']
def get_api_func_name(self): def get_api_func_name(self):
return self.api return self.api
......
- sparse_bw_api : conv3d_grad - backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups) args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups)
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor) output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
......
...@@ -25,9 +25,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -25,9 +25,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def __init__(self, bw_api_item_yaml): def __init__(self, bw_api_item_yaml):
BackwardAPI.__init__(self, bw_api_item_yaml) BackwardAPI.__init__(self, bw_api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['sparse_bw_api']
def get_api_func_name(self): def get_api_func_name(self):
return self.api return self.api
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册