diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index ff1d468c73341ad540451a0a0e371c75de57dee3..d32b9b61c2f161dd4c7164647109618acf8867b0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -194,6 +194,8 @@ def conv1d( ) (output,) = apply(op, inp, weight) if bias is not None: + if amp._enabled: + (bias,) = cast_tensors(bias) output += bias return output @@ -260,6 +262,8 @@ def conv2d( ) (output,) = apply(op, inp, weight) if bias is not None: + if amp._enabled: + (bias,) = cast_tensors(bias) output += bias return output diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index ae3c4aedeb47b59e7609c7c169588ea865b46cb6..3616b9e75d2e45f9ff2958b53a650c9b98519f82 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -896,13 +896,17 @@ def test_conv3d_zero_stride_numpy_array(): out.numpy() -def test_conv1d(): +@pytest.mark.parametrize("bias", [True, False]) +def test_conv1d(bias): inp = tensor(np.ones((2, 2, 4), dtype=np.float32)) weight = tensor(np.ones((3, 2, 2), dtype=np.float32)) - out = F.conv1d(inp, weight, None, 2, 0, 1, 1) + bias = tensor(np.ones((1, 3, 1), dtype=np.float32)) if bias else None + out = F.conv1d(inp, weight, bias, 2, 0, 1, 1) np.testing.assert_equal( out.numpy(), - np.array( + np.array([[[5, 5], [5, 5], [5, 5]], [[5, 5], [5, 5], [5, 5]]], dtype=np.float32) + if bias is not None + else np.array( [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32 ), ) @@ -928,13 +932,15 @@ def test_batchnorm2d_autocast(): np.testing.assert_allclose(out.numpy(), expected.numpy()) -def test_conv3d(): +@pytest.mark.parametrize("bias", [True, False]) +def test_conv3d(bias): inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) - out = F.conv3d(inp, weight, None, 2, 0, 1, 1) - np.testing.assert_equal( - out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16 - ) + bias = tensor(np.ones((1, 3, 1, 1, 1), dtype=np.float32)) if bias else None + out = F.conv3d(inp, weight, bias, 2, 0, 1, 1) + target = np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16 + target = target + 1 if bias is not None else target + np.testing.assert_equal(out.numpy(), target) def test_condtake(): diff --git a/imperative/src/impl/ops/convolution.cpp b/imperative/src/impl/ops/convolution.cpp index f15f4d873f6a4ec7198ccbd4f90240e42cc0b0e9..31b991d786392902fa5a71b5552fa801ebe231ae 100644 --- a/imperative/src/impl/ops/convolution.cpp +++ b/imperative/src/impl/ops/convolution.cpp @@ -117,13 +117,13 @@ std::tuple, bool> infer_output_attrs_fallible( desc.comp_node = inputs[0].comp_node; TensorLayout src = inputs[0].layout; + TensorLayout filter = inputs[1].layout; size_t src_ndim = src.ndim; - if (src_ndim == 0) { - desc.layout = src; + if (src_ndim == 0 || filter.ndim == 0) { + desc.layout = TensorLayout{{}, src.dtype}; return {dests, false}; } - TensorLayout filter = inputs[1].layout; desc.layout = do_shape_infer(def, src_ndim, src, filter); return {dests, true}; } @@ -165,23 +165,24 @@ SmallVector apply_on_physical_tensor( param.format = conv.format; // shape infer - TensorLayout shp({0}, inputs[0]->dtype()); - shp.ndim = 0; + TensorLayout empty_shp({0}, inputs[0]->dtype()); + empty_shp.ndim = 0; size_t sz = setup_algo( - {inp_shapes[0], inp_shapes[1], shp, shp, oup_shapes[0]}, dnn_opr.op.get(), - 0, false, false, cn, conv.policy(), false); + {inp_shapes[0], inp_shapes[1], empty_shp, empty_shp, oup_shapes[0]}, + dnn_opr.op.get(), 0, false, false, cn, conv.policy(), false); // alloc memory - DeviceTensorND bias = BlobManager::inst()->alloc_workspace_with_defrag(cn, shp); + DeviceTensorND empty_bias = + BlobManager::inst()->alloc_workspace_with_defrag(cn, empty_shp); TensorLayout w_layout({sz}, dtype::Byte()); auto dnn_wk = dnn_opr.create_workspace(w_layout); // exeucte dnn_opr.op->exec( - inp_tensornds[0], inp_tensornds[1], bias.as_megdnn(), bias.as_megdnn(), - out.as_megdnn(), nullptr, dnn_wk); + inp_tensornds[0], inp_tensornds[1], empty_bias.as_megdnn(), + empty_bias.as_megdnn(), out.as_megdnn(), nullptr, dnn_wk); return {Tensor::make(out)}; } @@ -333,18 +334,15 @@ TensorLayout convbwd_do_shape_infer( std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { - auto&& conv = static_cast(def); - SmallVector dests(1); auto&& desc = dests[0]; desc.comp_node = inputs[0].comp_node; TensorLayout filter = inputs[0].layout; TensorLayout diff = inputs[1].layout; - size_t filter_ndim = filter.ndim; size_t diff_ndim = diff.ndim; - if (diff_ndim == 0) { - desc.layout = diff; + if (diff_ndim == 0 || filter.ndim == 0) { + desc.layout = TensorLayout{{}, diff.dtype}; return {dests, false}; } @@ -506,12 +504,13 @@ std::tuple, bool> infer_output_attrs_fallible( desc.comp_node = inputs[0].comp_node; TensorLayout src = inputs[0].layout; + TensorLayout filter = inputs[1].layout; size_t src_ndim = src.ndim; - if (src_ndim == 0) { + if (src_ndim == 0 || filter.ndim == 0) { + desc.layout = TensorLayout{{}, src.dtype}; return {dests, false}; } - TensorLayout filter = inputs[1].layout; desc.layout = do_shape_infer(def, src_ndim, src, filter); return {dests, true}; } @@ -549,8 +548,11 @@ SmallVector apply_on_physical_tensor( DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout); - TensorLayout w_layout({sz}, dtype::Byte()); - auto dnn_wk = dnn_opr.create_workspace(w_layout); + megdnn::Workspace dnn_wk; + if (sz != 0) { + TensorLayout w_layout({sz}, dtype::Byte()); + dnn_wk = dnn_opr.create_workspace(w_layout); + } // exeucte dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk); @@ -581,7 +583,7 @@ std::tuple, bool> infer_output_attrs_fallible( auto&& diff = inputs[1]; auto& cn = weight.comp_node; - if (weight.layout.ndim == 0) { + if (weight.layout.ndim == 0 || diff.layout.ndim == 0) { return {{{TensorLayout{weight.layout.dtype}, cn, {}}}, false}; } @@ -616,9 +618,8 @@ SmallVector apply_on_physical_tensor( cn, op_def.policy(), false); megdnn::Workspace dnn_wk; if (wk_size != 0) { - auto wk = Blob::make(cn, wk_size); - dnn_wk.raw_ptr = wk->storage().get(); - dnn_wk.size = wk_size; + TensorLayout w_layout({wk_size}, dtype::Byte()); + dnn_wk = caller.create_workspace(w_layout); } dnn_opr->exec(weight, diff, oup.as_megdnn(), dnn_wk);