未验证 提交 2b6da4de 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Supported auto code gen for sparse kernels (#40276)

上级 a07f19ee
set(api_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml") set(api_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml,${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml")
set(backward_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/backward.yaml") set(backward_yaml_path "${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/backward.yaml,${PADDLE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml")
set(tmp_forwards_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.cc") set(tmp_forwards_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.cc")
set(tmp_forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.h") set(tmp_forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.h")
set(tmp_nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/tmp_nodes.cc") set(tmp_nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/tmp_nodes.cc")
......
...@@ -23,6 +23,7 @@ core_ops_returns_info = {} ...@@ -23,6 +23,7 @@ core_ops_returns_info = {}
core_ops_args_info = {} core_ops_args_info = {}
core_ops_args_type_info = {} core_ops_args_type_info = {}
namespace = ""
yaml_types_mapping = { yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
...@@ -125,6 +126,7 @@ def GetAutoGradMetaVectorName(string): ...@@ -125,6 +126,7 @@ def GetAutoGradMetaVectorName(string):
def ReadFwdFile(filepath): def ReadFwdFile(filepath):
f = open(filepath, 'r') f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader) contents = yaml.load(f, Loader=yaml.FullLoader)
f.close()
return contents return contents
...@@ -133,9 +135,13 @@ def ReadBwdFile(filepath): ...@@ -133,9 +135,13 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader) contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {} ret = {}
for content in contents: for content in contents:
assert 'backward_api' in content.keys() if 'backward_api' in content.keys():
api_name = content['backward_api'] api_name = content['backward_api']
else:
assert False
ret[api_name] = content ret[api_name] = content
f.close()
return ret return ret
...@@ -608,16 +614,23 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -608,16 +614,23 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str += f"return returns;\n" returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name) grad_node_name = GetGradNodeName(fwd_api_name)
if len(namespace) > 0:
grad_api_namespace = f"paddle::experimental::{namespace}"
else:
grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """ FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{
// Call grad_api function // Call grad_api function
auto grad_api_returns = paddle::experimental::{}({}); auto grad_api_returns = {}::{}({});
{} {}
}} }}
""" """
node_definition_str = FUNCTION_TEMPLATE.format( node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, bwd_api_name, grad_api_args_str, returns_str) grad_node_name, grad_api_namespace, bwd_api_name, grad_api_args_str,
returns_str)
return node_definition_str return node_definition_str
...@@ -850,6 +863,10 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -850,6 +863,10 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
function_name = fwd_api_name function_name = fwd_api_name
else: else:
function_name = fwd_api_name + "_intermediate" function_name = fwd_api_name + "_intermediate"
if len(namespace) > 0:
forward_call_str = f"auto api_result = paddle::experimental::{namespace}::{function_name}({inputs_call_args_str});"
else:
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});" forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs # Get return type list & outputs
...@@ -1002,6 +1019,7 @@ def GenerateNodeCCFile(filepath, node_definition_str): ...@@ -1002,6 +1019,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h" #include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/phi/api/include/sparse_api.h"
""" """
file_contents += node_definition_str file_contents += node_definition_str
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
...@@ -1025,6 +1043,7 @@ def GenerateForwardCCFile(filepath, forward_definition_str): ...@@ -1025,6 +1043,7 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
""" """
...@@ -1055,17 +1074,32 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str): ...@@ -1055,17 +1074,32 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
if __name__ == "__main__": if __name__ == "__main__":
args = ParseArguments() args = ParseArguments()
api_yaml_path = args.api_yaml_path api_yaml_paths = args.api_yaml_path.split(",")
backward_yaml_path = args.backward_yaml_path backward_yaml_paths = args.backward_yaml_path.split(",")
fwd_api_list = ReadFwdFile(api_yaml_path)
grad_api_dict = ReadBwdFile(backward_yaml_path)
# Generate per Dygraph API # Generate per Dygraph API
node_declaration_str = "" node_declaration_str = ""
node_definition_str = "" node_definition_str = ""
forward_definition_str = "" forward_definition_str = ""
forward_declaration_str = "" forward_declaration_str = ""
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
backward_yaml_path = backward_yaml_paths[i]
if "sparse" in api_yaml_path:
assert "sparse" in backward_yaml_path
namespace = "sparse"
else:
namespace = ""
fwd_api_list = ReadFwdFile(api_yaml_path)
grad_api_dict = ReadBwdFile(backward_yaml_path)
yaml_forward_definition_str = ""
yaml_forward_declaration_str = ""
yaml_node_declaration_str = ""
yaml_node_definition_str = ""
for fwd_api in fwd_api_list: for fwd_api in fwd_api_list:
# We only generate Ops with grad # We only generate Ops with grad
if 'backward' not in fwd_api.keys(): if 'backward' not in fwd_api.keys():
...@@ -1078,7 +1112,8 @@ if __name__ == "__main__": ...@@ -1078,7 +1112,8 @@ if __name__ == "__main__":
no_need_buffer_set = set() no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys(): if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api['no_need_buffer']) no_need_buffer_set = ParseNoNeedBuffer(fwd_api[
'no_need_buffer'])
fwd_api_name = fwd_api['api'] fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args'] fwd_args_str = fwd_api['args']
...@@ -1110,22 +1145,26 @@ if __name__ == "__main__": ...@@ -1110,22 +1145,26 @@ if __name__ == "__main__":
intermediate_outputs = [] intermediate_outputs = []
if 'intermediate' in fwd_api.keys(): if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api['intermediate']) intermediate_outputs = ParseIntermediate(fwd_api[
'intermediate'])
IntermediateValidationCheck(intermediate_outputs, forward_returns_list) IntermediateValidationCheck(intermediate_outputs,
forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks # Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", orig_forward_inputs_list) print("Parsed Original Forward Inputs List: ",
print("Prased Original Forward Attrs List: ", orig_forward_attrs_list) orig_forward_inputs_list)
print("Prased Original Forward Attrs List: ",
orig_forward_attrs_list)
print("Parsed Original Forward Returns List: ", print("Parsed Original Forward Returns List: ",
orig_forward_returns_list) orig_forward_returns_list)
# Forward Validation Checks # Forward Validation Checks
ForwardsValidationCheck(forward_inputs_list, forward_attrs_list, ForwardsValidationCheck(
forward_returns_list, orig_forward_inputs_list, forward_inputs_list, forward_attrs_list, forward_returns_list,
orig_forward_attrs_list, orig_forward_inputs_list, orig_forward_attrs_list,
orig_forward_returns_list) orig_forward_returns_list)
# Parse Backward Inputs/Outputs # Parse Backward Inputs/Outputs
...@@ -1148,20 +1187,23 @@ if __name__ == "__main__": ...@@ -1148,20 +1187,23 @@ if __name__ == "__main__":
backward_inputs_list, backward_returns_list, backward_inputs_list, backward_returns_list,
forward_inputs_position_map, forward_outputs_position_map) forward_inputs_position_map, forward_outputs_position_map)
print("Generated Backward Fwd Input Map: ", backward_fwd_input_map) print("Generated Backward Fwd Input Map: ", backward_fwd_input_map)
print("Generated Backward Grad Input Map: ", backward_grad_input_map) print("Generated Backward Grad Input Map: ",
print("Generated Backward Grad Output Map: ", backward_grad_output_map) backward_grad_input_map)
print("Generated Backward Grad Output Map: ",
backward_grad_output_map)
# Backward Validation Check # Backward Validation Check
BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map, BackwardValidationCheck(backward_fwd_input_map,
backward_grad_input_map,
backward_attrs_list) backward_attrs_list)
# Node Declaration Generation # Node Declaration Generation
node_declaration_str += GenerateNodeDeclaration( yaml_node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list, fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set) no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str) print("Generated Node Declaration: ", node_declaration_str)
node_definition_str += GenerateNodeDefinition( yaml_node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map, fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map, backward_grad_input_map, backward_grad_output_map,
backward_attrs_list) backward_attrs_list)
...@@ -1176,14 +1218,41 @@ if __name__ == "__main__": ...@@ -1176,14 +1218,41 @@ if __name__ == "__main__":
intermediate_outputs) intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str) print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0] yaml_forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1] yaml_forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch # For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_outputs_position_map,
forward_attrs_list) forward_attrs_list)
if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str}
}}
"""
forward_declaration_str += f"""namespace {namespace} {{
{yaml_forward_declaration_str}
}}
"""
node_declaration_str += f"""namespace {namespace} {{
{yaml_node_declaration_str}
}}
"""
node_definition_str += f"""namespace {namespace} {{
{yaml_node_definition_str}
}}
"""
else:
forward_definition_str += yaml_forward_definition_str
forward_declaration_str += yaml_forward_declaration_str
node_declaration_str += yaml_node_declaration_str
node_definition_str += yaml_node_definition_str
# Generate Files # Generate Files
nodes_h_path = args.nodes_h_path nodes_h_path = args.nodes_h_path
nodes_cc_path = args.nodes_cc_path nodes_cc_path = args.nodes_cc_path
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import argparse import argparse
from eager_gen import yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
skipped_fwd_api_names = set(["scale"]) skipped_fwd_api_names = set(["scale"])
...@@ -126,16 +126,20 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -126,16 +126,20 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
}} }}
""" """
namespace_str = ""
if len(namespace) > 0:
namespace_str = f"{namespace}::"
if is_forward_only: if is_forward_only:
fwd_function_name = fwd_api_name fwd_function_name = "paddle::experimental::" + namespace_str + fwd_api_name
else: else:
fwd_function_name = GetForwardFunctionName(fwd_api_name) fwd_function_name = namespace_str + GetForwardFunctionName(fwd_api_name)
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str, fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
fwd_function_name, dygraph_function_call_str) fwd_function_name, dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n" python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void)) {namespace_str}eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n"
return python_c_function_str, python_c_function_reg_str return python_c_function_str, python_c_function_reg_str
...@@ -189,7 +193,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) { ...@@ -189,7 +193,7 @@ static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
""" """
core_ops_infos_registry = """ core_ops_infos_registry = """
,{\"get_final_state_core_ops_args_info\", {\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS, (PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"}, \"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\", {\"get_final_state_core_ops_args_type_info\",
...@@ -222,6 +226,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): ...@@ -222,6 +226,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
...@@ -254,7 +259,19 @@ def GeneratePythonCFile(filepath, python_c_str): ...@@ -254,7 +259,19 @@ def GeneratePythonCFile(filepath, python_c_str):
if __name__ == "__main__": if __name__ == "__main__":
args = ParseArguments() args = ParseArguments()
api_yaml_path = args.api_yaml_path api_yaml_paths = args.api_yaml_path.split(",")
python_c_functions_reg_str = ""
python_c_functions_str = ""
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
if "sparse" in api_yaml_path:
namespace = "sparse"
else:
namespace = ""
fwd_api_list = ReadFwdFile(api_yaml_path) fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_function_list = [] python_c_function_list = []
...@@ -287,7 +304,8 @@ if __name__ == "__main__": ...@@ -287,7 +304,8 @@ if __name__ == "__main__":
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", forward_inputs_list) print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list) print("Prased Original Forward Attrs List: ", forward_attrs_list)
print("Parsed Original Forward Returns List: ", forward_returns_list) print("Parsed Original Forward Returns List: ",
forward_returns_list)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap( forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list) forward_inputs_list, forward_returns_list)
...@@ -303,8 +321,18 @@ if __name__ == "__main__": ...@@ -303,8 +321,18 @@ if __name__ == "__main__":
python_c_function_reg_list.append(python_c_function_reg_str) python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str) print("Generated Python-C Function: ", python_c_function_str)
python_c_functions_str = "\n".join(python_c_function_list) # Append Namespace
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list) python_c_functions_reg_str += ",\n".join(
python_c_function_reg_list) + ","
python_c_functions = "\n".join(python_c_function_list)
if len(namespace) > 0:
python_c_functions_str += f"""namespace {namespace} {{
{python_c_functions}
}}
"""
else:
python_c_functions_str += python_c_functions
python_c_str = GeneratePythonCWrappers(python_c_functions_str, python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str) python_c_functions_reg_str)
......
- sparse_api : conv3d - api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups)
output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
kernel : kernel :
func : sparse_conv3d func : sparse_conv3d
layout : x layout : x
- sparse_api : to_dense - api : to_dense
args : (Tensor x, Backend backend) args : (Tensor x, Backend backend)
output : Tensor(out@DenseTensor) output : Tensor(out@DenseTensor)
invoke : to_dense_impl(x, backend) invoke : to_dense_impl(x, backend)
- sparse_api : to_sparse_coo - api : to_sparse_coo
args : (Tensor x, Backend backend, int64 sparse_dim) args : (Tensor x, Backend backend, int64 sparse_dim)
output : Tensor(out@SparseCooTensor) output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, backend, sparse_dim) invoke : to_sparse_coo_impl(x, backend, sparse_dim)
- sparse_api : to_sparse_csr - api : to_sparse_csr
args : (Tensor x, Backend backend) args : (Tensor x, Backend backend)
output : Tensor(out@SparseCsrTensor) output : Tensor(out@SparseCsrTensor)
invoke : to_sparse_csr_impl(x, backend) invoke : to_sparse_csr_impl(x, backend)
...@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI): ...@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
super(SparseAPI, self).__init__(api_item_yaml) super(SparseAPI, self).__init__(api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['sparse_api']
def get_api_func_name(self): def get_api_func_name(self):
return self.api return self.api
......
- sparse_bw_api : conv3d_grad - backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups) args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups)
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor) output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
......
...@@ -25,9 +25,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -25,9 +25,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def __init__(self, bw_api_item_yaml): def __init__(self, bw_api_item_yaml):
BackwardAPI.__init__(self, bw_api_item_yaml) BackwardAPI.__init__(self, bw_api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['sparse_bw_api']
def get_api_func_name(self): def get_api_func_name(self):
return self.api return self.api
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册