From 194660462cf5d11f9cabd8e8b170abaf038646dd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 21 Jan 2021 17:52:35 +0800 Subject: [PATCH] feat(mge/funcitonal): add cvt_color opr python interface GitOrigin-RevId: 29e069fb2334a20545021f37a7133f7bfdbafd0b --- .../python/megengine/functional/__init__.py | 1 + .../python/megengine/functional/img_proc.py | 50 +++++++++++++++++++ .../test/unit/functional/test_functional.py | 11 ++++ imperative/src/impl/ops/img_proc.cpp | 33 ++++++++++++ src/core/include/megbrain/ir/ops.td | 2 + 5 files changed, 97 insertions(+) create mode 100644 imperative/python/megengine/functional/img_proc.py create mode 100644 imperative/src/impl/ops/img_proc.cpp diff --git a/imperative/python/megengine/functional/__init__.py b/imperative/python/megengine/functional/__init__.py index 3b1ac5a2..976e96c1 100644 --- a/imperative/python/megengine/functional/__init__.py +++ b/imperative/python/megengine/functional/__init__.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=redefined-builtin from .elemwise import * +from .img_proc import * from .math import * from .nn import * from .tensor import * diff --git a/imperative/python/megengine/functional/img_proc.py b/imperative/python/megengine/functional/img_proc.py new file mode 100644 index 00000000..5222b47d --- /dev/null +++ b/imperative/python/megengine/functional/img_proc.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ..core._imperative_rt.core2 import apply +from ..core.ops import builtin +from ..tensor import Tensor + +__all__ = [ + "cvt_color", +] + + +def cvt_color(inp: Tensor, mode: str = ""): + r""" + Convert images from one format to another + + :param inp: input images. + :param mode: format mode. + :return: convert result. + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.functional as F + + x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32)) + y = F.img_proc.cvt_color(x, mode="RGB2GRAY") + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [[[[0.86555195]]]] + + """ + assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" + mode = getattr(builtin.CvtColor.Mode, mode) + assert isinstance(mode, builtin.CvtColor.Mode) + op = builtin.CvtColor(mode=mode) + (out,) = apply(op, inp) + return out diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 0def0e4c..503e5d34 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -704,3 +704,14 @@ def test_argmxx_on_inf(): assert all(run_argmax() >= 0) assert all(run_argmin() >= 0) + + +def test_cvt_color(): + def rgb2gray(rgb): + return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + + inp = np.random.randn(3, 3, 3, 3).astype(np.float32) + out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32) + x = tensor(inp) + y = F.img_proc.cvt_color(x, mode="RGB2GRAY") + np.testing.assert_allclose(y.numpy(), out, atol=1e-5) diff --git a/imperative/src/impl/ops/img_proc.cpp b/imperative/src/impl/ops/img_proc.cpp new file mode 100644 index 00000000..38497f7d --- /dev/null +++ b/imperative/src/impl/ops/img_proc.cpp @@ -0,0 +1,33 @@ +/** + * \file imperative/src/impl/ops/img_proc.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/imgproc.h" + +#include "../op_trait.h" + +namespace mgb { +namespace imperative { + +namespace { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::CvtColor::make(inputs[0], op.param()); +} +OP_TRAIT_REG(CvtColor, CvtColor) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} +} +} \ No newline at end of file diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 0f87960d..dd0001b7 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -254,4 +254,6 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> { ); } +def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; + #endif // MGB_OPS -- GitLab