未验证 提交 89530384 编写于 作者: L Leo Chen 提交者: GitHub

Fix transpose in conv cudnn kernel when addto enabled (#28295)

上级 6cebd714
...@@ -293,8 +293,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -293,8 +293,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f; ScalingParamType<T> alpha = 1.0f;
ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f; ScalingParamType<T> beta = 0.0f;
VLOG(4) << "Conv: use_addto = " << ctx.Attr<bool>("use_addto");
// NOTE(zhiqiu): inplace addto is not supportted in double grad yet.
// ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f;
// VLOG(4) << "Conv: use_addto = " << ctx.Attr<bool>("use_addto");
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
...@@ -387,6 +391,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -387,6 +391,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>( ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input_grad, &transformed_input_grad_channel); ctx, input_grad, &transformed_input_grad_channel);
// NOTE(zhiqiu): If inplace_addto strategy is enabled, we need to copy
// the data of input_grad to transformed_input_grad_channel.
if (ctx.Attr<bool>("use_addto")) {
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input_grad, &transformed_input_grad_channel);
}
} }
} else { } else {
transformed_input_channel.ShareDataWith(*input); transformed_input_channel.ShareDataWith(*input);
......
...@@ -30,22 +30,21 @@ class ConvBNLayer(fluid.Layer): ...@@ -30,22 +30,21 @@ class ConvBNLayer(fluid.Layer):
filter_size, filter_size,
stride=1, stride=1,
groups=1, groups=1,
act=None, data_format="NCHW"):
use_cudnn=False):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = fluid.dygraph.Conv2D( self._conv = paddle.nn.Conv2D(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None,
bias_attr=False, bias_attr=False,
use_cudnn=use_cudnn) data_format=data_format)
self._batch_norm = fluid.dygraph.BatchNorm(num_filters, act=act) self._batch_norm = paddle.nn.BatchNorm(
num_filters, data_layout=data_format)
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
...@@ -53,19 +52,20 @@ class ConvBNLayer(fluid.Layer): ...@@ -53,19 +52,20 @@ class ConvBNLayer(fluid.Layer):
return y return y
def create_program(): def create_program(data_format="NCHW"):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
x = fluid.data(name='img', shape=[-1, 3, 224, 224]) x = fluid.data(name='img', shape=[-1, 3, 224, 224])
x.stop_gradient = False x.stop_gradient = False
if data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x = fluid.layers.prelu(x, mode="channel") x = fluid.layers.prelu(x, mode="channel")
conv = ConvBNLayer( conv = ConvBNLayer(
num_channels=3, num_channels=3,
num_filters=3, num_filters=3,
filter_size=1, filter_size=1,
act='relu', data_format=data_format)
use_cudnn=True)
y = conv(x) + x y = conv(x) + x
loss = fluid.layers.reduce_sum(y) loss = fluid.layers.reduce_sum(y)
...@@ -77,7 +77,7 @@ def create_program(): ...@@ -77,7 +77,7 @@ def create_program():
class TestInplaceAddto(unittest.TestCase): class TestInplaceAddto(unittest.TestCase):
def test_result(self): def check_result(self, data_format="NCHW"):
def run_program(enable_addto): def run_program(enable_addto):
np.random.seed(10) np.random.seed(10)
paddle.seed(10) paddle.seed(10)
...@@ -85,7 +85,7 @@ class TestInplaceAddto(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestInplaceAddto(unittest.TestCase):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True}) fluid.set_flags({"FLAGS_cudnn_deterministic": True})
fluid.set_flags({"FLAGS_max_inplace_grad_add": 2}) fluid.set_flags({"FLAGS_max_inplace_grad_add": 2})
loss, main, startup, w = create_program() loss, main, startup, w = create_program(data_format=data_format)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -98,7 +98,7 @@ class TestInplaceAddto(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestInplaceAddto(unittest.TestCase):
exe.run(startup) exe.run(startup)
img = np.random.uniform(-128, 128, img = np.random.uniform(-128, 128,
[8, 3, 224, 224]).astype(np.float32) [8, 3, 224, 224]).astype(np.float32)
for i in range(2): for i in range(10):
res = exe.run(compiled, res = exe.run(compiled,
feed={'img': img}, feed={'img': img},
fetch_list=[loss.name, w.name]) fetch_list=[loss.name, w.name])
...@@ -106,9 +106,16 @@ class TestInplaceAddto(unittest.TestCase): ...@@ -106,9 +106,16 @@ class TestInplaceAddto(unittest.TestCase):
res1, w1 = run_program(True) res1, w1 = run_program(True)
res2, w2 = run_program(False) res2, w2 = run_program(False)
print(res1, res2)
self.assertTrue(np.array_equal(res1, res2)) self.assertTrue(np.array_equal(res1, res2))
def test_nchw(self):
self.check_result()
def test_nhwc(self):
self.check_result("NHWC")
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册