未验证 提交 9a7dc249 编写于 作者: W wanghuancoder 提交者: GitHub

fix stride with ir bug (#56420)

上级 ed9ec699
...@@ -1225,7 +1225,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1225,7 +1225,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
return f""" return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; {code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); {code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
{code_indent} const auto& kernel = kernel_result.kernel; {code_indent} const auto& kernel = kernel_result.kernel;
{code_indent} if (FLAGS_low_precision_op_list) {{ {code_indent} if (FLAGS_low_precision_op_list) {{
{code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type); {code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
......
...@@ -216,7 +216,9 @@ KernelFactory::GetLowPrecisionKernelList() { ...@@ -216,7 +216,9 @@ KernelFactory::GetLowPrecisionKernelList() {
} }
KernelResult KernelFactory::SelectKernelOrThrowError( KernelResult KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& const_kernel_key) const { const std::string& kernel_name,
const KernelKey& const_kernel_key,
bool use_strided_kernel) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -224,7 +226,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -224,7 +226,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernels_.end(), kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel) { if (FLAGS_use_stride_kernel && use_strided_kernel) {
auto stride_kernel_iter = iter->second.find( auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU ? paddle::experimental::Backend::GPU
......
...@@ -325,7 +325,8 @@ class KernelFactory { ...@@ -325,7 +325,8 @@ class KernelFactory {
bool HasStructuredKernel(const std::string& op_type) const; bool HasStructuredKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name, KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key,
bool use_strided_kernel = false) const;
bool HasKernel(const std::string& kernel_name, bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册