提交 a892e5d0 编写于 作者: M Megvii Engine Team

perf(mge): add more specialized grad rules

GitOrigin-RevId: f88809a6d76bf2d565ae6607378cb41660203f84
上级 e9e5f442
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import itertools
import numpy as np
from .._imperative_rt import TensorAttr, imperative
from .._imperative_rt.core2 import apply
from ..ops.builtin import (
Broadcast,
Elemwise,
GetVarShape,
IndexingMultiAxisVec,
IndexingSetMultiAxisVec,
OpDef,
Reduce,
Reshape,
SetSubtensor,
Subtensor,
)
from ..ops.special import Const
def default_grad_fn(op, inputs, outputs, input_requires_grad):
def get_tensor_attr(x):
attr = TensorAttr()
attr.dtype = x.dtype
attr.comp_node = x.device.to_c()
return attr
output_has_grads = [True,] * len(outputs)
result = imperative.make_backward_graph(
op, list(map(get_tensor_attr, inputs)), input_requires_grad, output_has_grads
)
if result is None:
nr_inputs = len(inputs)
nr_outputs = len(outputs)
def backward(*args):
return nr_inputs * [
None,
]
return backward, nr_outputs * [False,]
backward_graph, save_for_backward_mask, input_has_grad = result
intput_output_mask = save_for_backward_mask[: len(inputs + outputs) :]
output_grad_mask = save_for_backward_mask[len(inputs + outputs) :]
save_for_backward = tuple(
val for val, mask in zip(inputs + outputs, intput_output_mask) if mask
)
del inputs
del outputs
def backward(*args):
output_grads = tuple(val for val, mask in zip(args, output_grad_mask) if mask)
assert None not in output_grads
ret = iter(apply(backward_graph, *(save_for_backward + output_grads)))
return tuple(next(ret) if mask else None for mask in input_has_grad)
return backward, output_grad_mask
def get_shape(x):
(s,) = apply(GetVarShape(), x._data)
return Tensor(s)
# override for Elemwise.add
def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2
input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
]
def reduce_to(x, s):
(y,) = apply(Reduce(), x, s)
return y
def backward(dy):
return tuple(
reduce_to(dy, s) if i else None
for i, s in zip(input_requires_grad, input_shapes)
)
return backward, [True]
# override for Reshape
def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 2
input_shapes = [
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs)
]
def reshape_to(dy, s):
(dx,) = apply(Reshape(), dy, s)
return dx
def backward(dy):
return tuple(
reshape_to(dy, s) if i else None
for i, s in zip(input_requires_grad, input_shapes)
)
return backward, [True]
# override for Subtensor
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = SetSubtensor(op.items)
input_shape = get_shape(inputs[0])
params = inputs[1:]
def make_grad(grad_op, dy):
(_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy)
(grad,) = apply(Broadcast(), _z, input_shape)
(dx,) = apply(grad_op, grad, dy, *params)
return dx
def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)
return backward, [True]
# override for IndexingMultiAxisVec
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = IndexingSetMultiAxisVec(op.items)
input_shape = get_shape(inputs[0])
params = inputs[1:]
def make_grad(grad_op, dy):
(_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy)
(grad,) = apply(Broadcast(), _z, input_shape)
(dx,) = apply(grad_op, grad, dy, *params)
return dx
def backward(dy):
return tuple(
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad
)
return backward, [True]
# override for Reduce.sum
def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad):
assert len(inputs) == len(input_requires_grad) == 1
input_shape = get_shape(inputs[0])
def broadcast_to(dy, s):
(dx,) = apply(Broadcast(), dy, s)
return dx
def backward(dy):
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,)
return backward, [True]
......@@ -19,7 +19,6 @@ import megengine as mge
from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from . import builtin_op_utils
""" Some notes:
1. Initialize the optimizer:
......
......@@ -25,6 +25,25 @@ std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
return python::apply(op, x, s)[0];
}
std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
static auto op = Reshape::make();
return python::apply(op, x, s)[0];
}
std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
static auto op = Broadcast::make();
return python::apply(op, x, s)[0];
}
std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
HostTensorND scalar{cn, {{1}, dtype::Float32()}};
scalar.ptr<float>()[0] = v;
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar);
auto&& t = std::make_shared<Tensor>(handle);
auto&& res = broadcast_to(t.get(), shape);
return res;
}
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
......@@ -52,10 +71,138 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
throw GradRuleFallback();
}
apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get());
}
}
return ret;
});
return apply(ctx);
}
apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<Subtensor>();
auto&& grad_op = SetSubtensor::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
apply_result_t ret(1);
if (inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
args_[i+1] = inputs[i].get();
}
ret[0] = python::apply(grad_op_, args_)[0];
}
return ret;
});
return apply(ctx);
}
apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
SmallVector<std::shared_ptr<Tensor>> inputs;
if (input_requires_grad(ctx, 0)) {
inputs.push_back(get_shape(ctx.args[0]));
for (size_t i = 1; i < ctx.nargs; ++i) {
inputs.push_back(ctx.args[i]->copy());
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
apply_result_t ret(1);
if (inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
args_[i+1] = inputs[i].get();
}
ret[0] = python::apply(grad_op_, args_)[0];
}
return ret;
});
return apply(ctx);
}
apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
mgb_assert(ctx.nargs == 1);
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
input_shapes[0] = get_shape(ctx.args[0]);
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get());
}
return ret;
});
return apply(ctx);
}
throw GradRuleFallback();
}
template<typename T, typename U>
apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<T>();
mgb_assert(ctx.nargs == 1);
auto&& grad_op = U::make(op.axis);
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
return ret;
});
return apply(ctx);
}
struct Init {
Init() {
auto& reg = grad_rule_registry();
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule<AddAxis, RemoveAxis>);
reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule<RemoveAxis, AddAxis>);
}
} _;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册