未验证 提交 f7b45b3e 编写于 作者: R RuohengMa 提交者: GitHub

[Paddle C++ API] Remapping input and output tensors after XPU op has fallen back to CPU op (#50625)

* fix accurary diff issue when XPU op batch_norm is added to XPU blacklist

* remap op output tensor to input tensor when the op has fallen back to CPU

* rename function name and fix bug causing by InplaceCounter
上级 846c4c30
...@@ -1179,6 +1179,11 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1179,6 +1179,11 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
): ):
return None, None, None return None, None, None
def reset_view_after_fallback(
self, out_dtype_list, code_indent='', inplace_flag=False
):
return ''
def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False): def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
kernel_dispatch = self.kernel['dispatch'][kernel_name] kernel_dispatch = self.kernel['dispatch'][kernel_name]
input_tensors, kernel_args, kernel_signature = self.get_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
...@@ -1227,6 +1232,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1227,6 +1232,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} }} {code_indent} }}
{code_indent} if (kernel_result.has_fallback_cpu) {{ {code_indent} if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans} {fallback_kernel_output_trans}
{self.reset_view_after_fallback(self.outputs['types'], code_indent, inplace_flag)}
{code_indent} }} {code_indent} }}
{code_indent} {self.gene_return_code()}""" {code_indent} {self.gene_return_code()}"""
......
...@@ -312,6 +312,36 @@ class ForwardAPI(BaseAPI): ...@@ -312,6 +312,36 @@ class ForwardAPI(BaseAPI):
return kernel_output, output_names, output_create return kernel_output, output_names, output_create
def reset_view_after_fallback(
self, out_dtype_list, code_indent='', inplace_flag=False
):
remap_code = ''
if len(out_dtype_list) == 1:
if (
not inplace_flag
and self.view_map is not None
and self.outputs['names'][0] in self.view_map
):
remap_code += f"""
{code_indent} phi::DenseTensor * {self.view_map[self.outputs['names'][0]]}_remap = static_cast<phi::DenseTensor*>({self.view_map[self.outputs['names'][0]]}.impl().get());
{code_indent} {self.view_map[self.outputs['names'][0]]}_remap->ShareBufferWith(*kernel_out);
{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][0]]}_remap);
"""
elif len(out_dtype_list) > 1:
for i in range(len(out_dtype_list)):
if (
not inplace_flag
and self.view_map is not None
and self.outputs['names'][i] in self.view_map
):
remap_code += f"""
{code_indent} phi::DenseTensor * {self.view_map[self.outputs['names'][i]]}_remap = static_cast<phi::DenseTensor*>({self.view_map[self.outputs['names'][i]]}.impl().get());
{code_indent} {self.view_map[self.outputs['names'][i]]}_remap->ShareBufferWith(*kernel_out_{i});
{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][i]]}_remap);
"""
return remap_code
def header_include(): def header_include():
return """ return """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册