diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index a93facbc34a5ba197eb1b0291cd8212d2dcafc9f..8049a8b8741b1a5a9588026a2600536daf23424d 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -904,10 +904,8 @@ def monkey_patch_varbase(): #[1, 2, 3, 4, 5] """ - if self.is_sparse_coo(): - return _C_ops.final_state_sparse_coo_values(self) - elif self.is_sparse_csr(): - return _C_ops.final_state_sparse_csr_values(self) + if self.is_sparse_coo() or self.is_sparse_csr(): + return _C_ops.final_state_sparse_values(self) else: raise ValueError( "only SparseCooTensor and SparseCsrTensor have method values") diff --git a/python/paddle/sparse/functional/unary.py b/python/paddle/sparse/functional/unary.py index 860b4025d89e0eb085eea3f0f81608b2dd2fd0cd..550e6a2a39261938ce24fd722d68238273db9e86 100644 --- a/python/paddle/sparse/functional/unary.py +++ b/python/paddle/sparse/functional/unary.py @@ -47,10 +47,8 @@ def relu(x, name=None): assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" - if x.is_sparse_coo(): - return _C_ops.final_state_sparse_coo_relu(x) - elif x.is_sparse_csr(): - return _C_ops.final_state_sparse_csr_relu(x) + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_relu(x) else: raise ValueError( "Currently, sparse.relu only support the input of SparseCooTensor or SparseCsrTensor" @@ -87,10 +85,8 @@ def tanh(x, name=None): assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" - if x.is_sparse_coo(): - return _C_ops.final_state_sparse_coo_tanh(x) - elif x.is_sparse_csr(): - return _C_ops.final_state_sparse_csr_tanh(x) + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_tanh(x) else: raise ValueError( "Currently, sparse.tanh only support the input of SparseCooTensor or SparseCsrTensor" @@ -127,10 +123,8 @@ def sqrt(x, name=None): assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" - if x.is_sparse_coo(): - return _C_ops.final_state_sparse_coo_sqrt(x) - elif x.is_sparse_csr(): - return _C_ops.final_state_sparse_csr_sqrt(x) + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_sqrt(x) else: raise ValueError( "Currently, sparse.sqrt only support the input of SparseCooTensor or SparseCsrTensor" @@ -167,10 +161,8 @@ def sin(x, name=None): assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" - if x.is_sparse_coo(): - return _C_ops.final_state_sparse_coo_sin(x) - elif x.is_sparse_csr(): - return _C_ops.final_state_sparse_csr_sin(x) + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_sin(x) else: raise ValueError( "Currently, sparse.sin only support the input of SparseCooTensor or SparseCsrTensor" diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 6dbd272c68a88c91ff1f13f78735aae4512f3c2f..96896b65f404115a6c4cbd6d227af8be06342a9c 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -45,7 +45,8 @@ class BaseAPI(object): 'infer_meta']) self.kernel = self.parse_kernel(api_item_yaml['kernel']) self.support_selected_rows_kernel = False if len(self.kernel[ - 'func']) == 1 else True + 'func']) == 1 or not self.kernel['func'][1].endswith( + '_sr') else True self.data_transform = self.parse_data_transform(api_item_yaml) self.inplace_map, self.view_map = self.parse_inplace_and_view( api_item_yaml) @@ -248,13 +249,15 @@ class BaseAPI(object): # backend : str, the names of param to choose the kernel backend, default is None # layout : str, the names of param to choose the kernel layout, default is None # data_type : str, the names of param to choose the kernel data_type, default is None + # dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)}) kernel = { 'func': [], 'param': None, 'backend': None, 'layout': None, 'data_type': None, - 'use_gpudnn': 'false' + 'use_gpudnn': 'false', + 'dispatch': {} } if 'backend' in kernel_config and len(kernel_config['backend']) > 0: kernel['backend'] = kernel_config['backend'] @@ -268,17 +271,21 @@ class BaseAPI(object): kernel['use_gpudnn'] = kernel_config['use_gpudnn'] if isinstance(kernel['use_gpudnn'], bool): kernel['use_gpudnn'] = str(kernel['use_gpudnn']).lower() - kernel['func'] = [ - kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',') - ] - - if len(kernel['func']) == 2: - assert kernel['func'][0] == self.api, \ - f"{self.api} : Kernel func error: If kernel has two func config, the name of first func should be same with api name({self.api}), \ - but now is {kernel['func'][0]}." - assert kernel['func'][1].endswith('_sr'), \ - f"{self.api} : Kernel func error: If kernel has two func config, the name of second func should be a selected_rows kernel (the func name endwith '_sr'), \ - but now is {kernel['func'][1]}." + kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall( + kernel_config['func']) + + def parse_kernel_in_out_type(in_out_str): + if len(in_out_str) == 0: + return None + tmp_in_out_list = in_out_str[1:-1].split('->') + inputs = [item.strip() for item in tmp_in_out_list[0].split(',')] + outputs = [item.strip() for item in tmp_in_out_list[1].split(',')] + return (inputs, outputs) + + for func_item in kernel_funcs: + kernel['func'].append(func_item[0]) + kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type( + func_item[1]) return kernel diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index ae3e9e6942233360f41323cf495d5798361f8402..5d1dc55f0638de4166a1e7061d7c41978bcc765c 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -1,128 +1,98 @@ - api : conv3d args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) - output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) + output : Tensor(out), Tensor(rulebook) kernel : - func : sparse_conv3d + func : sparse_conv3d{sparse_coo, dense -> sparse_coo, dense} layout : x intermediate : rulebook backward : conv3d_grad -- api : coo_relu - args : (Tensor x) - output : Tensor(out@SparseCooTensor) - kernel : - func : sparse_coo_relu - layout : x - backward : sparse_coo_relu_grad - -- api : coo_sin - args : (Tensor x) - output : Tensor(out@SparseCooTensor) - kernel : - func : sparse_coo_sin - layout : x - backward : sparse_coo_sin_grad - -- api : coo_sqrt - args : (Tensor x) - output : Tensor(out@SparseCooTensor) - kernel : - func : sparse_coo_sqrt - layout : x - backward : sparse_coo_sqrt_grad - -- api : coo_tanh - args : (Tensor x) - output : Tensor(out@SparseCooTensor) - kernel : - func : sparse_coo_tanh - layout : x - backward : sparse_coo_tanh_grad - - api : coo_to_dense args : (Tensor x) - output : Tensor(out@DenseTensor) + output : Tensor(out) invoke : to_dense_impl(x) backward : coo_to_dense_grad -- api : coo_values - args : (Tensor x) - output : Tensor(out@DenseTensor) - kernel : - func : coo_values - layout : x - backward : coo_values_grad - - api : create_sparse_coo_tensor args : (Tensor values, Tensor indices, IntArray dense_shape) - output : Tensor(out@SparseCooTensor) + output : Tensor(out) kernel : - func : sparse_coo_tensor + func : sparse_coo_tensor{dense, dense -> sparse_coo} layout : values data_type : values backward : create_sparse_coo_tensor_grad -- api : csr_relu - args : (Tensor x) - output : Tensor(out@SparseCsrTensor) - kernel : - func : sparse_csr_relu - layout : x +- api : dense_to_coo + args : (Tensor x, int64_t sparse_dim) + output : Tensor(out) + invoke : to_sparse_coo_impl(x, sparse_dim) + backward : dense_to_coo_grad -- api : csr_sin +- api : relu args : (Tensor x) - output : Tensor(out@SparseCsrTensor) + output : Tensor(out) kernel : - func : sparse_csr_sin + func : sparse_coo_relu{sparse_coo -> sparse_coo}, + sparse_csr_relu{sparse_csr -> sparse_csr} layout : x + backward : relu_grad -- api : csr_sqrt +- api : sin args : (Tensor x) - output : Tensor(out@SparseCsrTensor) + output : Tensor(out@SparseCooTensor) kernel : - func : sparse_csr_sqrt + func : sparse_coo_sin {sparse_coo -> sparse_coo}, + sparse_csr_sin {sparse_csr -> sparse_csr} layout : x + backward : sin_grad -- api : csr_tanh +- api : sqrt args : (Tensor x) - output : Tensor(out@SparseCsrTensor) + output : Tensor(out) kernel : - func : sparse_csr_tanh + func : sparse_coo_sqrt{sparse_coo -> sparse_coo}, + sparse_csr_sqrt{sparse_csr -> sparse_csr} layout : x + backward : sqrt_grad -- api : csr_values +- api : tanh args : (Tensor x) - output : Tensor(out@DenseTensor) + output : Tensor(out) kernel : - func : csr_values + func : sparse_coo_tanh{sparse_coo -> sparse_coo}, + sparse_csr_tanh{sparse_csr -> sparse_csr} layout : x - -- api : dense_to_coo - args : (Tensor x, int64_t sparse_dim) - output : Tensor(out@SparseCooTensor) - invoke : to_sparse_coo_impl(x, sparse_dim) - backward : dense_to_coo_grad + backward : tanh_grad - api : to_dense args : (Tensor x) - output : Tensor(out@DenseTensor) + output : Tensor(out) invoke : to_dense_impl(x) - api : to_sparse_coo args : (Tensor x, int64_t sparse_dim) - output : Tensor(out@SparseCooTensor) + output : Tensor(out) invoke : to_sparse_coo_impl(x, sparse_dim) - api : to_sparse_csr args : (Tensor x) - output : Tensor(out@SparseCsrTensor) + output : Tensor(out) invoke : to_sparse_csr_impl(x) +- api : values + args : (Tensor x) + output : Tensor(out) + kernel : + func : coo_values{sparse_coo -> dense}, + csr_values{sparse_csr -> dense} + layout : x + backward : values_grad + - api: maxpool args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) - output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) + output : Tensor(out), Tensor(rulebook) kernel : - func : sparse_maxpool + func : sparse_maxpool{sparse_coo -> sparse_coo, dense} layout : x intermediate : rulebook backward : sparse_maxpool_grad diff --git a/python/paddle/utils/code_gen/sparse_api_gen.py b/python/paddle/utils/code_gen/sparse_api_gen.py index 509858d339f69675a55722818768b612231ed868..bd73032e179dbba43e8da56301dbf4df5bf005a8 100644 --- a/python/paddle/utils/code_gen/sparse_api_gen.py +++ b/python/paddle/utils/code_gen/sparse_api_gen.py @@ -47,6 +47,11 @@ class SparseAPI(ForwardAPI): output_names = [] output_create = "" return_type = self.get_return_type_with_intermediate(inplace_flag) + output_type_map = { + 'dense': 'TensorType::DENSE_TENSOR', + 'sparse_coo': 'TensorType::SPARSE_COO', + 'sparse_csr': 'TensorType::SPARSE_CSR' + } if len(output_type_list) == 1: kernel_output = 'kernel_out' @@ -55,19 +60,18 @@ class SparseAPI(ForwardAPI): 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][0] in self.inplace_map else "" output_create = f""" - {return_type} api_output{inplace_assign}; - auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" + {return_type} api_output{inplace_assign}; + auto* kernel_out = {set_out_func}(&api_output, {output_type_map[output_type_list[0]]});""" elif len(output_type_list) > 1: output_create = f""" - {return_type} api_output;""" + {return_type} api_output;""" if inplace_flag: output_create = f""" - {return_type} api_output{{""" + {return_type} api_output{{""" for out_name in self.outputs['names']: - out_name = out_name.split('@')[0] if out_name in self.inplace_map: output_create = output_create + self.inplace_map[ out_name] + ', ' @@ -79,7 +83,7 @@ class SparseAPI(ForwardAPI): kernel_output = kernel_output + f'kernel_out_{i}, ' output_names.append(f'kernel_out_{i}') output_create = output_create + f""" - auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" + auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {output_type_map[output_type_list[i]]});""" kernel_output = kernel_output[:-2] else: @@ -117,7 +121,7 @@ class SparseAPI(ForwardAPI): ) else: kernel_context_code = kernel_context_code + f""" - kernel_context.EmplaceBackInput({param}.impl().get());""" + kernel_context.EmplaceBackInput({param}.impl().get());""" continue if param in attr_names: @@ -131,43 +135,78 @@ class SparseAPI(ForwardAPI): else: param + str(param) + ", " kernel_context_code = kernel_context_code + f""" - kernel_context.EmplaceBackAttr({param});""" + kernel_context.EmplaceBackAttr({param});""" for out_name in kernel_output_names: kernel_context_code = kernel_context_code + f""" - kernel_context.EmplaceBackOutput({out_name});""" + kernel_context.EmplaceBackOutput({out_name});""" return kernel_context_code - def gen_sparse_kernel_code(self, inplace_flag=False): + def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False): _, kernel_output_names, output_create = self.gene_output( - self.outputs['types'], 'SetSparseKernelOutput', '', inplace_flag) + self.kernel['dispatch'][kernel_name][1], 'SetSparseKernelOutput', + '', inplace_flag) kernel_context_code = self.gen_sparse_kernel_context( kernel_output_names) return_code = "" if len(self.gene_return_code( )) == 0 else " " + self.gene_return_code() return f""" - auto phi_kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); - VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; - VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel; + VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; + auto phi_kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); + VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - auto kernel_context = phi::KernelContext(dev_ctx); + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + auto kernel_context = phi::KernelContext(dev_ctx); {output_create} {kernel_context_code} - phi_kernel(&kernel_context); -{return_code}""" + phi_kernel(&kernel_context); + {return_code}""" + + def get_condition_code(self, kernel_name): + assert self.kernel['dispatch'][kernel_name], \ + f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'conv3d' in sparse_api.yaml." + input_types = self.kernel['dispatch'][kernel_name][0] + sparse_type_map = { + 'sparse_coo': 'DataLayout::SPARSE_COO', + 'sparse_csr': 'DataLayout::SPARSE_CSR' + } + condition_list = [] + for i, in_type in enumerate(input_types): + if in_type == "dense": + condition_list.append( + f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())" + ) + else: + condition_list.append( + f"{self.inputs['names'][i]}.layout() == {sparse_type_map[in_type]}" + ) + return " && ".join(condition_list) + + def gene_dispatch_code(self, kernel_name, inplace_flag=False): + dispatch_code = "" + return f""" + if ({self.get_condition_code(kernel_name)}) {{ +{self.gen_sparse_kernel_code(kernel_name, inplace_flag)} + }} +""" def gene_base_api_code(self, inplace_flag=False): api_func_name = self.get_api_func_name() if inplace_flag and api_func_name[-1] != '_': api_func_name += '_' + kernel_dispatch_code = f"{self.gene_kernel_select()}\n" + for kernel_name in self.kernel['func']: + kernel_dispatch_code += self.gene_dispatch_code(kernel_name, + inplace_flag) + return f""" PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{ -{self.gene_kernel_select()} -{self.gen_sparse_kernel_code(inplace_flag)} +{kernel_dispatch_code} + PADDLE_THROW(phi::errors::Unimplemented( + "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors.")); }} """ diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index d8e8aad8f98b2a7475c9c9e45c119559d0b50678..eb7114cbdd2c952758e2b0af33fec1cedeee2a6a 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -1,68 +1,68 @@ - backward_api : conv3d_grad forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) - output : Tensor(x_grad@SparseCooTensor), Tensor(kernel_grad@DenseTensor) + output : Tensor(x_grad), Tensor(kernel_grad) kernel : - func : sparse_conv3d_grad + func : sparse_conv3d_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense} - backward_api : coo_to_dense_grad - forward : coo_to_dense(Tensor x) -> Tensor(out@DenseTensor) + forward : coo_to_dense(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad@SparseCooTensor) + output : Tensor(x_grad) kernel : - func : sparse_coo_to_dense_grad - -- backward_api : coo_values_grad - forward : coo_values(Tensor x) -> Tensor(out@DenseTensor) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad@SparseCooTensor) - kernel : - func : coo_values_grad + func : sparse_coo_to_dense_grad{sparse_coo, dense-> sparse_coo} - backward_api : create_sparse_coo_tensor_grad - forward : create_sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out@SparseCooTensor) + forward : create_sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out) args : (Tensor indices, Tensor out_grad) - output : Tensor(values_grad@DenseTensor) + output : Tensor(values_grad) kernel : - func : sparse_coo_tensor_grad + func : sparse_coo_tensor_grad{dense, sparse_coo -> dense} - backward_api : dense_to_coo_grad - forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out@SparseCooTensor) + forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out) args : (Tensor out_grad) - output : Tensor(x_grad@DenseTensor) + output : Tensor(x_grad) invoke : to_dense_impl(out_grad) -- backward_api : sparse_coo_relu_grad - forward : sparse_coo_relu(Tensor x) -> Tensor(out@SparseCooTensor) +- backward_api : relu_grad + forward : relu(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) - output : Tensor(x_grad@SparseCooTensor) + output : Tensor(x_grad) kernel : - func : sparse_coo_relu_grad + func : sparse_coo_relu_grad {sparse_coo, sparse_coo -> sparse_coo} -- backward_api : sparse_coo_sin_grad - forward : sparse_coo_sin(Tensor x) -> Tensor(out@SparseCooTensor) +- backward_api : sin_grad + forward : sin(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad@SparseCooTensor) + output : Tensor(x_grad) + kernel : + func : sparse_coo_sin_grad {sparse_coo, sparse_coo -> sparse_coo} + +- backward_api : sparse_maxpool_grad + forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook) + args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes) + output : Tensor(x_grad) kernel : - func : sparse_coo_sin_grad + func : sparse_maxpool_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo} -- backward_api : sparse_coo_sqrt_grad - forward : sparse_coo_sqrt(Tensor x) -> Tensor(out@SparseCooTensor) +- backward_api : sqrt_grad + forward : sqrt(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) - output : Tensor(x_grad@SparseCooTensor) + output : Tensor(x_grad) kernel : - func : sparse_coo_sqrt_grad + func : sparse_coo_sqrt_grad {sparse_coo, sparse_coo -> sparse_coo} -- backward_api : sparse_coo_tanh_grad - forward : sparse_coo_tanh(Tensor x) -> Tensor(out@SparseCooTensor) +- backward_api : tanh_grad + forward : tanh(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) - output : Tensor(x_grad@SparseCooTensor) + output : Tensor(x_grad) kernel : - func : sparse_coo_tanh_grad + func : sparse_coo_tanh_grad {sparse_coo, sparse_coo -> sparse_coo} -- backward_api : sparse_maxpool_grad - forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) - args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes) - output : Tensor(x_grad@SparseCooTensor) +- backward_api : values_grad + forward : coo_values(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) kernel : - func : sparse_maxpool_grad + func : coo_values_grad{sparse_coo, dense-> sparse_coo} diff --git a/python/paddle/utils/code_gen/sparse_bw_api_gen.py b/python/paddle/utils/code_gen/sparse_bw_api_gen.py index 53a99d798118e66e9d64eef5b2f283721b131a92..cf59726bbb19577239fbecd1390f7fbb9301272e 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api_gen.py +++ b/python/paddle/utils/code_gen/sparse_bw_api_gen.py @@ -35,7 +35,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): return BackwardAPI.get_return_type(self) def gene_return_code(self): - return "" + return "return;" def gene_api_declaration(self): return SparseAPI.gene_api_declaration(self) @@ -54,6 +54,11 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): kernel_output = "" output_names = [] output_create = "" + output_type_map = { + 'dense': 'TensorType::DENSE_TENSOR', + 'sparse_coo': 'TensorType::SPARSE_COO', + 'sparse_csr': 'TensorType::SPARSE_CSR' + } if len(output_type_list) == 1: kernel_output = 'kernel_out' @@ -62,7 +67,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][0] in self.inplace_map else "" output_create = f""" - auto kernel_out = {set_out_func}({self.outputs['names'][0].split('@')[0]}, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" + auto kernel_out = {set_out_func}({self.outputs['names'][0]}, {output_type_map[output_type_list[0]]});""" elif len(output_type_list) > 1: output_create = "" @@ -73,10 +78,10 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][i] in self.inplace_map: output_create = output_create + f""" - *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};""" + *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};""" output_create = output_create + f""" - auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i].split('@')[0]}, {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" + auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i]}, {output_type_map[output_type_list[i]]});""" kernel_output = kernel_output[:-2] else: