提交 7927e98f 编写于 作者: M Megvii Engine Team

perf(mge): speed up PixelShuffle

GitOrigin-RevId: 942e7557450bf71fb58962112fbe45c19183f6ca
上级 43e5f41c
......@@ -104,24 +104,27 @@ from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheO
from .version import __version__
logger = get_logger(__name__)
ngpus = get_device_count("gpu")
supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions())
for idx in range(ngpus):
prop = get_cuda_device_property(idx)
cur_sm = str(prop.major * 10 + prop.minor)
if not cur_sm in supported_sm_versions:
logger.warning(
"{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format(
prop.name,
cur_sm,
"capabilities" if len(supported_sm_versions) > 1 else "capability",
" ".join(["sm_" + v for v in supported_sm_versions]),
prop.name,
def _check_sm_version():
cur_logger = get_logger(__name__)
ngpus = get_device_count("gpu")
supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions())
for idx in range(ngpus):
prop = get_cuda_device_property(idx)
cur_sm = str(prop.major * 10 + prop.minor)
if not cur_sm in supported_sm_versions:
cur_logger.warning(
"{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format(
prop.name,
cur_sm,
"capabilities" if len(supported_sm_versions) > 1 else "capability",
" ".join(["sm_" + v for v in supported_sm_versions]),
prop.name,
)
)
)
_check_sm_version()
_set_fork_exec_path_for_timed_func(
sys.executable,
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
......
......@@ -16,6 +16,7 @@ from ..core._imperative_rt.core2 import (
adaptive_pool2d_cpp,
apply,
dtype_promotion,
pixel_shuffle_cpp,
)
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core.ops import builtin
......@@ -1849,16 +1850,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order):
return layerPixelShuffle
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
"""
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero
or more batch dimensions.
:param inp: input tensor.
:param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor.
"""
def layerPixelShuffle_traceable(inp, upscale_factor):
assert upscale_factor > 0, "upscale_factor should larger than 0"
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
assert (
......@@ -1899,6 +1891,19 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
return outvar
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
"""
Rearranges elements in a tensor of shape `(..., C * r^2, H, W)` to a tensor of
shape `(..., C, H * r, W * r)`, where `r` is an upscale factor, where `...` is
zero or more batch dimensions.
:param inp: input tensor.
:param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor.
"""
return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable)
from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip
from .metric import * # isort:skip
......
......@@ -349,6 +349,28 @@ std::optional<ValueRefList> removeAxis_grad_rule(
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> pixelShuffle_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& pixelShuffle = op.cast_final_safe<PixelShuffle>();
mgb_assert(inputs.size() == 1);
bool flag = inputs_require_grad[0];
auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor);
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(1);
if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
......@@ -382,6 +404,8 @@ struct Init {
RemoveAxis::typeinfo(), removeAxis_grad_rule);
CustomBackward::register_grad_rule(
FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule(
PixelShuffle::typeinfo(), pixelShuffle_grad_rule);
}
} _;
......
......@@ -438,6 +438,7 @@ WRAP_FUNC_PY35(batched_matmul_cpp);
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
WRAP_FUNC_PY35(astensor1d_cpp);
WRAP_FUNC_PY35(pixel_shuffle_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -595,6 +596,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -1378,7 +1378,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
} else {
auto&& inp_ndim = get_ndim_safe(inp_hdl);
ndim += inp_ndim.first;
unknown_ndim &= ~inp_ndim.second;
unknown_ndim &= !inp_ndim.second;
}
for (size_t i = 0; i < axis.size(); ++i) {
if (axis[i] < 0) {
......@@ -1446,6 +1446,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
}
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object obj = _expand_args(args);
py::list lis;
......@@ -1562,6 +1563,19 @@ py::object _batched_matmul_cpp(
}
}
py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) {
if (enable_fastpath(inp) && PyLong_Check(val.ptr())) {
std::shared_ptr<OpDef> op = PixelShuffle::make(val.cast<int32_t>());
py::object Op = py::cast(op);
PyObject* p[2] = {Op.ptr(), inp.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
} else {
// fallback to traceable subgraph implement
return func(inp, val);
}
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(args[0]).release().ptr();
......@@ -1632,6 +1646,13 @@ PyObject* adaptive_pool2d_cpp(PyObject* self, PyObject* const* args, size_t narg
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _pixel_shuffle_cpp(args[0], args[1], args[2]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _Const(args[0], args[1], args[2], args[3]).release().ptr();
......
......@@ -40,4 +40,6 @@ PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs
PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
......@@ -462,3 +462,19 @@ def test_dot():
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy())
def test_pixel_shuffle():
x = np.random.rand(2, 3, 16, 3, 4).astype("float32")
x = mge.Tensor(x)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
def f(x):
p = F.pixel_shuffle(x, 2)
return p * p
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(2 * x.numpy(), x.grad.numpy())
......@@ -255,6 +255,7 @@ def test_conv_bias_int4():
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
@pytest.mark.require_ngpu(1)
@pytest.mark.skipif(
get_cuda_compute_capability(0) < 61,
reason="does not support int8 when gpu compute capability less than 6.1",
......
......@@ -290,6 +290,7 @@ def test_deformable_ps_roi_pooling():
check_pygraph_dump(fwd, [inp, rois, trans], [result])
@pytest.mark.require_ngpu(1)
@pytest.mark.skipif(
get_cuda_compute_capability(0) < 61,
reason="does not support int8 when gpu compute capability less than 6.1",
......
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
using namespace megdnn;
namespace mgb::imperative {
namespace pixel_shuffle {
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<PixelShuffle>();
auto&& src = inputs[0];
auto&& layout = src->layout();
mgb_assert(
layout.ndim >= 3,
"the input dimension of pixel_shuffle should be larger than or equal to 3");
size_t idx = layout.ndim - 3;
mgb_assert(
layout[idx] % (op.factor * op.factor) == 0,
"the -3 dimension should be divided by (upscale_factor ** 2)");
TensorLayout tlayout;
TensorShape tshp; // {N, C, r, r, H, W}
TensorShape vshp; // {..., C, Hr, Wr}
tshp.ndim = 6;
vshp.ndim = layout.ndim;
tshp[0] = 1;
for (size_t i = 0; i < idx; ++i) {
tshp[0] *= layout[i];
vshp[i] = layout[i];
}
tshp[1] = layout[idx] / (op.factor * op.factor);
tshp[2] = tshp[3] = op.factor;
tshp[4] = layout[idx + 1];
tshp[5] = layout[idx + 2];
vshp[idx] = tshp[1];
vshp[idx + 1] = layout[idx + 1] * op.factor;
vshp[idx + 2] = layout[idx + 2] * op.factor;
tlayout = layout.reshape(tshp).dimshuffle({0, 1, 4, 2, 5, 3});
TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout);
out->to_contiguous_inplace(); // relayout
tlayout = out->layout().reshape(vshp);
return {Tensor::make(out->blob(), out->offset(), tlayout)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op = def.cast_final_safe<PixelShuffle>();
mgb_assert(op.factor > 0, "upscale_factor should be larger than 0");
auto&& src = inputs[0];
if (src.layout.ndim == 0) {
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
src.layout.ndim >= 3,
"the input dimension of pixel_shuffle should be larger than or equal to 3");
size_t idx = src.layout.ndim - 3;
mgb_assert(
src.layout[idx] % (op.factor * op.factor) == 0,
"the -3 dimension should be divided by (upscale_factor ** 2)");
TensorShape tshp;
tshp.ndim = src.layout.ndim;
for (size_t i = 0; i < idx; ++i) {
tshp[i] = src.layout[i];
}
tshp[idx] = src.layout[idx] / (op.factor * op.factor);
tshp[idx + 1] = src.layout[idx + 1] * op.factor;
tshp[idx + 2] = src.layout[idx + 2] * op.factor;
return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
OP_TRAIT_REG(PixelShuffle, PixelShuffle)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace pixel_shuffle
namespace pixel_shuffle_backward {
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<PixelShuffleBackward>();
auto&& src = inputs[0];
auto&& layout = src->layout();
size_t idx = layout.ndim - 3;
TensorLayout tlayout;
TensorShape tshp; // {N, C, H, r, W, r}
TensorShape vshp; // {..., Cr^2, H, W}
tshp.ndim = 6;
vshp.ndim = layout.ndim;
tshp[0] = 1;
for (size_t i = 0; i < idx; ++i) {
tshp[0] *= layout[i];
vshp[i] = layout[i];
}
tshp[1] = layout[idx];
tshp[3] = tshp[5] = op.factor;
tshp[2] = layout[idx + 1] / op.factor;
tshp[4] = layout[idx + 2] / op.factor;
vshp[idx] = tshp[1] * op.factor * op.factor;
vshp[idx + 1] = tshp[2];
vshp[idx + 2] = tshp[4];
tlayout = layout.reshape(tshp).dimshuffle({0, 1, 3, 5, 2, 4});
TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout);
out->to_contiguous_inplace(); // relayout
tlayout = out->layout().reshape(vshp);
return {Tensor::make(out->blob(), out->offset(), tlayout)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op = def.cast_final_safe<PixelShuffleBackward>();
auto&& src = inputs[0];
if (src.layout.ndim == 0) {
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
}
size_t idx = src.layout.ndim - 3;
TensorShape tshp;
tshp.ndim = src.layout.ndim;
for (size_t i = 0; i < idx; ++i) {
tshp[i] = src.layout[i];
}
tshp[idx] = src.layout[idx] * op.factor * op.factor;
tshp[idx + 1] = src.layout[idx + 1] / op.factor;
tshp[idx + 2] = src.layout[idx + 2] / op.factor;
return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
OP_TRAIT_REG(PixelShuffleBackward, PixelShuffleBackward)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace pixel_shuffle_backward
} // namespace mgb::imperative
......@@ -435,6 +435,18 @@ def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">;
def PixelShuffle: MgbHashableOp<"PixelShuffle"> {
let extraArguments = (ins
MgbI32Attr:$factor
);
}
def PixelShuffleBackward: MgbHashableOp<"PixelShuffleBackward"> {
let extraArguments = (ins
MgbI32Attr:$factor
);
}
def ExternOpr: MgbHashableOp<"ExternOpr"> {
let extraArguments = (ins
MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册