diff --git a/imperative/python/megengine/experimental/__init__.py b/imperative/python/megengine/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imperative/python/megengine/experimental/autograd.py b/imperative/python/megengine/experimental/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8b5d2548f3fdfa1e90528c05d631078cbe651b --- /dev/null +++ b/imperative/python/megengine/experimental/autograd.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ..core._imperative_rt.core2 import ( + set_allow_higher_order_directive as _set_allow_higher_order_directive, +) + +__all__ = [ + "enable_higher_order_directive", + "disable_higher_order_directive", +] + + +def enable_higher_order_directive(): + _set_allow_higher_order_directive(True) + + +def disable_higher_order_directive(): + _set_allow_higher_order_directive(False) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index e47f733d8bcaa09f7a7de0e2baad6631465c47f6..898919ccc48525cf52e885d2065021712a0ae998 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -271,7 +271,7 @@ struct GradFn : std::enable_shared_from_this { pool.free(ptr); } - std::shared_ptr make() { + static std::shared_ptr make() { return std::shared_ptr(pool.alloc(), &deleter); } @@ -316,14 +316,18 @@ public: apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { // copy inputs first, or trace will make InputNodes for each usage + ApplyContext ctx_dup = ctx; SmallVector> inputs_copy; SmallVector inputs_copy_weak; for (size_t i = 0; i < ctx.nargs; ++i) { - inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); + Tensor* input = ctx.args[i]; + inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); inputs_copy_weak.push_back(inputs_copy.back().get()); inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; + if (input->m_flags & Flags::GRAD) { + inputs_copy.back()->m_flags |= Flags::GRAD; + } } - ApplyContext ctx_dup = ctx; ctx_dup.args = inputs_copy_weak.data(); auto outputs = apply(ctx_dup); @@ -332,7 +336,6 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra if (!backward_graph) { return outputs; } - ret_grad_fn.emplace(std::move(backward_graph), ctx_dup, outputs); return outputs; @@ -389,6 +392,12 @@ apply_result_t apply_grad(ApplyContext& ctx) { if (grad_keys.empty()) { return apply(ctx); + } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { + PyErr_SetString( + PyExc_NotImplementedError, + "second order directive not enabled, please call " + "'megengine.experimental.enable_higher_order_directive'"); + throw pyext17::py_err_set(); } GradFnHelper grad_fn_holder; diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index 17f28dbbb5e28b118ec30c7e7dc5a67fc98ce688..c5068e207e93241e2614e4c619fccf05e816e80e 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -36,6 +36,7 @@ struct GradKey : std::enable_shared_from_this, NonCopyableObj { bool is_blocked() const { return priority < sm_min_priority; } + inline static bool allow_higher_order_directive = false; private: static int sm_min_priority; }; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index bec4cb77d642be507aa441510926e56be5ec288a..c4db334e7b0ed1678829b3289614ead36f4e4e27 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -990,6 +990,9 @@ void init_tensor(py::module m) { m.def("set_tracing", &set_tracing); m.def("unset_tracing", &unset_tracing); + m.def("set_allow_higher_order_directive", [](bool value){ + GradKey::allow_higher_order_directive = value; + }); } #undef MGE_PY_INTERFACE diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index f012c400bca33cbc4f3d7aaadf601b0daae0a799..ed598cedbe124b1a87d7721f41cdbe61bc979b1a 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -1,3 +1,10 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os import platform import sys @@ -9,6 +16,10 @@ import megengine.module from megengine import Parameter from megengine.core._imperative_rt.core2 import sync from megengine.distributed.helper import get_device_count_by_fork +from megengine.experimental.autograd import ( + disable_higher_order_directive, + enable_higher_order_directive, +) from megengine.jit import trace as _trace from megengine.module import Linear, Module @@ -34,3 +45,13 @@ def skip_distributed(request): platform.system() ) ) + + +@pytest.fixture(autouse=True) +def resolve_require_higher_order_directive(request): + marker = request.node.get_closest_marker("require_higher_order_directive") + if marker: + enable_higher_order_directive() + yield + if marker: + disable_higher_order_directive() diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 2f82f6c0b9b48ec066f3be3a0e63533a08ca6790..ddb6ad5bd1859eef54f22ff1de986eeb15cb6a4e 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -281,6 +281,7 @@ def test_broadcast_grad(trace_mode): worker() +@pytest.mark.require_higher_order_directive() def test_2nd_grad_with_manager(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -299,6 +300,7 @@ def test_2nd_grad_with_manager(): ) +@pytest.mark.require_higher_order_directive() def test_grad_manager_group(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -315,6 +317,7 @@ def test_grad_manager_group(): x.grad = None +@pytest.mark.require_higher_order_directive() def test_grad_manager_group_visibility(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -330,6 +333,7 @@ def test_grad_manager_group_visibility(): np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) +@pytest.mark.require_higher_order_directive() def test_grad_manager_visibility_by_order(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 9765ba88d94d895613149630e084c41d6740eef9..fa90ab40e62167d55445011897c7cc1b8e3ca855 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -108,6 +108,7 @@ def test_grad_2(): np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) +@pytest.mark.require_higher_order_directive() def test_2nd_grad(): x_np = np.random.rand(10).astype("float32") x = as_tensor(x_np)