未验证 提交 9a7dc249 编写于 作者: W wanghuancoder 提交者: GitHub

fix stride with ir bug (#56420)

上级 ed9ec699
......@@ -1225,7 +1225,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
{code_indent} const auto& kernel = kernel_result.kernel;
{code_indent} if (FLAGS_low_precision_op_list) {{
{code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
......
......@@ -216,7 +216,9 @@ KernelFactory::GetLowPrecisionKernelList() {
}
KernelResult KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& const_kernel_key) const {
const std::string& kernel_name,
const KernelKey& const_kernel_key,
bool use_strided_kernel) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
......@@ -224,7 +226,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel) {
if (FLAGS_use_stride_kernel && use_strided_kernel) {
auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU
......
......@@ -325,7 +325,8 @@ class KernelFactory {
bool HasStructuredKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const;
const KernelKey& kernel_key,
bool use_strided_kernel = false) const;
bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册