未验证 提交 6fc5b7a5 编写于 作者: J JZ-LIANG 提交者: GitHub

revert bug deps (#54901)

上级 ac94b135
......@@ -1278,17 +1278,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
}}
"""
def gen_dist_tensor_code(self):
# define the DistTensorSpec vector for input and output tensors
api_code = " \n std::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
# get DistTensorSpec for each input tensor
for tensor_name in self.inputs['names']:
api_code += f" input_specs.emplace_back(paddle::distributed::auto_parallel::DistTensorSpec({tensor_name}));\n"
api_code += "\n"
return api_code
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
......@@ -1297,8 +1286,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()}
"""
# if api_func_name == 'matmul':
# api_code += self.gen_dist_tensor_code()
if len(self.kernel['func']) > 1:
kernel_dispatch_code = ''
......
......@@ -379,8 +379,6 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册