未验证 提交 a3b08bad 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix the backward maxpool (#32030)

上级 1e52f324
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/miopen_helper.h" #include "paddle/fluid/platform/miopen_helper.h"
#endif #endif
...@@ -264,6 +266,34 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -264,6 +266,34 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm"); std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
#ifdef PADDLE_WITH_HIP
if (pooling_type == "max") {
using OpKernelMap = paddle::framework::OperatorWithKernel::OpKernelMap;
using OpKernelFunc = paddle::framework::OperatorWithKernel::OpKernelFunc;
auto &all_op_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
std::string op_type = "pool2d_grad";
auto kernels_iter = all_op_kernels.find(op_type);
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
op_type));
OpKernelMap &kernels = kernels_iter->second;
paddle::framework::OpKernelType expected_kernel_key(
paddle::framework::ToDataType(typeid(T)), ctx.GetPlace());
auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
platform::errors::NotFound(
"Operator (%s) does not have kernel for %s.",
op_type, KernelTypeToString(expected_kernel_key)));
std::unique_ptr<OpKernelFunc> kernel_func_(
new OpKernelFunc(kernel_iter->second));
(*kernel_func_)(ctx);
return;
}
#endif
// update paddings // update paddings
auto in_x_dims = input->dims(); auto in_x_dims = input->dims();
framework::DDim data_dims; framework::DDim data_dims;
......
...@@ -174,6 +174,11 @@ class Conv2D(layers.Layer): ...@@ -174,6 +174,11 @@ class Conv2D(layers.Layer):
dtype='float32'): dtype='float32'):
assert param_attr is not False, "param_attr should not be False here." assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__() super(Conv2D, self).__init__()
if (core.is_compiled_with_cuda() and paddle.fluid.get_flags(
"FLAGS_conv2d_disable_cudnn")["FLAGS_conv2d_disable_cudnn"]):
use_cudnn = False
self._num_channels = num_channels self._num_channels = num_channels
self._groups = groups self._groups = groups
self._stride = utils.convert_to_list(stride, 2, 'stride') self._stride = utils.convert_to_list(stride, 2, 'stride')
......
...@@ -1470,13 +1470,14 @@ class TestConv2DAPI_Error(unittest.TestCase): ...@@ -1470,13 +1470,14 @@ class TestConv2DAPI_Error(unittest.TestCase):
not (core.is_compiled_with_cuda() or core.is_compiled_with_rocm()), not (core.is_compiled_with_cuda() or core.is_compiled_with_rocm()),
"core is not compiled with CUDA or ROCM") "core is not compiled with CUDA or ROCM")
class TestConv2DEnviron(unittest.TestCase): class TestConv2DEnviron(unittest.TestCase):
def run_conv2d_api(self): def run1(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
inputs = fluid.layers.data( inputs = fluid.layers.data(
shape=[2, 3, 5, 5], shape=[2, 3, 5, 5],
append_batch_size=False, append_batch_size=False,
name="inputs", name="inputs",
dtype="float32") dtype="float32")
fluid.layers.conv2d( result = fluid.layers.conv2d(
input=inputs, input=inputs,
num_filters=4, num_filters=4,
filter_size=[3, 3], filter_size=[3, 3],
...@@ -1485,20 +1486,43 @@ class TestConv2DEnviron(unittest.TestCase): ...@@ -1485,20 +1486,43 @@ class TestConv2DEnviron(unittest.TestCase):
dilation=[1, 1], dilation=[1, 1],
groups=1, groups=1,
data_format="NCHW") data_format="NCHW")
exe = fluid.Executor(place)
x_var = paddle.uniform((2, 3, 5, 5), dtype="float32", min=-1., max=1.) exe.run(fluid.default_startup_program())
fetches = exe.run(fluid.default_main_program(),
feed={"inputs": self.input_np},
fetch_list=[result])
def run2(self, place):
with fluid.dygraph.guard(place):
inputs = fluid.dygraph.to_variable(self.input_np)
conv = paddle.nn.Conv2D( conv = paddle.nn.Conv2D(
in_channels=3, in_channels=3,
out_channels=4, out_channels=4,
kernel_size=(3, 3), kernel_size=(3, 3),
data_format="NCHW") data_format="NCHW")
y_var = conv(x_var) result = conv(inputs)
def run3(self, place):
with fluid.dygraph.guard(place):
inputs = fluid.dygraph.to_variable(self.input_np)
conv = paddle.fluid.dygraph.nn.Conv2D(
num_channels=3,
num_filters=4,
filter_size=(3, 3), )
result = conv(inputs)
def run_all(self, place):
self.run1(place)
self.run2(place)
self.run3(place)
def test_environ(self): def test_environ(self):
self.input_np = np.random.random([2, 3, 5, 5]).astype("float32")
for place in [paddle.CPUPlace(), paddle.CUDAPlace(0)]:
fluid.set_flags({'FLAGS_conv2d_disable_cudnn': False}) fluid.set_flags({'FLAGS_conv2d_disable_cudnn': False})
self.run_conv2d_api() self.run_all(place)
fluid.set_flags({'FLAGS_conv2d_disable_cudnn': True}) fluid.set_flags({'FLAGS_conv2d_disable_cudnn': True})
self.run_conv2d_api() self.run_all(place)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册