Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9ce1f0f5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9ce1f0f5
编写于
1月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dispatch): implement grad
GitOrigin-RevId: d8367f9587093919c4dcb40361c7f91a9589f6c7
上级
c609c031
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
954 addition
and
0 deletion
+954
-0
imperative/src/impl/transformations/grad.cpp
imperative/src/impl/transformations/grad.cpp
+543
-0
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+411
-0
未找到文件。
imperative/src/impl/transformations/grad.cpp
0 → 100644
浏览文件 @
9ce1f0f5
/**
* \file imperative/src/impl/transformations/grad.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/graph_cache.h"
#include <range/v3/all.hpp>
namespace
mgb
{
namespace
imperative
{
static
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
make_optimized_backward_graph
(
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
,
Span
<
bool
>
inputs_require_grad
)
{
// hash
using
OptimizedBackwardGraphCache
=
OpMethResultCache
<
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
,
SmallVector
<
bool
>>
;
thread_local
auto
cache
=
std
::
make_unique
<
OptimizedBackwardGraphCache
>
();
OptimizedBackwardGraphCache
::
key_t
cache_key
{
op
};
SmallVector
<
LogicalTensorDesc
>&
input_descs
=
cache_key
.
inputs
;
std
::
get
<
0
>
(
cache_key
.
extras
)
=
inputs_require_grad
.
copy_into
<
SmallVector
<
bool
>>
();
input_descs
.
resize
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
input_descs
[
i
].
layout
.
dtype
=
inputs
[
i
].
dtype
().
cast
<
DTypeValue
>
();
input_descs
[
i
].
comp_node
=
inputs
[
i
].
device
().
cast
<
CompNodeValue
>
();
}
auto
iter
=
cache
->
find
(
cache_key
);
if
(
iter
!=
cache
->
end
())
{
return
iter
->
second
;
}
// slow path
SmallVector
<
bool
>
output_has_grad
(
outputs
.
size
(),
true
);
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
ret
;
auto
bg
=
OpDef
::
make_backward_graph
(
*
op
,
input_descs
,
std
::
get
<
0
>
(
cache_key
.
extras
),
output_has_grad
);
if
(
!
bg
.
graph
.
empty
())
{
ret
=
std
::
make_shared
<
OptimizedBackwardGraphResult
>
(
bg
);
}
cache
->
emplace
(
cache_key
,
ret
);
return
ret
;
}
BackwardGraphWithClosure
::
BackwardGraphWithClosure
(
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
backward_graph
,
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
)
:
backward_graph
(
backward_graph
),
output_mask_offset
(
inputs
.
size
()),
grad_mask_offset
(
inputs
.
size
()
+
outputs
.
size
())
{
auto
&
save_for_backward
=
backward_graph
->
save_for_backward
;
mgb_assert
(
save_for_backward
.
size
()
==
inputs
.
size
()
+
2
*
outputs
.
size
());
size_t
count
=
std
::
count_if
(
save_for_backward
.
begin
(),
save_for_backward
.
end
(),
ranges
::
identity
{});
if
(
!
backward_graph
->
precomp
.
empty
())
{
SmallVector
<
ValueRef
>
inputs_and_outputs
;
for
(
auto
&&
input
:
inputs
)
{
inputs_and_outputs
.
push_back
(
input
);
}
for
(
auto
&&
output
:
outputs
)
{
inputs_and_outputs
.
push_back
(
output
);
}
auto
precomp
=
imperative
::
apply
(
backward_graph
->
precomp
,
inputs_and_outputs
);
closure
.
reserve
(
precomp
.
size
()
+
count
);
std
::
copy
(
precomp
.
begin
(),
precomp
.
end
(),
std
::
back_inserter
(
closure
));
}
else
{
closure
.
reserve
(
count
);
}
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
save_for_backward
[
i
])
{
closure
.
push_back
(
inputs
[
i
]);
}
}
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
save_for_backward
[
inputs
.
size
()
+
i
])
{
closure
.
push_back
(
outputs
[
i
]);
}
}
}
void
BackwardGraphWithClosure
::
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
ValueRef
args
[
closure
.
size
()
+
grads
.
size
()];
size_t
nargs
=
0
;
for
(
auto
&&
value
:
closure
)
{
args
[
nargs
++
]
=
value
;
}
bool
null_grad
=
false
;
for
(
size_t
i
=
0
;
i
<
grads
.
size
();
++
i
)
{
if
(
backward_graph
->
save_for_backward
[
grad_mask_offset
+
i
])
{
if
(
grads
[
i
])
{
mgb_assert
(
!
null_grad
,
"null_grad"
);
args
[
nargs
++
]
=
grads
[
i
];
}
else
{
null_grad
=
true
;
}
}
}
if
(
null_grad
)
{
return
;
}
auto
igrads
=
imperative
::
apply
(
backward_graph
->
backward
,
Span
(
args
,
nargs
));
auto
&&
iter
=
igrads
.
begin
();
for
(
auto
[
i
,
p
]
:
ranges
::
views
::
enumerate
(
backward_graph
->
input_has_grad
))
{
if
(
p
)
{
receiver
(
i
,
std
::
move
(
*
iter
));
++
iter
;
}
}
}
void
CustomBackward
::
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
)
{
size_t
nargs
=
grads
.
size
();
ValueRef
args
[
nargs
];
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
args
[
i
]
=
grads
[
i
];
}
auto
ret
=
m_backward
({
args
,
nargs
});
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
auto
&&
t
=
ret
[
i
])
{
receiver
(
i
,
std
::
move
(
t
));
}
}
}
std
::
string
GradSlot
::
to_string
()
const
{
bool
has_callback
=
bool
(
callback
);
return
ssprintf
(
"GradSlot{grad=%s, has_callback=%d}"
,
m_grad
.
to_string
().
c_str
(),
(
int
)
has_callback
);
}
std
::
string
GradFn
::
to_string
()
const
{
return
ssprintf
(
"GradFn{dests=%s}"
,
imperative
::
to_string
(
m_dests
).
c_str
());
}
std
::
string
GradSlotPtr
::
to_string
()
const
{
if
(
!
m_fn
)
{
return
"<empty>"
;
}
return
(
*
this
)
->
to_string
();
}
std
::
string
GradValue
::
to_string
()
const
{
return
ssprintf
(
"GradValue{key=
\"
%s
\"
, slot=%s, value=%s}"
,
m_key
->
name
().
c_str
(),
m_slot
.
to_string
().
c_str
(),
m_value
.
to_string
().
c_str
());
}
static
std
::
unordered_map
<
Typeinfo
*
,
CustomBackward
::
BackwardRule
>&
get_backward_rule_storage
()
{
static
std
::
unordered_map
<
Typeinfo
*
,
CustomBackward
::
BackwardRule
>
sl_storage
;
return
sl_storage
;
}
bool
CustomBackward
::
register_grad_rule
(
Typeinfo
*
typeinfo
,
BackwardRule
rule
)
{
return
get_backward_rule_storage
().
insert
({
typeinfo
,
rule
}).
second
;
}
auto
CustomBackward
::
lookup_grad_rule
(
Typeinfo
*
typeinfo
)
->
BackwardRule
{
auto
iter
=
get_backward_rule_storage
().
find
(
typeinfo
);
if
(
iter
==
get_backward_rule_storage
().
end
())
{
return
{};
}
return
iter
->
second
;
}
void
GradKey
::
backward
()
{
mgb_assert
(
m_frozen
);
auto
&
tape
=
m_frozen_tape
;
for
(
std
::
ptrdiff_t
k
=
tape
.
size
()
-
1
;
k
>=
0
;
--
k
)
{
auto
&
[
grad_fn
,
op
]
=
tape
[
k
];
auto
grad_receiver
=
[
&
,
grad_fn
=
grad_fn
](
size_t
i
,
ValueRef
grad
)
{
auto
&
dest
=
grad_fn
->
m_dests
[
i
];
if
(
dest
)
{
auto
&
existing_grad
=
dest
->
m_grad
;
if
(
!
existing_grad
)
{
existing_grad
=
grad
;
}
else
{
existing_grad
=
imperative
::
apply
(
ApplyOp
(
*
Elemwise
::
make
(
Elemwise
::
Mode
::
ADD
)),
existing_grad
,
grad
)[
0
];
}
}
};
// clang-format off
std
::
visit
([
&
,
grad_fn
=
grad_fn
,
op
=
op
](
auto
&&
backward
)
{
using
T
=
std
::
decay_t
<
decltype
(
backward
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
mgb_throw
(
AssertionError
,
"invalid backward"
);
}
else
{
mgb_assert
(
grad_fn
->
m_slots
.
size
()
>
0
);
std
::
vector
<
ValueRef
>
grads
;
for
(
auto
&&
slot
:
grad_fn
->
m_slots
)
{
grads
.
push_back
(
slot
.
m_grad
);
}
backward
(
grads
,
grad_receiver
);
}
},
grad_fn
->
m_backward
);
// clang-format on
for
(
auto
&&
dest
:
grad_fn
->
m_dests
)
{
if
(
!
dest
)
{
continue
;
}
if
(
!
dest
.
m_producer_record
.
next
&&
dest
->
callback
&&
dest
->
m_grad
)
{
// I'm the last grad producer, invoke callback
dest
->
callback
(
dest
->
m_grad
);
}
}
grad_fn
->
clear
();
}
tape
.
clear
();
}
GradValue
::
ref_t
GradKey
::
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
)
{
auto
grad_value
=
tensor
.
as_ref
<
GradValue
>
();
if
(
grad_value
&&
grad_value
->
has_key
(
shared_from_this
()))
{
mgb_assert
(
!
tensor
.
cast
<
GradValue
>
().
slot_for
(
shared_from_this
())
->
callback
,
"callback exists"
);
}
else
{
GradSlotPtr
grad_slot
;
auto
&
grad_fn
=
grad_slot
.
m_fn
;
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
->
m_key
=
shared_from_this
();
grad_fn
->
m_slots
.
resize
(
1
);
grad_slot
.
m_index
=
0
;
grad_value
=
GradValue
::
make
(
tensor
,
shared_from_this
(),
grad_slot
);
}
grad_value
->
slot_for
(
shared_from_this
()).
m_fn
->
m_slots
[
0
].
callback
=
callback
;
return
grad_value
;
}
void
GradKey
::
freeze
()
{
mgb_assert
(
m_frozen_tape
.
empty
()
&&
!
m_frozen
);
for
(
auto
&&
[
grad_fn
,
op
]
:
m_tape
)
{
if
(
auto
valid_grad_fn
=
grad_fn
.
lock
())
{
m_frozen_tape
.
push_back
({
valid_grad_fn
,
op
});
}
}
m_tape
.
clear
();
m_frozen
=
true
;
}
std
::
vector
<
ValueRef
>
GradTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
unwrap_inputs
=
[
this
](
Span
<
ValueRef
>
inputs
)
->
SmallVector
<
ValueRef
>
{
SmallVector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
unwrapped_inputs
.
push_back
(
grad_value
->
m_value
);
}
else
{
unwrapped_inputs
.
push_back
(
input
);
}
}
return
unwrapped_inputs
;
};
if
(
m_suppressed
)
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
if
(
auto
*
op_val
=
op
.
as
<
ApplyOp
>
())
{
size_t
nr_require_grad
=
0
;
SmallVector
<
bool
>
require_grads
;
for
(
auto
&&
input
:
inputs
)
{
if
(
is_grad_value
(
input
))
{
nr_require_grad
++
;
require_grads
.
push_back
(
true
);
}
else
{
require_grads
.
push_back
(
false
);
}
}
if
(
nr_require_grad
==
0
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
SmallVector
<
ValueRef
>
captured_inputs
;
SmallVector
<
bool
>
inputs_require_grad
;
// capture value so that trace could assume input as same
auto
capture_value
=
[](
ValueRef
value
)
{
// TODO: fastpath copy shouldn't be an OpDef
return
imperative
::
apply
(
ApplyOp
(
*
FastpathCopy
::
make
()),
{
&
value
,
1
})[
0
];
};
for
(
auto
&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
captured_inputs
.
push_back
(
capture_value
(
grad_value
->
m_value
));
inputs_require_grad
.
push_back
(
true
);
}
else
{
captured_inputs
.
push_back
(
capture_value
(
input
));
inputs_require_grad
.
push_back
(
false
);
}
}
decltype
(
std
::
declval
<
GradFn
>
().
m_backward
)
backward_storage
;
auto
outputs
=
[
&
]
{
auto
backward_rule
=
CustomBackward
::
lookup_grad_rule
(
op_val
->
op
().
dyn_typeinfo
());
if
(
backward_rule
)
{
CustomBackward
backward
;
auto
optional_outputs
=
backward_rule
(
op_val
->
op
(),
{
captured_inputs
.
data
(),
captured_inputs
.
size
()},
{
inputs_require_grad
.
data
(),
inputs_require_grad
.
size
()},
backward
);
if
(
optional_outputs
)
{
backward_storage
=
backward
;
// backward by rule
return
*
optional_outputs
;
}
}
auto
outputs
=
imperative
::
apply
(
op
,
{
captured_inputs
.
begin
(),
captured_inputs
.
end
()});
auto
backward_graph
=
make_optimized_backward_graph
(
op
.
cast
<
ApplyOp
>
().
op
().
shared_from_this
(),
{
captured_inputs
.
begin
(),
captured_inputs
.
end
()},
{
outputs
.
data
(),
outputs
.
size
()},
{
inputs_require_grad
.
data
(),
inputs_require_grad
.
size
()});
if
(
backward_graph
)
{
backward_storage
=
BackwardGraphWithClosure
(
backward_graph
,
op
.
cast
<
ApplyOp
>
().
op
().
shared_from_this
(),
{
captured_inputs
.
begin
(),
captured_inputs
.
end
()},
{
outputs
.
data
(),
outputs
.
size
()});
// backward by make_backward_graph
return
outputs
;
}
else
{
// no backward
return
outputs
;
}
}();
if
(
std
::
holds_alternative
<
std
::
monostate
>
(
backward_storage
))
{
return
outputs
;
}
auto
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
->
m_key
=
m_key
;
grad_fn
->
m_slots
.
resize
(
outputs
.
size
());
grad_fn
->
m_backward
=
backward_storage
;
mgb_assert
(
!
outputs
.
empty
());
grad_fn
->
m_dests
.
reserve
(
inputs
.
size
());
// clang-format off
std
::
visit
([
&
](
auto
&
backward
)
{
using
T
=
std
::
decay_t
<
decltype
(
backward
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
mgb_throw
(
AssertionError
,
"invalid backward"
);
}
else
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
backward
.
input_has_grad
(
i
)
&&
require_grads
[
i
])
{
auto
&
input_grad_slot
=
inputs
[
i
].
cast
<
GradValue
>
().
slot_for
(
m_key
);
grad_fn
->
m_dests
.
emplace_back
(
input_grad_slot
);
grad_fn
->
m_dests
.
back
().
m_producer_record
.
insert_after
(
input_grad_slot
->
m_producer_head
);
}
else
{
grad_fn
->
m_dests
.
emplace_back
();
}
}
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
backward
.
output_requires_grad
(
i
))
{
auto
grad_value
=
GradValue
::
make
(
outputs
[
i
],
m_key
,
GradSlotPtr
{
grad_fn
,
i
});
outputs
[
i
]
=
record_grad
(
grad_value
);
}
}
}
},
grad_fn
->
m_backward
);
// clang-format on
mgb_assert
(
!
grad_fn
->
m_slots
.
empty
());
m_key
->
m_tape
.
push_back
({
grad_fn
,
op_val
->
op
().
shared_from_this
()});
return
outputs
;
}
else
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
if
(
!
has_key
(
attach_grad
->
key
()))
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
auto
tensor
=
inputs
[
0
];
GenericFunction
callback
=
(
GenericFunction
&
)
inputs
[
1
].
cast
<
FunctionValue
>
();
auto
output
=
attach_grad
->
key
()
->
attach
(
tensor
,
[
callback
](
ValueRef
grad
)
{
auto
ret
=
callback
({
&
grad
,
1
});
assert
(
ret
.
empty
());
});
return
{
record_grad
(
output
)};
}
else
if
(
auto
*
grad_backward
=
op
.
as
<
GradBackward
>
())
{
if
(
!
has_key
(
grad_backward
->
key
()))
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
size_t
nr_grads
=
inputs
.
size
()
/
2
;
mgb_assert
(
nr_grads
*
2
==
inputs
.
size
());
auto
values
=
inputs
.
sub
(
0
,
nr_grads
);
auto
grads
=
inputs
.
sub
(
nr_grads
,
nr_grads
);
make_backward_closure
(
values
)(
grads
);
return
{};
}
else
if
(
auto
*
is_attached_to
=
op
.
as
<
IsAttachedTo
>
())
{
if
(
has_key
(
is_attached_to
->
key
()))
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
// TODO: assert grad_fn
return
{
BoolValue
::
make
(
true
)};
}
}
return
{
BoolValue
::
make
(
false
)};
}
else
if
(
auto
*
set_grad
=
op
.
as
<
SetGrad
>
())
{
// TODO: merge SetGrad and ApplyOp
auto
grad_fn
=
std
::
make_shared
<
GradFn
>
();
auto
&
backward
=
std
::
get
<
CustomBackward
>
(
grad_fn
->
m_backward
=
CustomBackward
());
size_t
nr_inputs
=
set_grad
->
nr_inputs
();
mgb_assert
(
inputs
.
size
()
>
nr_inputs
);
size_t
nr_outputs
=
inputs
.
size
()
-
nr_inputs
;
Span
<
ValueRef
>
inputs_
=
{
inputs
.
data
(),
nr_inputs
};
Span
<
ValueRef
>
outputs_
=
{
inputs
.
data
()
+
nr_inputs
,
nr_outputs
};
backward
.
m_input_has_grad
=
SmallVector
(
nr_inputs
,
true
);
backward
.
m_output_attrs
=
SmallVector
(
nr_outputs
,
CustomBackward
::
OutputAttr
{
true
,
true
});
backward
.
m_backward
=
set_grad
->
grad_fn
();
std
::
vector
<
ValueRef
>
outputs
;
grad_fn
->
m_key
=
m_key
;
grad_fn
->
m_slots
.
resize
(
nr_outputs
);
grad_fn
->
m_dests
.
reserve
(
nr_inputs
);
for
(
size_t
i
=
0
;
i
<
nr_inputs
;
++
i
)
{
if
(
auto
grad_value
=
as_grad_value
(
inputs_
[
i
]))
{
auto
&
input_grad_slot
=
grad_value
->
m_slot
;
grad_fn
->
m_dests
.
emplace_back
(
grad_value
->
m_slot
);
grad_fn
->
m_dests
.
back
().
m_producer_record
.
insert_after
(
input_grad_slot
->
m_producer_head
);
}
else
{
grad_fn
->
m_dests
.
emplace_back
();
}
}
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
auto
&
output
=
outputs_
[
i
];
auto
grad_value
=
as_grad_value
(
output
);
if
(
grad_value
)
{
grad_value
=
GradValue
::
make
(
grad_value
->
m_value
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
}
else
{
grad_value
=
GradValue
::
make
(
output
,
m_key
,
GradSlotPtr
(
grad_fn
,
i
));
}
outputs
.
push_back
(
record_grad
(
grad_value
));
}
m_key
->
m_tape
.
push_back
({
grad_fn
,
nullptr
});
return
outputs
;
}
else
if
(
auto
*
gbc
=
op
.
as
<
GetBackwardColsure
>
())
{
if
(
gbc
->
key
()
!=
m_key
)
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
return
{
FunctionValue
::
make
(
make_backward_closure
(
inputs
))};
}
else
if
(
op
.
is
<
DetachGrad
>
())
{
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
return
{
grad_value
->
m_value
};
}
else
{
return
{
inputs
[
0
]};
}
}
else
if
(
op
.
is
<
GetGradKey
>
())
{
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
return
{
GradKeyValue
::
make
(
grad_value
->
m_key
)};
}
}
return
imperative
::
apply
(
op
,
inputs
);
}
else
if
(
op
.
kind
()
==
Operator
::
IdentityLike
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
grad_value
=
as_grad_value
(
inputs
[
0
]))
{
auto
output
=
imperative
::
apply
(
op
,
grad_value
->
m_value
)[
0
];
auto
grad_output
=
GradValue
::
make
(
output
,
grad_value
->
key
(),
grad_value
->
slot_for
(
m_key
));
return
{
record_grad
(
grad_output
)};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
return
imperative
::
apply
(
op
,
inputs
);
}
else
{
SmallVector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
grad_value
=
as_grad_value
(
input
))
{
unwrapped_inputs
.
push_back
(
grad_value
->
m_value
);
}
else
{
unwrapped_inputs
.
push_back
(
input
);
}
}
auto
outputs
=
imperative
::
apply
(
op
,
{
unwrapped_inputs
.
data
(),
unwrapped_inputs
.
size
()});
mgb_assert
(
op
.
kind
()
==
Operator
::
GetAttrLike
||
outputs
.
empty
());
return
outputs
;
}
}
GenericFunction
GradTransformation
::
make_backward_closure
(
Span
<
ValueRef
>
ys
)
{
// reset GradKey
auto
grad_key
=
m_key
;
std
::
vector
<
GradSlotPtr
>
y_slots
;
for
(
auto
&&
y
:
ys
)
{
if
(
auto
grad_value
=
as_grad_value
(
y
))
{
y_slots
.
push_back
(
grad_value
->
slot_for
(
grad_key
));
}
else
{
y_slots
.
emplace_back
();
}
}
GenericFunction
closure
=
[
grad_key
,
y_slots
](
Span
<
ValueRef
>
dys
)
->
std
::
vector
<
ValueRef
>
{
size_t
nr_grads
=
y_slots
.
size
();
mgb_assert
(
dys
.
size
()
==
nr_grads
);
for
(
size_t
i
=
0
;
i
<
nr_grads
;
++
i
)
{
if
(
y_slots
[
i
])
{
y_slots
[
i
]
->
m_grad
=
dys
[
i
];
}
}
grad_key
->
backward
();
return
{};
};
grad_key
->
freeze
();
cleanup
();
return
closure
;
}
void
GradTransformation
::
on_unregister
()
noexcept
{
cleanup
();
}
void
GradTransformation
::
cleanup
()
{
for
(
auto
&&
weak_value
:
m_weak_values
)
{
auto
grad_value
=
weak_value
.
lock
();
if
(
grad_value
)
{
mgb_assert
(
grad_value
->
m_key
==
m_key
);
grad_value
.
reset
(
grad_value
->
m_value
);
}
}
m_weak_values
.
clear
();
m_key
=
{};
}
void
GradTransformation
::
suppress
()
{
m_suppressed
++
;
}
void
GradTransformation
::
resume
()
{
m_suppressed
--
;
}
}
// namespace imperative
}
// namespace mgb
imperative/src/include/megbrain/imperative/transformations/grad.h
0 → 100644
浏览文件 @
9ce1f0f5
/**
* \file imperative/src/include/megbrain/imperative/grad.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <variant>
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/intrusive_list.h"
#include "megbrain/imperative/utils/to_string.h"
namespace
mgb
::
imperative
{
struct
BackwardGraphWithClosure
{
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
backward_graph
;
SmallVector
<
ValueRef
>
closure
;
size_t
output_mask_offset
;
size_t
grad_mask_offset
;
BackwardGraphWithClosure
(
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
backward_graph
,
std
::
shared_ptr
<
OpDef
>
op
,
Span
<
ValueRef
>
inputs
,
Span
<
ValueRef
>
outputs
);
void
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
bool
input_has_grad
(
size_t
i
)
{
return
backward_graph
->
input_has_grad
[
i
];
}
bool
output_requires_grad
(
size_t
i
)
{
return
backward_graph
->
save_for_backward
[
grad_mask_offset
+
i
];
}
bool
output_captured
(
size_t
i
)
{
return
backward_graph
->
save_for_backward
[
output_mask_offset
+
i
];
}
};
struct
CustomBackward
;
using
GradRuleFn
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
Span
<
ValueRef
>
inputs
,
CustomBackward
&
)
>
;
struct
CustomBackward
{
using
BackwardFn
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
Span
<
ValueRef
>
)
>
;
using
BackwardRule
=
std
::
function
<
std
::
optional
<
std
::
vector
<
ValueRef
>>
(
const
OpDef
&
,
Span
<
ValueRef
>
,
Span
<
bool
>
,
CustomBackward
&
)
>
;
BackwardFn
m_backward
;
SmallVector
<
bool
,
8
>
m_input_has_grad
;
struct
OutputAttr
{
bool
requires_grad
=
true
,
captured
=
true
;
};
SmallVector
<
OutputAttr
>
m_output_attrs
;
public:
void
operator
()(
std
::
vector
<
ValueRef
>
grads
,
std
::
function
<
void
(
size_t
,
ValueRef
)
>
receiver
);
bool
input_has_grad
(
size_t
i
)
{
return
m_input_has_grad
[
i
];
}
bool
output_requires_grad
(
size_t
i
)
{
return
m_output_attrs
[
i
].
requires_grad
;
}
bool
output_captured
(
size_t
i
)
{
return
m_output_attrs
[
i
].
captured
;
}
static
bool
register_grad_rule
(
Typeinfo
*
typeinfo
,
BackwardRule
rule
);
static
BackwardRule
lookup_grad_rule
(
Typeinfo
*
typeinfo
);
};
class
GradSlot
;
class
GradSlotPtr
;
class
GradSlotProducerPtr
;
class
GradFn
;
class
GradKey
;
struct
GradProducerRecord
:
utils
::
intrusive_list
::
Node
<
GradProducerRecord
>
{
using
Node
=
utils
::
intrusive_list
::
Node
<
GradProducerRecord
>
;
GradProducerRecord
()
=
default
;
GradProducerRecord
(
head_t
&
head
)
:
Node
(
utils
::
intrusive_list
::
after_t
{},
head
)
{}
};
class
GradSlot
{
private:
ValueRef
m_grad
;
GradProducerRecord
::
head_t
m_producer_head
;
std
::
function
<
void
(
ValueRef
)
>
callback
;
public:
std
::
string
to_string
()
const
;
friend
class
GradKey
;
friend
class
GradSlotProducerPtr
;
friend
class
GradTransformation
;
};
template
<
>
struct
ToStringTrait
<
GradSlot
>
{
std
::
string
operator
()(
const
GradSlot
&
value
)
const
{
return
value
.
to_string
();
}
};
class
GradFn
{
private:
std
::
weak_ptr
<
GradKey
>
m_key
;
std
::
vector
<
GradSlot
>
m_slots
;
std
::
vector
<
GradSlotProducerPtr
>
m_dests
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
CustomBackward
>
m_backward
;
public:
void
clear
()
{
m_key
.
reset
();
m_slots
.
clear
();
m_dests
.
clear
();
m_backward
.
emplace
<
std
::
monostate
>
();
}
std
::
string
to_string
()
const
;
friend
class
GradSlotPtr
;
friend
class
GradKey
;
friend
class
GradTransformation
;
};
class
GradSlotPtr
{
private:
std
::
shared_ptr
<
GradFn
>
m_fn
;
size_t
m_index
=
0
;
public:
GradSlotPtr
(
std
::
shared_ptr
<
GradFn
>
fn
,
size_t
index
)
:
m_fn
(
fn
),
m_index
(
index
)
{}
GradSlotPtr
()
=
default
;
GradSlot
*
operator
->
()
const
{
return
&
m_fn
->
m_slots
[
m_index
];
}
operator
bool
()
const
{
return
bool
(
m_fn
);
}
std
::
string
to_string
()
const
;
friend
class
GradKey
;
friend
class
GradTransformation
;
};
template
<
>
struct
ToStringTrait
<
GradSlotPtr
>
{
std
::
string
operator
()(
const
GradSlotPtr
&
value
)
const
{
return
value
.
to_string
();
}
};
class
GradSlotProducerPtr
:
public
GradSlotPtr
{
private:
GradProducerRecord
m_producer_record
;
bool
dirty
=
false
;
public:
GradSlotProducerPtr
(
const
GradSlotPtr
&
info
)
:
GradSlotPtr
(
info
),
m_producer_record
(
info
->
m_producer_head
)
{}
GradSlotProducerPtr
()
=
default
;
GradSlotProducerPtr
(
GradSlotProducerPtr
&&
)
=
default
;
~
GradSlotProducerPtr
()
{
dirty
=
true
;
}
friend
class
GradKey
;
friend
class
GradTransformation
;
};
template
<
>
struct
ToStringTrait
<
GradSlotProducerPtr
>
{
std
::
string
operator
()(
const
GradSlotProducerPtr
&
value
)
const
{
return
value
.
to_string
();
}
};
class
GradValue
final
:
public
ValueImpl
<
GradValue
>
{
private:
ValueRef
m_value
;
std
::
shared_ptr
<
GradKey
>
m_key
;
GradSlotPtr
m_slot
;
public:
GradValue
(
ValueRef
value
,
std
::
shared_ptr
<
GradKey
>
key
,
GradSlotPtr
slot
=
{})
:
m_value
(
value
),
m_key
(
key
),
m_slot
(
slot
)
{}
std
::
string
to_string
()
const
override
;
bool
has_key
(
std
::
shared_ptr
<
GradKey
>
key
)
const
{
return
m_key
==
key
;
}
const
GradSlotPtr
&
slot_for
(
std
::
shared_ptr
<
GradKey
>
key
)
const
{
mgb_assert
(
m_key
==
key
);
return
m_slot
;
}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
void
clear
()
override
{
m_slot
=
{};
m_value
=
{};
m_key
=
nullptr
;
}
void
on_watch
()
override
{
m_value
.
watch
();
}
void
on_unwatch
()
override
{
m_value
.
unwatch
();
}
friend
class
GradKey
;
friend
class
GradTransformation
;
};
class
GradKey
:
public
std
::
enable_shared_from_this
<
GradKey
>
{
private:
std
::
string
m_name
;
std
::
vector
<
std
::
pair
<
std
::
weak_ptr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_tape
;
std
::
vector
<
std
::
pair
<
std
::
shared_ptr
<
GradFn
>
,
std
::
shared_ptr
<
OpDef
>>>
m_frozen_tape
;
bool
m_frozen
=
false
;
public:
void
backward
();
GradValue
::
ref_t
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
);
const
std
::
string
&
name
()
const
{
return
m_name
;
}
void
name
(
std
::
string
name
)
{
m_name
=
std
::
move
(
name
);
}
void
freeze
();
friend
class
GradTransformation
;
};
class
GradKeyValue
final
:
public
MixinValueImpl
<
GradKeyValue
,
std
::
shared_ptr
<
GradKey
>>
{
public:
using
MixinValueImpl
::
MixinValueImpl
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GradKey{%s}"
,
(
*
this
)
->
name
().
c_str
());
}
};
class
GradTransformation
final
:
public
Transformation
{
private:
std
::
shared_ptr
<
GradKey
>
m_key
;
std
::
vector
<
GradValue
::
weak_ref_t
>
m_weak_values
;
size_t
m_suppressed
=
0
;
public:
GradTransformation
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
auto
record_grad
(
GradValue
::
ref_t
tensor
)
{
m_weak_values
.
push_back
(
tensor
);
return
tensor
;
}
bool
is_grad_value
(
ValueRef
value
)
{
if
(
auto
*
grad_value
=
value
.
as
<
GradValue
>
())
{
if
(
grad_value
->
has_key
(
m_key
))
{
return
true
;
}
}
return
false
;
}
/**
* \brief test whether value is related to this GradTransformation
*
* there may be multiple grad transformations, so simply using value.is<GradValue>()
* is unsafe
*
* \param value
* \return GradValue::ref_t
*/
GradValue
::
ref_t
as_grad_value
(
ValueRef
value
)
{
if
(
auto
grad_value
=
value
.
as_ref
<
GradValue
>
())
{
if
(
grad_value
->
has_key
(
m_key
))
{
return
grad_value
;
}
}
return
{};
}
bool
has_key
(
std
::
shared_ptr
<
GradKey
>
key
)
{
if
(
key
==
m_key
)
{
return
true
;
}
return
false
;
}
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
if
(
auto
grad_val
=
as_grad_value
(
value
))
{
return
grad_val
->
m_value
;
}
return
value
;
}
std
::
string
name
()
const
override
{
return
"GradTransformation"
;
}
GenericFunction
make_backward_closure
(
Span
<
ValueRef
>
ys
);
void
on_unregister
()
noexcept
override
;
void
cleanup
();
void
suppress
();
void
resume
();
};
class
DetachGrad
:
public
OperatorImpl
<
DetachGrad
,
Operator
::
IdentityLike
>
{
private:
// TODO: identified by GradKey
public:
std
::
string
to_string
()
const
override
{
return
"DetachValue"
;
}
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
inputs
.
as_array
<
1
>
()[
0
]};
}
};
class
AttachGrad
:
public
OperatorImpl
<
AttachGrad
>
{
private:
std
::
shared_ptr
<
GradKey
>
m_key
;
public:
AttachGrad
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"AttachGradValue{key=%s}"
,
m_key
->
name
().
c_str
());
}
};
class
GradBackward
:
public
OperatorImpl
<
GradBackward
,
Operator
::
GetAttrLike
>
{
private:
std
::
shared_ptr
<
GradKey
>
m_key
;
public:
GradBackward
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GradBackwardValue{key=%s}"
,
m_key
->
name
().
c_str
());
}
};
class
IsAttachedTo
:
public
OperatorImpl
<
IsAttachedTo
,
Operator
::
GetAttrLike
>
{
private:
std
::
shared_ptr
<
GradKey
>
m_key
;
public:
IsAttachedTo
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"IsAttachedToValue{key=%s}"
,
m_key
->
name
().
c_str
());
}
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
BoolValue
::
make
(
false
)};
}
};
class
SetGrad
:
public
OperatorImpl
<
SetGrad
>
{
private:
std
::
shared_ptr
<
GradKey
>
m_key
;
GenericFunction
m_grad_fn
;
size_t
m_nr_inputs
;
public:
SetGrad
(
std
::
shared_ptr
<
GradKey
>
key
,
GenericFunction
grad_fn
,
size_t
nr_inputs
)
:
m_key
(
key
),
m_grad_fn
(
grad_fn
),
m_nr_inputs
(
nr_inputs
)
{}
GenericFunction
grad_fn
()
{
return
m_grad_fn
;
}
size_t
nr_inputs
()
{
return
m_nr_inputs
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"SetGradValue{key=%s}"
,
m_key
->
name
().
c_str
());
}
};
class
GetGradKey
:
public
OperatorImpl
<
GetGradKey
,
Operator
::
GetAttrLike
>
{
public:
GetGradKey
()
=
default
;
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GetGradKeyValue{}"
);
}
std
::
vector
<
ValueRef
>
fallback
(
Span
<
ValueRef
>
inputs
)
const
override
{
return
{
ValueRef
()};
}
};
class
GetBackwardColsure
:
public
OperatorImpl
<
GetBackwardColsure
,
Operator
::
GetAttrLike
>
{
private:
std
::
shared_ptr
<
GradKey
>
m_key
;
public:
GetBackwardColsure
(
std
::
shared_ptr
<
GradKey
>
key
)
:
m_key
(
key
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
{
return
m_key
;
}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"GetBackwardClosure{key=%s}"
,
m_key
->
name
().
c_str
());
}
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录