提交 46f03b9f 编写于 作者: M Megvii Engine Team

perf(mge/imperative): optimize gpu memory with standalone grad_fn

GitOrigin-RevId: 44952b235ebe213e168c050a3b20636efee53d5d
上级 bd73dabb
......@@ -12,9 +12,26 @@ import itertools
import numpy as np
from .._imperative_rt import TensorAttr, imperative
from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape
from ..ops.builtin import (
Broadcast,
Elemwise,
GetVarShape,
IndexingMultiAxisVec,
IndexingSetMultiAxisVec,
OpDef,
OprAttr,
Reduce,
Reshape,
SetSubtensor,
Subtensor,
)
from ..ops.special import Const
from ..tensor.core import apply
from ..tensor.function import Function
from ..tensor.tensor_wrapper import TensorWrapper
_elemwise_add_param = Elemwise(mode="add").to_c().param
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0]
@functools.singledispatch
......@@ -22,19 +39,17 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):
assert 0
_elemwise_add_param = Elemwise(mode="add").to_c().param
@builtin_op_get_backward_fn.register(OpDef)
def _(op: OpDef, inputs, outputs, input_requires_grad):
if (
isinstance(op, OprAttr)
and op.type == "Elemwise"
and op.param == _elemwise_add_param
):
grad_fn = elemwise_grad_fn
elif isinstance(op, OprAttr) and op.type == Reshape.name:
grad_fn = reshape_grad_fn
if isinstance(op, OprAttr):
grad_fn = _oprAttr_grad_fn.get(op.type, None)
if grad_fn is None:
if op.type == Elemwise.name and op.param == _elemwise_add_param:
grad_fn = elemwise_add_grad_fn
elif op.type == Reduce.name and op.param[0] == _reduce_sum_param:
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
else:
grad_fn = default_grad_fn
return grad_fn(op, inputs, outputs, input_requires_grad)
......@@ -73,6 +88,7 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad):
save_for_backward = tuple(
val for val, mask in zip(inputs + outputs, intput_output_mask) if mask
)
del inputs
del outputs
......@@ -85,13 +101,14 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad):
return backward, output_grad_mask
# override for elemwise
def elemwise_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2
def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s
def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s
# override for Elemwise.add
def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2
input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
......@@ -110,13 +127,10 @@ def elemwise_grad_fn(op, inputs, outputs, input_requires_grad):
return backward, [True]
# override for Reshape
def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2
def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s
input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
]
......@@ -132,3 +146,78 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
)
return backward, [True]
# override for Subtensor
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr()
grad_op.type = SetSubtensor.name
grad_op.param = op.param
input_shape = get_shape(inputs[0])
params = inputs[1:]
def make_grad(grad_op, dy):
grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params)
return dx
def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)
return backward, [True]
# override for IndexingMultiAxisVec
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr()
grad_op.type = IndexingSetMultiAxisVec.name
grad_op.param = op.param
input_shape = get_shape(inputs[0])
params = inputs[1:]
def make_grad(grad_op, dy):
grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params)
return dx
def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)
return backward, [True]
# override for Reduce.sum
def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 1
input_shape = get_shape(inputs[0])
def broadcast_to(dy, s):
(dx,) = apply(Broadcast(), dy, s)
return dx
def backward(dy):
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,)
return backward, [True]
_oprAttr_grad_fn = {
Reshape.name: reshape_grad_fn,
Subtensor.name: subtensor_grad_fn,
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn,
Broadcast.name: elemwise_add_grad_fn,
}
......@@ -14,6 +14,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, imperative
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad
......@@ -229,3 +230,86 @@ def test_elemwise_relu_backward_fn():
result = imperative.make_backward_graph(op, [attr], [True], [True])
backward_graph, save_for_backward_mask, input_has_grad = result
assert save_for_backward_mask == [False, True, True], save_for_backward_mask
def test_reshape():
x_np = np.random.rand(2, 5).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.reshape(5, 2)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy())
def test_subtensor():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x[1:-1, :2]
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy()
)
def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x[[0, 2], [0, 2]]
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy()
)
def test_AxisAddRemove():
x_np = np.random.rand(1, 5).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.remove_axis(F.add_axis(x, 2), 0)
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 1, 1, 1, 1]], dtype=np.float32), x.grad.numpy()
)
def test_Broadcast():
x_np = np.random.rand(3, 3, 1).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast(x, (3, 3, 10))
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())
def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.sum(axis=0)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())
def test_Reduce_mean():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.mean(axis=0)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32) / 3, x.grad.numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册