diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 860f6f0e90514c089d55f920fe261f878974be43..f8eca4b28b7eb25412ffbc3141d428b732686509 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -104,24 +104,27 @@ from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheO from .version import __version__ -logger = get_logger(__name__) -ngpus = get_device_count("gpu") -supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions()) -for idx in range(ngpus): - prop = get_cuda_device_property(idx) - cur_sm = str(prop.major * 10 + prop.minor) - if not cur_sm in supported_sm_versions: - logger.warning( - "{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format( - prop.name, - cur_sm, - "capabilities" if len(supported_sm_versions) > 1 else "capability", - " ".join(["sm_" + v for v in supported_sm_versions]), - prop.name, +def _check_sm_version(): + cur_logger = get_logger(__name__) + ngpus = get_device_count("gpu") + supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions()) + for idx in range(ngpus): + prop = get_cuda_device_property(idx) + cur_sm = str(prop.major * 10 + prop.minor) + if not cur_sm in supported_sm_versions: + cur_logger.warning( + "{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format( + prop.name, + cur_sm, + "capabilities" if len(supported_sm_versions) > 1 else "capability", + " ".join(["sm_" + v for v in supported_sm_versions]), + prop.name, + ) ) - ) +_check_sm_version() + _set_fork_exec_path_for_timed_func( sys.executable, os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 8594bf564f5eaa82b46244ddb6c6c4f95c6fa5c7..0137b6d7d51c431eaa40980d87ce4c9297915452 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -16,6 +16,7 @@ from ..core._imperative_rt.core2 import ( adaptive_pool2d_cpp, apply, dtype_promotion, + pixel_shuffle_cpp, ) from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core.ops import builtin @@ -1849,16 +1850,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order): return layerPixelShuffle -def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: - """ - Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of - shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero - or more batch dimensions. - - :param inp: input tensor. - :param upscale_factor: upscale factor of pixel_shuffle. - :return: output tensor. - """ +def layerPixelShuffle_traceable(inp, upscale_factor): assert upscale_factor > 0, "upscale_factor should larger than 0" assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" assert ( @@ -1899,6 +1891,19 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: return outvar +def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: + """ + Rearranges elements in a tensor of shape `(..., C * r^2, H, W)` to a tensor of + shape `(..., C, H * r, W * r)`, where `r` is an upscale factor, where `...` is + zero or more batch dimensions. + + :param inp: input tensor. + :param upscale_factor: upscale factor of pixel_shuffle. + :return: output tensor. + """ + return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable) + + from .quantized import conv_bias_activation # isort:skip from .loss import * # isort:skip from .metric import * # isort:skip diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index d59437d489b33a181c9ee7a1e51e9385b2f3e3b6..51982c1aa21b5a45283e387e24ec7cdd97f928b9 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -349,6 +349,28 @@ std::optional removeAxis_grad_rule( return imperative::apply(op, inputs); } +std::optional pixelShuffle_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& pixelShuffle = op.cast_final_safe(); + mgb_assert(inputs.size() == 1); + bool flag = inputs_require_grad[0]; + auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor); + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(1); + if (grad && flag_) { + ret[0] = imperative::apply(*grad_op_, grad)[0]; + } + return ret; + }); + maker.finalize(); + return imperative::apply(op, inputs); +} + std::optional fastpathcopy_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { @@ -382,6 +404,8 @@ struct Init { RemoveAxis::typeinfo(), removeAxis_grad_rule); CustomBackward::register_grad_rule( FastpathCopy::typeinfo(), fastpathcopy_grad_rule); + CustomBackward::register_grad_rule( + PixelShuffle::typeinfo(), pixelShuffle_grad_rule); } } _; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 834a7af5f322d6a78a9ca9b9c5b70641fda84950..e6f5e13ca102439708725bad15be190e1d292510 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -438,6 +438,7 @@ WRAP_FUNC_PY35(batched_matmul_cpp); WRAP_FUNC_PY35(convert_single_value_cpp); WRAP_FUNC_PY35(convert_inputs_cpp); WRAP_FUNC_PY35(astensor1d_cpp); +WRAP_FUNC_PY35(pixel_shuffle_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -595,6 +596,7 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), + MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 70375d574a05ceae0c4a9a1544a136f23c623034..7171566dcc8b33ab828253a7fe5a08bb93b6fae6 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1378,7 +1378,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { } else { auto&& inp_ndim = get_ndim_safe(inp_hdl); ndim += inp_ndim.first; - unknown_ndim &= ~inp_ndim.second; + unknown_ndim &= !inp_ndim.second; } for (size_t i = 0; i < axis.size(); ++i) { if (axis[i] < 0) { @@ -1446,6 +1446,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); return ret[0]; } + py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { py::object obj = _expand_args(args); py::list lis; @@ -1562,6 +1563,19 @@ py::object _batched_matmul_cpp( } } +py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { + if (enable_fastpath(inp) && PyLong_Check(val.ptr())) { + std::shared_ptr op = PixelShuffle::make(val.cast()); + py::object Op = py::cast(op); + PyObject* p[2] = {Op.ptr(), inp.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); + return ret[0]; + } else { + // fallback to traceable subgraph implement + return func(inp, val); + } +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(args[0]).release().ptr(); @@ -1632,6 +1646,13 @@ PyObject* adaptive_pool2d_cpp(PyObject* self, PyObject* const* args, size_t narg PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _pixel_shuffle_cpp(args[0], args[1], args[2]).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { try { return _Const(args[0], args[1], args[2], args[3]).release().ptr(); diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index 6ceaa6bd6d284c9b8227cb15698f6b183b909f37..60a20164a7ede58b7160811330122a11ef54ab36 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -40,4 +40,6 @@ PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 6d51fe44056efb3472665343db207d1661de7fab..63a8304a46fd8d120de8e98db2da9c9400cbbb3f 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -462,3 +462,19 @@ def test_dot(): grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) + + +def test_pixel_shuffle(): + + x = np.random.rand(2, 3, 16, 3, 4).astype("float32") + x = mge.Tensor(x) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + + def f(x): + p = F.pixel_shuffle(x, 2) + return p * p + + y = f(x) + grad(y, F.ones_like(y)) + np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index 1003764fcec5dbe726f6502efcaf93066de43bf5..53929bb1299d896ae2722d15f3ed1e3e82640e99 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -255,6 +255,7 @@ def test_conv_bias_int4(): run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") +@pytest.mark.require_ngpu(1) @pytest.mark.skipif( get_cuda_compute_capability(0) < 61, reason="does not support int8 when gpu compute capability less than 6.1", diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index 348e917f361412642265a2da1cadd4dcea091716..7ec1b7a5ed7a4d2394b2f3d3dd4b00e93624f88d 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -290,6 +290,7 @@ def test_deformable_ps_roi_pooling(): check_pygraph_dump(fwd, [inp, rois, trans], [result]) +@pytest.mark.require_ngpu(1) @pytest.mark.skipif( get_cuda_compute_capability(0) < 61, reason="does not support int8 when gpu compute capability less than 6.1", diff --git a/imperative/src/impl/ops/pixel_shuffle.cpp b/imperative/src/impl/ops/pixel_shuffle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ff055325c02574f0d7499fd561347982a9fee79e --- /dev/null +++ b/imperative/src/impl/ops/pixel_shuffle.cpp @@ -0,0 +1,157 @@ +#include "../op_trait.h" +#include "megbrain/imperative/ops/autogen.h" + +using namespace megdnn; + +namespace mgb::imperative { + +namespace pixel_shuffle { + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto&& src = inputs[0]; + auto&& layout = src->layout(); + mgb_assert( + layout.ndim >= 3, + "the input dimension of pixel_shuffle should be larger than or equal to 3"); + size_t idx = layout.ndim - 3; + mgb_assert( + layout[idx] % (op.factor * op.factor) == 0, + "the -3 dimension should be divided by (upscale_factor ** 2)"); + TensorLayout tlayout; + TensorShape tshp; // {N, C, r, r, H, W} + TensorShape vshp; // {..., C, Hr, Wr} + tshp.ndim = 6; + vshp.ndim = layout.ndim; + tshp[0] = 1; + for (size_t i = 0; i < idx; ++i) { + tshp[0] *= layout[i]; + vshp[i] = layout[i]; + } + tshp[1] = layout[idx] / (op.factor * op.factor); + tshp[2] = tshp[3] = op.factor; + tshp[4] = layout[idx + 1]; + tshp[5] = layout[idx + 2]; + vshp[idx] = tshp[1]; + vshp[idx + 1] = layout[idx + 1] * op.factor; + vshp[idx + 2] = layout[idx + 2] * op.factor; + tlayout = layout.reshape(tshp).dimshuffle({0, 1, 4, 2, 5, 3}); + TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout); + out->to_contiguous_inplace(); // relayout + tlayout = out->layout().reshape(vshp); + return {Tensor::make(out->blob(), out->offset(), tlayout)}; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(op.factor > 0, "upscale_factor should be larger than 0"); + auto&& src = inputs[0]; + if (src.layout.ndim == 0) { + return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; + } + mgb_assert( + src.layout.ndim >= 3, + "the input dimension of pixel_shuffle should be larger than or equal to 3"); + size_t idx = src.layout.ndim - 3; + mgb_assert( + src.layout[idx] % (op.factor * op.factor) == 0, + "the -3 dimension should be divided by (upscale_factor ** 2)"); + TensorShape tshp; + tshp.ndim = src.layout.ndim; + for (size_t i = 0; i < idx; ++i) { + tshp[i] = src.layout[i]; + } + tshp[idx] = src.layout[idx] / (op.factor * op.factor); + tshp[idx + 1] = src.layout[idx + 1] * op.factor; + tshp[idx + 2] = src.layout[idx + 2] * op.factor; + return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + return layout_checker; +} + +OP_TRAIT_REG(PixelShuffle, PixelShuffle) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace pixel_shuffle + +namespace pixel_shuffle_backward { + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto&& src = inputs[0]; + auto&& layout = src->layout(); + size_t idx = layout.ndim - 3; + TensorLayout tlayout; + TensorShape tshp; // {N, C, H, r, W, r} + TensorShape vshp; // {..., Cr^2, H, W} + tshp.ndim = 6; + vshp.ndim = layout.ndim; + tshp[0] = 1; + for (size_t i = 0; i < idx; ++i) { + tshp[0] *= layout[i]; + vshp[i] = layout[i]; + } + tshp[1] = layout[idx]; + tshp[3] = tshp[5] = op.factor; + tshp[2] = layout[idx + 1] / op.factor; + tshp[4] = layout[idx + 2] / op.factor; + vshp[idx] = tshp[1] * op.factor * op.factor; + vshp[idx + 1] = tshp[2]; + vshp[idx + 2] = tshp[4]; + tlayout = layout.reshape(tshp).dimshuffle({0, 1, 3, 5, 2, 4}); + TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout); + out->to_contiguous_inplace(); // relayout + tlayout = out->layout().reshape(vshp); + return {Tensor::make(out->blob(), out->offset(), tlayout)}; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op = def.cast_final_safe(); + auto&& src = inputs[0]; + if (src.layout.ndim == 0) { + return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; + } + size_t idx = src.layout.ndim - 3; + TensorShape tshp; + tshp.ndim = src.layout.ndim; + for (size_t i = 0; i < idx; ++i) { + tshp[i] = src.layout[i]; + } + tshp[idx] = src.layout[idx] * op.factor * op.factor; + tshp[idx + 1] = src.layout[idx + 1] / op.factor; + tshp[idx + 2] = src.layout[idx + 2] / op.factor; + return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + return layout_checker; +} + +OP_TRAIT_REG(PixelShuffleBackward, PixelShuffleBackward) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace pixel_shuffle_backward + +} // namespace mgb::imperative diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 79e97951e0786497aaa6153a4e916c7f704cc20e..13148873b9a0274d66c0808cc2b756cd3acd6924 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -435,6 +435,18 @@ def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>; def FastpathCopy: MgbHashableOp<"FastpathCopy">; +def PixelShuffle: MgbHashableOp<"PixelShuffle"> { + let extraArguments = (ins + MgbI32Attr:$factor + ); +} + +def PixelShuffleBackward: MgbHashableOp<"PixelShuffleBackward"> { + let extraArguments = (ins + MgbI32Attr:$factor + ); +} + def ExternOpr: MgbHashableOp<"ExternOpr"> { let extraArguments = (ins MgbArrayAttr>:$output_shapes,