From f7b45b3e363334e352f279780556cb38414fd1f7 Mon Sep 17 00:00:00 2001 From: RuohengMa <120699764+RuohengMa@users.noreply.github.com> Date: Thu, 23 Feb 2023 16:13:34 +0800 Subject: [PATCH] [Paddle C++ API] Remapping input and output tensors after XPU op has fallen back to CPU op (#50625) * fix accurary diff issue when XPU op batch_norm is added to XPU blacklist * remap op output tensor to input tensor when the op has fallen back to CPU * rename function name and fix bug causing by InplaceCounter --- paddle/phi/api/yaml/generator/api_base.py | 6 +++++ paddle/phi/api/yaml/generator/api_gen.py | 30 +++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 82f51d0575..48bb10b7d5 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -1179,6 +1179,11 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ): return None, None, None + def reset_view_after_fallback( + self, out_dtype_list, code_indent='', inplace_flag=False + ): + return '' + def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False): kernel_dispatch = self.kernel['dispatch'][kernel_name] input_tensors, kernel_args, kernel_signature = self.get_kernel_args( @@ -1227,6 +1232,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d {code_indent} }} {code_indent} if (kernel_result.has_fallback_cpu) {{ {fallback_kernel_output_trans} +{self.reset_view_after_fallback(self.outputs['types'], code_indent, inplace_flag)} {code_indent} }} {code_indent} {self.gene_return_code()}""" diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 36ba67fb43..a48e1b83ab 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -312,6 +312,36 @@ class ForwardAPI(BaseAPI): return kernel_output, output_names, output_create + def reset_view_after_fallback( + self, out_dtype_list, code_indent='', inplace_flag=False + ): + remap_code = '' + + if len(out_dtype_list) == 1: + if ( + not inplace_flag + and self.view_map is not None + and self.outputs['names'][0] in self.view_map + ): + remap_code += f""" +{code_indent} phi::DenseTensor * {self.view_map[self.outputs['names'][0]]}_remap = static_cast({self.view_map[self.outputs['names'][0]]}.impl().get()); +{code_indent} {self.view_map[self.outputs['names'][0]]}_remap->ShareBufferWith(*kernel_out); +{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][0]]}_remap); +""" + elif len(out_dtype_list) > 1: + for i in range(len(out_dtype_list)): + if ( + not inplace_flag + and self.view_map is not None + and self.outputs['names'][i] in self.view_map + ): + remap_code += f""" +{code_indent} phi::DenseTensor * {self.view_map[self.outputs['names'][i]]}_remap = static_cast({self.view_map[self.outputs['names'][i]]}.impl().get()); +{code_indent} {self.view_map[self.outputs['names'][i]]}_remap->ShareBufferWith(*kernel_out_{i}); +{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][i]]}_remap); +""" + return remap_code + def header_include(): return """ -- GitLab