From c49427d15a3ed519ce8a735eba84cc41e16b5471 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 2 Dec 2020 11:36:43 +0800 Subject: [PATCH] feat(imperative): add inplace add_update option in optimizer GitOrigin-RevId: b8feb493218e4a8e62cf5d8f14b13cfa12461dc0 --- .../python/megengine/functional/inplace.py | 15 ++ imperative/python/megengine/jit/tracing.py | 6 +- imperative/python/megengine/optimizer/adam.py | 61 ++++++-- .../python/megengine/optimizer/optimizer.py | 3 +- imperative/python/megengine/optimizer/sgd.py | 23 ++- imperative/python/megengine/tensor.py | 4 +- imperative/python/src/grad_override.cpp | 4 +- imperative/python/src/tensor.cpp | 11 +- .../test/integration/test_sgd_momentum.py | 18 ++- imperative/python/test/unit/test_tracing.py | 1 - imperative/src/impl/dnn_op_helper.h | 44 +++++- imperative/src/impl/interpreter_impl.cpp | 7 +- imperative/src/impl/interpreter_impl.h | 5 +- imperative/src/impl/ops/cond_take.cpp | 40 +----- imperative/src/impl/ops/elemwise.cpp | 133 ++++++++++++++++++ imperative/src/impl/proxy_graph_detail.cpp | 28 ++-- .../include/megbrain/imperative/interpreter.h | 2 +- .../megbrain/imperative/physical_tensor.h | 4 + .../megbrain/imperative/proxy_graph_detail.h | 4 + src/core/include/megbrain/ir/ops.td | 2 + src/opr/impl/basic_arith.cpp | 7 +- src/opr/impl/internal/identical_fwd.cpp | 16 +++ .../megbrain/opr/internal/identical_fwd.h | 4 +- 23 files changed, 337 insertions(+), 105 deletions(-) create mode 100644 imperative/python/megengine/functional/inplace.py diff --git a/imperative/python/megengine/functional/inplace.py b/imperative/python/megengine/functional/inplace.py new file mode 100644 index 000000000..ddd04a46c --- /dev/null +++ b/imperative/python/megengine/functional/inplace.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ..core._imperative_rt.core2 import apply +from ..core.ops import builtin +from ..core.ops.builtin import InplaceAdd + + +def _inplace_add_(dest, delta, alpha, beta): + return dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 7f714f545..67b18f3ff 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -502,6 +502,8 @@ class trace: # profile if self._profiling: self._profiler = GraphProfiler(graph) + if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")): + graph.options.var_sanity_check_first_run = False def _compile(self): graph = self._graph = G.Graph() @@ -1073,7 +1075,7 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): return active_trace._apply_op(op, args) -def apply_const_compiled_mode(value, dtype, device, is_const): +def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): if skip_tracing: args = [ RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x @@ -1099,7 +1101,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): return list(outputs) -def apply_const_with_tracing(value, dtype, device, is_const): +def apply_const_with_tracing(value, dtype, device, is_const, no_cache): if active_trace._symbolic: outputs = apply_const_symbolic_mode(value, dtype, device) else: diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 2d8b9f454..5e13ed77e 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -6,8 +6,10 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import os from typing import Iterable, Tuple, Union +from ..functional.inplace import _inplace_add_ from ..tensor import Parameter, tensor from .optimizer import Optimizer @@ -58,15 +60,24 @@ class Adam(Optimizer): eps = param_group["eps"] beta0, beta1 = param_group["betas"] + def make_scalar(val): + return tensor([val]) + # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor - _lr = tensor([lr]) - _weight_decay = tensor([weight_decay]) - _eps = tensor([eps]) - _beta0, _beta1 = tensor([beta0]), tensor([beta1]) - c1 = tensor([1.0]) - c05 = tensor([0.5]) + _lr, _neg_lr = map(make_scalar, (lr, -lr)) + _weight_decay = make_scalar(weight_decay) + _eps = make_scalar(eps) + _beta0, _beta1 = map(make_scalar, (beta0, beta1)) + + c1, c05 = map(make_scalar, (1.0, 0.5)) + + inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) + if inplace_mode: + # reduce device sync + c1_sub_beta0, c1_sub_beta1 = map(make_scalar, (1 - beta0, 1 - beta1)) + for param in param_group["params"]: if param.grad is None: @@ -77,18 +88,38 @@ class Adam(Optimizer): grad += param * _weight_decay states = self._state[param] - step = states["step"] + + step, exp_avg, exp_avg_sq = ( + states["step"], + states["exp_avg"], + states["exp_avg_sq"], + ) + + if inplace_mode: + _inplace_add_(step, c1, alpha=c1, beta=c1) + _inplace_add_(exp_avg, grad, alpha=_beta0, beta=c1_sub_beta0) + _inplace_add_( + exp_avg_sq, grad * grad, alpha=_beta1, beta=c1_sub_beta1, + ) + + delta = (exp_avg / (c1 - _beta0 ** step)) / ( + (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps + ) + _inplace_add_(param, delta, alpha=c1, beta=_neg_lr) + continue + + # step = step + c1 step += c1 - exp_avg = states["exp_avg"] - exp_avg_sq = states["exp_avg_sq"] - exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) - exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) + + # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) + exp_avg *= _beta0 + exp_avg += grad * (c1 - _beta0) + + # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) + exp_avg_sq *= _beta1 + exp_avg_sq += (c1 - _beta1) * (grad * grad) delta = (exp_avg / (c1 - _beta0 ** step)) / ( (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps ) param -= _lr * delta - - # not inplace change, need to update underlying tensor handler in state - states["exp_avg"]._reset(exp_avg) - states["exp_avg_sq"]._reset(exp_avg_sq) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 3db988e64..f5135229e 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -96,6 +96,7 @@ class Optimizer(metaclass=ABCMeta): "optimizer can only optimize Parameters, but one of the params is " + str(type(param)) ) + param._reset(Tensor(param.numpy(), no_cache=True)) for name, default in self._defaults.items(): if default is required and name not in param_group: @@ -121,7 +122,7 @@ class Optimizer(metaclass=ABCMeta): initializer = np.zeros(param.shape, dtype=np.float32) state_dict = self._state.setdefault(param, {}) assert state_name not in state_dict - state = Tensor(initializer) + state = Tensor(initializer, no_cache=True) state_dict[state_name] = state @abstractmethod diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 5a9e9ac30..0d61b13b0 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -6,8 +6,10 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import os from typing import Iterable, Union +from ..functional.inplace import _inplace_add_ from ..tensor import Parameter, tensor from .optimizer import Optimizer @@ -54,10 +56,16 @@ class SGD(Optimizer): # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor + _lr = tensor([lr]) _weight_decay = tensor([weight_decay]) _momentum = tensor([momentum]) + inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) + if inplace_mode: + _neg_lr = tensor([-lr]) + c1 = tensor([1.0]) + for param in param_group["params"]: if param.grad is None: continue @@ -66,10 +74,21 @@ class SGD(Optimizer): if weight_decay != 0.0: grad += param * _weight_decay + if inplace_mode: + if momentum: + v = self._state[param]["momentum_buffer"] + _inplace_add_(v, grad, alpha=_momentum, beta=c1) + _inplace_add_(param, v, alpha=c1, beta=_neg_lr) + else: + _inplace_add_(param, grad, alpha=c1, beta=_neg_lr) + continue + if momentum: v = self._state[param]["momentum_buffer"] - v = _momentum * v + grad + # v = v * _momentum + grad + v *= _momentum + v += grad + param -= _lr * v - self._state[param]["momentum_buffer"]._reset(v) else: param -= _lr * grad diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index fb263b2fd..c893a2eb4 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin): dmap_callback = None q_dict = {"mode": None, "scale": None, "zero_point": None} - def __new__(cls, data, dtype=None, device=None, is_const=False): + def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): if device is None: cn = get_default_device() elif isinstance(device, str): @@ -49,7 +49,7 @@ class Tensor(_Tensor, ArrayMethodMixin): if 0 in data.strides: data = data.squeeze().reshape(data.shape) - obj = _Tensor.__new__(cls, data, dtype, cn, is_const) + obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache) return obj @property diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index dbebddf4b..7befb282f 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -38,9 +38,9 @@ std::shared_ptr broadcast_to(Tensor* x, Tensor* s) { std::shared_ptr make_tensor(CompNode cn, Tensor* shape, float v = 0) { HostTensorND scalar{cn, {{1}, dtype::Float32()}}; scalar.ptr()[0] = v; - interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar); + interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); auto&& t = std::make_shared(handle); - auto&& res = broadcast_to(t.get(), shape); + auto res = broadcast_to(t.get(), shape); return res; } diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 1b5bc4157..1d96f1c10 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -231,13 +231,14 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } } else { py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType + if (nargs != 4 && nargs != 5) { + throw py::type_error("expect 4 or 5 arguments"); + } auto data = tup[0].cast(); DType dtype = tup[1].cast(); CompNode cn = tup[2].cast(); bool is_const = tup[3].cast(); - if (nargs != 4) { - throw py::type_error("expect 3 arguments"); - } + bool no_cache = nargs == 5 ? tup[4].cast() : false; // const op if (is_const && is_tracing) { @@ -259,10 +260,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { interpreter::Interpreter::Handle handle; constexpr auto size_threshhold = TensorShape::MAX_NDIM; if (data.size() > size_threshhold) { - handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); + handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype), no_cache); } else { HostTensorND ret(cn); - handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); + handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache); } m_tensor = std::make_shared(handle); diff --git a/imperative/python/test/integration/test_sgd_momentum.py b/imperative/python/test/integration/test_sgd_momentum.py index cf395fd24..385dac24b 100644 --- a/imperative/python/test/integration/test_sgd_momentum.py +++ b/imperative/python/test/integration/test_sgd_momentum.py @@ -6,6 +6,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import itertools +import os + import numpy as np import megengine @@ -58,13 +61,16 @@ def test_sgd_momentum(): np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) np.testing.assert_almost_equal( - optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 + optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5 ) def test_sgd_momentum_trace(): - - for symbolic in (True, False): + origin_inplace = os.getenv("MEGENGINE_INPLACE_UPDATE") + symbolic = (True, False) + inplace = (0, 1) + for symbolic, inplace in itertools.product(symbolic, inplace): + os.environ["MEGENGINE_INPLACE_UPDATE"] = str(inplace) @trace(symbolic=symbolic) def train_func(data, *, model=None, optim=None, gm=None): @@ -101,5 +107,9 @@ def test_sgd_momentum_trace(): train_func(data, model=net, optim=optim, gm=gm) np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) np.testing.assert_almost_equal( - optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 + optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5 ) + if origin_inplace: + os.environ["MEGENGINE_INPLACE_UPDATE"] = origin_inplace + else: + del os.environ["MEGENGINE_INPLACE_UPDATE"] diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index e0a4fa521..4eaeb3da2 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -325,7 +325,6 @@ def test_raise_on_trace(): @trace def add_abc(a, b, c): - print("Hello") ps = a + b result = ps + c if step_count == bad_step: diff --git a/imperative/src/impl/dnn_op_helper.h b/imperative/src/impl/dnn_op_helper.h index 69d36f651..81222c83b 100644 --- a/imperative/src/impl/dnn_op_helper.h +++ b/imperative/src/impl/dnn_op_helper.h @@ -11,6 +11,7 @@ #include "megbrain/comp_node_env.h" #include "megbrain/comp_node.h" +#include "megbrain/imperative/physical_tensor.h" using namespace megdnn; @@ -29,19 +30,21 @@ struct DnnOprCaller { Workspace workspace; std::unique_ptr op; - DnnOprCaller(CompNode cn): cn(cn) { + DnnOprCaller(CompNode cn): cn(cn), op(create_operator(cn)) {} + + static std::unique_ptr create_operator(CompNode cn) { auto&& handle = MegDNNHandle::get( CompNodeEnv::from_comp_node(cn)).handle(); - op = handle->create_operator(); + return handle->create_operator(); } megdnn::Workspace create_workspace(TensorLayout layout) { dev_tensor = Tensor::make(layout, cn)->dev_tensor(); - workspace = megdnn::Workspace(dev_tensor.raw_ptr(), + workspace = megdnn::Workspace(dev_tensor.raw_ptr(), dev_tensor.storage().size()); return workspace; } - + ~DnnOprCaller() { using DT = CompNode::DeviceType; if (cn.device_type() == DT::CPU && cn != CompNode::default_cpu()) { @@ -52,5 +55,36 @@ struct DnnOprCaller { } }; +template +class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { + using Output = std::array; + + CompNode m_cn; + Output m_out; + +public: + MegDNNDynOutMallocImpl(CompNode cn): m_cn{cn} {} + + megdnn::TensorND alloc_output( + size_t id, DType dtype, const TensorShape &shape, + void *user_data) override { + TensorLayout m_layout(shape, dtype); + m_out[id] = Tensor::make(m_layout, m_cn); + return m_out[id]->dev_tensor().as_megdnn(); + } + + void* alloc_workspace(size_t sz, void *user_data) override { + return m_cn.alloc_device(sz); + } + + void free_workspace(void *ptr, void *user_data) override { + m_cn.free_device(ptr); + } + + TensorPtr at(size_t id) { + return m_out[id]; + } +}; + } // namespace imperative -} // namespace mgb \ No newline at end of file +} // namespace mgb diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 223ce678b..66de1a0fb 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -28,13 +28,13 @@ Interpreter& Interpreter::inst() { return inst_; } -void* ChannelImpl::put(const HostTensorND& value) { +void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { auto info = alloc(); info->desc.layout = value.layout(); info->desc.comp_node = value.comp_node(); info->desc.value = value.proxy_to_default_cpu(); m_valid_handle.insert(info); - m_worker.add_task(Put{info, value}); + m_worker.add_task(Put{info, value, no_cache}); return info; } @@ -395,7 +395,8 @@ void ChannelImpl::process_one_task(Command& cmd) { using T = std::remove_reference_t; try { if constexpr (std::is_same_v) { - produce_tensor(cmd.dest, Tensor::make(cmd.value)); + auto value = cmd.no_cache ? std::make_shared(cmd.value) : Tensor::make(cmd.value); + produce_tensor(cmd.dest, std::move(value)); } else if constexpr (std::is_same_v) { SmallVector tensor_inputs; tensor_inputs.reserve(cmd.inputs.size()); diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 5a8c974a4..5a73466a4 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -45,7 +45,7 @@ struct TensorInfo { HostTensorND h_value; size_t locked = 0; size_t recompute_times = 0; - + struct ComputePath { std::shared_ptr op; SmallVector inputs; @@ -57,6 +57,7 @@ struct TensorInfo { struct Put { TensorInfo* dest; HostTensorND value; + bool no_cache = false; }; struct ApplyOp { std::shared_ptr op; @@ -92,7 +93,7 @@ struct ChannelImpl : Interpreter::Channel { ChannelImpl() : m_worker(this) {} ~ChannelImpl() override; - Handle put(const HostTensorND& value) override; + Handle put(const HostTensorND& value, bool no_cache) override; Handle put(const DeviceTensorND& value) override; void del(Handle) override; diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 70f55ce81..4044078ce 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -20,44 +20,6 @@ namespace mgb::imperative { namespace { -class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { - using Output = std::array; - - CompNode m_cn; - Output m_out; - - public: - MegDNNDynOutMallocImpl(CompNode cn): m_cn{cn} {} - - megdnn::TensorND alloc_output( - size_t id, DType dtype, const TensorShape &shape, - void *user_data) override; - - void* alloc_workspace(size_t sz, void *user_data) override; - void free_workspace(void *ptr, void *user_data) override; - TensorPtr at(size_t id); -}; - -megdnn::TensorND MegDNNDynOutMallocImpl::alloc_output( - size_t id, DType dtype, const TensorShape &shape, - void * /*user_data*/) { - TensorLayout m_layout(shape, dtype); - m_out[id] = Tensor::make(m_layout, m_cn); - return m_out[id]->dev_tensor().as_megdnn(); -} - -void* MegDNNDynOutMallocImpl::alloc_workspace(size_t sz, void * /*user_data*/) { - return m_cn.alloc_device(sz); -} - -void MegDNNDynOutMallocImpl::free_workspace(void *ptr, void * /*user_data*/) { - m_cn.free_device(ptr); -} - -TensorPtr MegDNNDynOutMallocImpl::at(size_t id) { - return m_out[id]; -} - cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { @@ -94,7 +56,7 @@ SmallVector apply_on_physical_tensor( dtype::Byte()); auto dnn_workspace = dnn_op.create_workspace(m_layout); - MegDNNDynOutMallocImpl policy{inp->comp_node()}; + MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; dnn_op.op->exec(inp->dev_tensor().as_megdnn(), msk->dev_tensor().as_megdnn(), diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 4811150f9..b33b5cd14 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -11,8 +11,11 @@ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/basic_arith.h" +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/opr/utility.h" #include "../op_trait.h" +#include "../dnn_op_helper.h" namespace mgb { namespace imperative { @@ -84,12 +87,142 @@ SmallVector apply_on_physical_tensor( return {Tensor::make(out)}; } +MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT) //{ +public: + struct Param{ + using Mode = megdnn::Elemwise::Param::Mode; + Mode mode; + size_t inplace_index; + }; + using Mode = Param::Mode; + ForceInplaceElemwise(const VarNodeArray& inputs, Param param, + OperatorNodeConfig config = {}) + : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), m_param{param} { + for (auto* input: inputs) { + add_input({input}); + } + add_output(None)-> + set_fwd_in2out_writable_force(input(param.inplace_index)). + add_flag(VarNode::Flag::NO_MEM_RECLAIM); + } + static SymbolVar make(const VarNodeArray& inputs, Param param) { + return SymbolVar{inputs[0]}.insert_single_output_opr( + inputs, param); + } + static cg::OperatorNodeBase* shallow_copy( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config); +protected: + NodeProp* do_make_node_prop() const override { + auto ret = Super::do_make_node_prop(); + ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); + return ret; + } + void create_megdnn_opr() override { + auto opr = DnnOprCaller::create_operator(comp_node()); + opr->param().mode = m_param.mode; + set_megdnn_opr(std::move(opr)); + } + void scn_do_execute() override { + auto to_dnnnd = [&](auto* var){ return var->dev_tensor().as_megdnn(); }; + megdnn::TensorNDArray inputs_dnnnd; + for (auto* input: input()) { + inputs_dnnnd.push_back(to_dnnnd(input)); + } + mgb_assert(input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC), + "ForceInplaceElemwise cannot be applied in internal tensor"); + auto* out_dest = output(0); + auto* opr = static_cast(megdnn_opr()); + opr->exec(std::move(inputs_dnnnd), + to_dnnnd(out_dest)); + } + void init_output_static_infer_desc() override { + using namespace cg::static_infer; + + owner_graph()->static_infer_manager().register_shape_infer( + output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index))); + } +private: + Param m_param; + void record_execute_deps(ExecDependencyArray& deps) override { + record_megdnn_opr(deps); + } +}; + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise); + +cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config) { + auto &&opr = opr_.cast_final_safe(); + auto* graph = ctx.owner_graph(opr, inputs); + return graph->insert_opr(std::make_unique(inputs, opr.m_param, config)); +} + +MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy); + +cg::OperatorNodeBase* apply_inplace_add_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto dest = inputs[0], delta = inputs[1], + alpha = inputs[2], beta = inputs[3]; + auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4; + return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1}).node()->owner_opr(); +} + +SmallVector apply_inplace_add_on_physical_tensor( + const OpDef& def, + const SmallVector& inputs){ + auto dest = inputs[0], delta = inputs[1], + alpha = inputs[2], beta = inputs[3]; + auto tensor_to_scalar = [](const TensorPtr& tensor) -> float { + return *tensor->get_value().ptr(); + }; + DnnOprCaller caller{dest->comp_node()}; + caller.op->param() = { tensor_to_scalar(alpha), tensor_to_scalar(beta) }; + caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn()); + return { std::make_shared(dest->blob(), dest->offset(), dest->layout()) }; +} + +std::tuple, bool> infer_inplace_add_output_attrs_fallible( + const OpDef& def, + const SmallVector& inputs) { + mgb_assert(inputs.size() == 4, "invalid input number for inplace_add"); + CompNode cn; + for (auto&& input: inputs) { + if (!cn.valid()) { + cn = input.comp_node; + } else { + mgb_assert(input.comp_node == cn, "inputs should be in same comp_node"); + } + } + auto dest = inputs[0], delta = inputs[1], + alpha = inputs[2], beta = inputs[3]; + bool succeed = dest.layout.ndim != 0; + if (succeed) { + mgb_assert(delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout), "dest and delta must have same shape"); + mgb_assert(alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}), "alpha should be scalar"); + mgb_assert(beta.layout.ndim == 0 || beta.layout.eq_shape({1}), "beta should be scalar"); + } + mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32"); + mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32"); + return {{dest}, succeed}; +} + OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); + +OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) + .apply_on_var_node(apply_inplace_add_on_var_node) + .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) + .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) + .fallback(); } // anonymous namespace } // namespace imperative diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index b97facb16..6d7518db0 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -32,14 +32,22 @@ SmallVector to_raw_ptr_array( return ret; } +SmallVector +infer_output_attrs(const OpDef& def, + const SmallVector& inputs) { + auto&& graph = ProxyGraph::get_default_graph(); + return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); +} +} // anonymous namespace + void exec(const OpDef& def, - const SmallVector& inputs_, - const SmallVector& outputs_) { + const SmallVector& inputs, + const SmallVector& outputs) { auto&& graph = ProxyGraph::get_default_graph(); - auto inputs = to_raw_ptr_array(inputs_), - outputs = to_raw_ptr_array(outputs_); + auto raw_inputs = to_raw_ptr_array(inputs), + raw_outputs = to_raw_ptr_array(outputs); CompNode::UnorderedSet used_cns; - for (auto&& out: outputs) { + for (auto&& out: raw_outputs) { auto cn = out->comp_node(); if (used_cns.insert(cn).second) { for (auto&& in: inputs) { @@ -50,7 +58,7 @@ void exec(const OpDef& def, } } } - graph->invoke_op(def, inputs, outputs); + graph->invoke_op(def, raw_inputs, raw_outputs); for (auto&& cn: used_cns) { for (auto&& in: inputs) { if (in->comp_node() != cn) { @@ -60,14 +68,6 @@ void exec(const OpDef& def, } } -SmallVector -infer_output_attrs(const OpDef& def, - const SmallVector& inputs) { - auto&& graph = ProxyGraph::get_default_graph(); - return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); -} -} // anonymous namespace - SmallVector apply_on_physical_tensor(const OpDef& def, const SmallVector& inputs) { diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index ea9676146..4dc3d787f 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -21,7 +21,7 @@ struct Interpreter { struct Channel { virtual ~Channel() = default; - virtual Handle put(const HostTensorND& value) = 0; + virtual Handle put(const HostTensorND& value, bool no_cache) = 0; virtual Handle put(const DeviceTensorND& value) = 0; virtual void del(Handle) = 0; diff --git a/imperative/src/include/megbrain/imperative/physical_tensor.h b/imperative/src/include/megbrain/imperative/physical_tensor.h index 1b8e18297..86b712929 100644 --- a/imperative/src/include/megbrain/imperative/physical_tensor.h +++ b/imperative/src/include/megbrain/imperative/physical_tensor.h @@ -101,6 +101,10 @@ public: return m_layout; } + size_t offset() const { + return m_offset; + } + DeviceTensorND dev_tensor(); static TensorPtr make_scalar(DTypeScalar value, CompNode cn); diff --git a/imperative/src/include/megbrain/imperative/proxy_graph_detail.h b/imperative/src/include/megbrain/imperative/proxy_graph_detail.h index 2729f11fb..cecf00dfe 100644 --- a/imperative/src/include/megbrain/imperative/proxy_graph_detail.h +++ b/imperative/src/include/megbrain/imperative/proxy_graph_detail.h @@ -24,6 +24,10 @@ apply_on_physical_tensor(const OpDef& def, std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, const SmallVector& inputs); +void exec(const OpDef& def, + const SmallVector& inputs, + const SmallVector& outputs); + BackwardGraphResult make_backward_graph(const OpDef& def, const SmallVector& inputs, diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 141166c25..7c3834de1 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -239,4 +239,6 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara ); } +def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; + #endif // MGB_OPS diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index e176f58db..997564057 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -886,12 +886,9 @@ AddUpdate::AddUpdate(VarNode *dest, VarNode *delta, m_param{param} { auto dest_opr = dest->owner_opr(); - mgb_throw_if(!(dest_opr->same_type() || - dest_opr->same_type()), + mgb_throw_if(dest_opr->same_type(), GraphError, - "AddUpdate must be applied on SharedDeviceTensor; " - "got %s{%s} actually", - dest_opr->cname(), dest_opr->dyn_typeinfo()->name); + "AddUpdate cannot be applied on ImmutableTensor; "); add_input({dest, delta}); /* diff --git a/src/opr/impl/internal/identical_fwd.cpp b/src/opr/impl/internal/identical_fwd.cpp index 9d5dca911..8bbf576df 100644 --- a/src/opr/impl/internal/identical_fwd.cpp +++ b/src/opr/impl/internal/identical_fwd.cpp @@ -80,6 +80,22 @@ public: MGB_TYPEINFO_OBJ_IMPL(ForwardInputToOutput::MutableSrc); +void ForwardInputToOutput::mixin_init_rt_force_dynamic_mem_alloc_imply_chain( + OperatorNodeBase &opr) { + VarNode *valid_out = nullptr; + for (auto i: opr.output()) { + if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + mgb_assert(!valid_out); + valid_out = i; + } + } + mgb_assert(valid_out); + + // There may be many inputs such as in opr::VirtualDep, but we only forward first one + opr.input(0)->add_rt_force_dynamic_mem_alloc_imply_chain(valid_out); + valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0)); +} + void ForwardInputToOutput::mixin_mem_plan_fwd_in2out_readonly( OperatorNodeBase& opr) { m_mem_fwd_success = opr.output(0)->set_fwd_in2out_readonly( diff --git a/src/opr/include/megbrain/opr/internal/identical_fwd.h b/src/opr/include/megbrain/opr/internal/identical_fwd.h index df99d0bcf..57c900f3d 100644 --- a/src/opr/include/megbrain/opr/internal/identical_fwd.h +++ b/src/opr/include/megbrain/opr/internal/identical_fwd.h @@ -67,6 +67,7 @@ class ForwardInputToOutput: public cg::OperatorNodeMixinBase { virtual void mixin_scn_do_execute(OperatorNodeBase &opr); + void mixin_init_rt_force_dynamic_mem_alloc_imply_chain(OperatorNodeBase &opr); void mixin_mem_plan_fwd_in2out_readonly(OperatorNodeBase &opr); void mixin_init_output_static_infer_desc(OperatorNodeBase &opr); virtual cg::static_infer::ValueInferDesc mixin_get_static_infer_desc(OperatorNodeBase &opr); @@ -173,8 +174,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ForwardInputToOutput, protected: using Super::Super; void init_rt_force_dynamic_mem_alloc_imply_chain() override { - mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( - *this); + this->mixin_init_rt_force_dynamic_mem_alloc_imply_chain(*this); } void mem_plan_fwd_in2out_readonly() override { -- GitLab