未验证 提交 2b6da4de 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Supported auto code gen for sparse kernels (#40276)

上级 a07f19ee
set(api_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml")
set(backward_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/backward.yaml")
set(api_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml,${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml")
set(backward_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/backward.yaml,${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml")
set(tmp_forwards_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.cc")
set(tmp_forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.h")
set(tmp_nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/tmp_nodes.cc")
......
......@@ -23,6 +23,7 @@ core_ops_returns_info = {}
core_ops_args_info = {}
core_ops_args_type_info = {}
namespace = ""
yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
......@@ -125,6 +126,7 @@ def GetAutoGradMetaVectorName(string):
def ReadFwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
f.close()
return contents
......@@ -133,9 +135,13 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {}
for content in contents:
assert 'backward_api' in content.keys()
api_name = content['backward_api']
if 'backward_api' in content.keys():
api_name = content['backward_api']
else:
assert False
ret[api_name] = content
f.close()
return ret
......@@ -608,16 +614,23 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name)
if len(namespace) > 0:
grad_api_namespace = f"paddle::experimental::{namespace}"
else:
grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{
// Call grad_api function
auto grad_api_returns = paddle::experimental::{}({});
auto grad_api_returns = {}::{}({});
{}
}}
"""
node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, bwd_api_name, grad_api_args_str, returns_str)
grad_node_name, grad_api_namespace, bwd_api_name, grad_api_args_str,
returns_str)
return node_definition_str
......@@ -850,7 +863,11 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
function_name = fwd_api_name
else:
function_name = fwd_api_name + "_intermediate"
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
if len(namespace) > 0:
forward_call_str = f"auto api_result = paddle::experimental::{namespace}::{function_name}({inputs_call_args_str});"
else:
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) - len(
......@@ -1002,6 +1019,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/phi/api/include/sparse_api.h"
"""
file_contents += node_definition_str
with open(filepath, 'a') as f:
......@@ -1025,6 +1043,7 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
"""
......@@ -1055,134 +1074,184 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
if __name__ == "__main__":
args = ParseArguments()
api_yaml_path = args.api_yaml_path
backward_yaml_path = args.backward_yaml_path
fwd_api_list = ReadFwdFile(api_yaml_path)
grad_api_dict = ReadBwdFile(backward_yaml_path)
api_yaml_paths = args.api_yaml_path.split(",")
backward_yaml_paths = args.backward_yaml_path.split(",")
# Generate per Dygraph API
node_declaration_str = ""
node_definition_str = ""
forward_definition_str = ""
forward_declaration_str = ""
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api['no_need_buffer'])
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys()
bwd_api = grad_api_dict[bwd_api_name]
assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys()
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output']
# Collect Forward Inputs/Outputs
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward(
bwd_forward_str)
print("Parsed Forward Inputs List: ", forward_inputs_list)
print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list)
intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api['intermediate'])
IntermediateValidationCheck(intermediate_outputs, forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", orig_forward_inputs_list)
print("Prased Original Forward Attrs List: ", orig_forward_attrs_list)
print("Parsed Original Forward Returns List: ",
orig_forward_returns_list)
# Forward Validation Checks
ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
forward_returns_list, orig_forward_inputs_list,
orig_forward_attrs_list,
orig_forward_returns_list)
# Parse Backward Inputs/Outputs
backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward(
bwd_args_str, bwd_returns_str)
print("Parsed Backward Inputs List: ", backward_inputs_list)
print("Prased Backward Attrs List: ", backward_attrs_list)
print("Parsed Backward Returns List: ", backward_returns_list)
# Determine Forward Inputs/Outputs Position
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
# SlotName Matching
backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching(
backward_inputs_list, backward_returns_list,
forward_inputs_position_map, forward_outputs_position_map)
print("Generated Backward Fwd Input Map: ", backward_fwd_input_map)
print("Generated Backward Grad Input Map: ", backward_grad_input_map)
print("Generated Backward Grad Output Map: ", backward_grad_output_map)
# Backward Validation Check
BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map,
backward_attrs_list)
# Node Declaration Generation
node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str)
node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list)
print("Generated Node Definition: ", node_definition_str)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
forward_attrs_list)
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
backward_yaml_path = backward_yaml_paths[i]
if "sparse" in api_yaml_path:
assert "sparse" in backward_yaml_path
namespace = "sparse"
else:
namespace = ""
fwd_api_list = ReadFwdFile(api_yaml_path)
grad_api_dict = ReadBwdFile(backward_yaml_path)
yaml_forward_definition_str = ""
yaml_forward_declaration_str = ""
yaml_node_declaration_str = ""
yaml_node_definition_str = ""
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api[
'no_need_buffer'])
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys()
bwd_api = grad_api_dict[bwd_api_name]
assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys()
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output']
# Collect Forward Inputs/Outputs
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward(
bwd_forward_str)
print("Parsed Forward Inputs List: ", forward_inputs_list)
print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list)
intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api[
'intermediate'])
IntermediateValidationCheck(intermediate_outputs,
forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ",
orig_forward_inputs_list)
print("Prased Original Forward Attrs List: ",
orig_forward_attrs_list)
print("Parsed Original Forward Returns List: ",
orig_forward_returns_list)
# Forward Validation Checks
ForwardsValidationCheck(
forward_inputs_list, forward_attrs_list, forward_returns_list,
orig_forward_inputs_list, orig_forward_attrs_list,
orig_forward_returns_list)
# Parse Backward Inputs/Outputs
backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward(
bwd_args_str, bwd_returns_str)
print("Parsed Backward Inputs List: ", backward_inputs_list)
print("Prased Backward Attrs List: ", backward_attrs_list)
print("Parsed Backward Returns List: ", backward_returns_list)
# Determine Forward Inputs/Outputs Position
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
# SlotName Matching
backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching(
backward_inputs_list, backward_returns_list,
forward_inputs_position_map, forward_outputs_position_map)
print("Generated Backward Fwd Input Map: ", backward_fwd_input_map)
print("Generated Backward Grad Input Map: ",
backward_grad_input_map)
print("Generated Backward Grad Output Map: ",
backward_grad_output_map)
# Backward Validation Check
BackwardValidationCheck(backward_fwd_input_map,
backward_grad_input_map,
backward_attrs_list)
# Node Declaration Generation
yaml_node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str)
yaml_node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list)
print("Generated Node Definition: ", node_definition_str)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
yaml_forward_definition_str += definition_declaration_pair[0]
yaml_forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
forward_attrs_list)
if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str}
}}
"""
forward_declaration_str += f"""namespace {namespace} {{
{yaml_forward_declaration_str}
}}
"""
node_declaration_str += f"""namespace {namespace} {{
{yaml_node_declaration_str}
}}
"""
node_definition_str += f"""namespace {namespace} {{
{yaml_node_definition_str}
}}
"""
else:
forward_definition_str += yaml_forward_definition_str
forward_declaration_str += yaml_forward_declaration_str
node_declaration_str += yaml_node_declaration_str
node_definition_str += yaml_node_definition_str
# Generate Files
nodes_h_path = args.nodes_h_path
......
......@@ -14,7 +14,7 @@
import os
import argparse
from eager_gen import yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
skipped_fwd_api_names = set(["scale"])
......@@ -126,16 +126,20 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
}}
"""
namespace_str = ""
if len(namespace) > 0:
namespace_str = f"{namespace}::"
if is_forward_only:
fwd_function_name = fwd_api_name
fwd_function_name = "paddle::experimental::" + namespace_str + fwd_api_name
else:
fwd_function_name = GetForwardFunctionName(fwd_api_name)
fwd_function_name = namespace_str + GetForwardFunctionName(fwd_api_name)
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
fwd_function_name, dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n"
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void)) {namespace_str}eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n"
return python_c_function_str, python_c_function_reg_str
......@@ -189,7 +193,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
"""
core_ops_infos_registry = """
,{\"get_final_state_core_ops_args_info\",
{\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\",
......@@ -222,6 +226,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
......@@ -254,57 +259,80 @@ def GeneratePythonCFile(filepath, python_c_str):
if __name__ == "__main__":
args = ParseArguments()
api_yaml_path = args.api_yaml_path
fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_function_list = []
python_c_function_reg_list = []
for fwd_api in fwd_api_list:
# We only generate Ops with grad
is_forward_only = False
if 'backward' not in fwd_api.keys():
is_forward_only = True
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
if fwd_api_name in skipped_fwd_api_names:
continue
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list)
print("Parsed Original Forward Returns List: ", forward_returns_list)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map, optional_inputs, is_forward_only)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)
api_yaml_paths = args.api_yaml_path.split(",")
python_c_functions_reg_str = ""
python_c_functions_str = ""
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
if "sparse" in api_yaml_path:
namespace = "sparse"
else:
namespace = ""
fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_function_list = []
python_c_function_reg_list = []
for fwd_api in fwd_api_list:
# We only generate Ops with grad
is_forward_only = False
if 'backward' not in fwd_api.keys():
is_forward_only = True
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
if fwd_api_name in skipped_fwd_api_names:
continue
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list)
print("Parsed Original Forward Returns List: ",
forward_returns_list)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map, optional_inputs, is_forward_only)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
# Append Namespace
python_c_functions_reg_str += ",\n".join(
python_c_function_reg_list) + ","
python_c_functions = "\n".join(python_c_function_list)
if len(namespace) > 0:
python_c_functions_str += f"""namespace {namespace} {{
{python_c_functions}
}}
"""
else:
python_c_functions_str += python_c_functions
python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)
......
- sparse_api : conv3d
- api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups)
output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
kernel :
func : sparse_conv3d
layout : x
- sparse_api : to_dense
- api : to_dense
args : (Tensor x, Backend backend)
output : Tensor(out@DenseTensor)
invoke : to_dense_impl(x, backend)
- sparse_api : to_sparse_coo
- api : to_sparse_coo
args : (Tensor x, Backend backend, int64 sparse_dim)
output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, backend, sparse_dim)
- sparse_api : to_sparse_csr
- api : to_sparse_csr
args : (Tensor x, Backend backend)
output : Tensor(out@SparseCsrTensor)
invoke : to_sparse_csr_impl(x, backend)
......@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI):
def __init__(self, api_item_yaml):
super(SparseAPI, self).__init__(api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['sparse_api']
def get_api_func_name(self):
return self.api
......
- sparse_bw_api : conv3d_grad
- backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups)
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
......
......@@ -25,9 +25,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def __init__(self, bw_api_item_yaml):
BackwardAPI.__init__(self, bw_api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['sparse_bw_api']
def get_api_func_name(self):
return self.api
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册