未验证 提交 e217e965 编写于 作者: D Double_V 提交者: GitHub

fix pool bug (#27366)

上级 d9366194
......@@ -713,7 +713,7 @@ def max_pool2d(x,
'data_format', data_format)
return output
op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "max_pool2d"
op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "pool2d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
......@@ -839,7 +839,7 @@ def max_pool3d(x,
'data_format', data_format)
return output
op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "max_pool3d"
op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "pool3d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册