未验证 提交 fbff4bf9 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

Add fused self attention (#5966)

* add fused_self_attention functor

* add fused_self_attention gradients

* add test case

* fix backward immpl bug

* fix bug

* fix bug

* fix backward bug

* fix comments

* code format

* fix comment

* fix ci error

* fix ci error

* add init FusedSelfAttentionInterpState struct
Co-authored-by: NYao Chi <later@usopp.net>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 dbdffe95
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedSelfAttentionInterpState : public AutoGradCaptureState {
bool input_requires_grad = false;
float alpha = 1.0;
};
class FusedSelfAttention : public OpExprGradFunction<FusedSelfAttentionInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedSelfAttentionInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<float>("alpha"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedSelfAttentionInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2);
in_grads->resize(1);
const auto& hidden_states = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& fused_self_attention_grad =
JUST(functional::FusedSelfAttentionGrad(out_grads.at(0), out_grads.at(1), hidden_states,
ctx->alpha));
in_grads->at(0) = fused_self_attention_grad;
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_self_attention_query_mul_key_and_value", FusedSelfAttention);
} // namespace one
} // namespace oneflow
......@@ -1069,6 +1069,14 @@
signature: "Tensor Randperm(Int32 n,*,Device device=None, Generator generator=None)"
bind_python: True
- name: "fused_self_attention"
signature: "TensorTuple FusedSelfAttention(Tensor hidden_states, *, Int64 head_size=8, Float alpha=1.0)"
bind_python: True
- name: "fused_self_attention_grad"
signature: "Tensor FusedSelfAttentionGrad(Tensor query_mul_key_grad, Tensor value_grad, Tensor hidden_states, *, Float alpha=1.0)"
bind_python: False
- name: "consistent_randperm"
signature: "Tensor ConsistentRandperm(Int32 n,*, Placement placement, SbpList sbp_tuple, Generator generator=None)"
bind_python: True
......
......@@ -673,6 +673,51 @@ class L2NormalizeFunctor {
std::shared_ptr<OpExpr> op_;
};
class FusedSelfAttentionFunctor {
public:
FusedSelfAttentionFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fused_self_attention_query_mul_key_and_value")
.Input("hidden_states")
.Output("query_mul_key")
.Output("value")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& hidden_states,
const int64_t& head_size, const float& alpha) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("head_size", head_size));
JUST(attrs.SetAttr<float>("alpha", alpha));
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {hidden_states}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class FusedSelfAttentionGradFunctor {
public:
FusedSelfAttentionGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fused_self_attention_query_mul_key_and_value_grad")
.Input("query_mul_key_grad")
.Input("value_grad")
.Input("hidden_states")
.Output("hidden_states_grad")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query_mul_key_grad,
const std::shared_ptr<one::Tensor>& value_grad,
const std::shared_ptr<one::Tensor>& hidden_states,
const float& alpha) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<float>("alpha", alpha));
return OpInterpUtil::Dispatch<Tensor>(*op_, {query_mul_key_grad, value_grad, hidden_states},
attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class L2NormalizeGradFunctor {
public:
L2NormalizeGradFunctor() {
......@@ -764,6 +809,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::Avgpool2DFunctor>("Avgpool2D");
m.add_functor<impl::Avgpool3DFunctor>("Avgpool3D");
m.add_functor<impl::OneHotFunctor>("OneHot");
m.add_functor<impl::FusedSelfAttentionFunctor>("FusedSelfAttention");
m.add_functor<impl::FusedSelfAttentionGradFunctor>("FusedSelfAttentionGrad");
m.add_functor<impl::L2NormalizeFunctor>("L2Normalize");
m.add_functor<impl::L2NormalizeGradFunctor>("L2NormalizeGrad");
m.add_functor<impl::FusedScaleTrilFunctor>("FusedScaleTril");
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
from collections import OrderedDict
import numpy as np
from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
def test_fused_self_attention(test_case, batch_size, seq_len, num_heads, head_size):
hidden_size = num_heads * 3 * head_size
x = np.random.randn(seq_len, batch_size, hidden_size)
fused_input = flow.Tensor(x).to("cuda")
fused_input.requires_grad = True
(fused_qmk, fused_v) = flow.F.fused_self_attention(
fused_input, head_size=head_size, alpha=1.0,
)
fused_atten = flow.matmul(fused_qmk, fused_v)
fused_atten_sum = fused_atten.sum()
origin_input = flow.Tensor(x).to("cuda")
origin_input.requires_grad = True
reshape_input = flow.reshape(origin_input, (seq_len, batch_size, -1, 3 * head_size))
origin_q = flow.slice(
reshape_input,
slice_tup_list=[
[None, None, None],
[None, None, None],
[None, None, None],
[0, head_size, 1],
],
).permute(1, 2, 0, 3)
origin_k = flow.slice(
reshape_input,
slice_tup_list=[
[None, None, None],
[None, None, None],
[None, None, None],
[head_size, 2 * head_size, 1],
],
).permute(1, 2, 0, 3)
origin_v = flow.slice(
reshape_input,
slice_tup_list=[
[None, None, None],
[None, None, None],
[None, None, None],
[2 * head_size, 3 * head_size, 1],
],
).permute(1, 2, 0, 3)
origin_k = origin_k.transpose(2, 3)
origin_qmk = flow.matmul(origin_q, origin_k)
origin_atten = flow.matmul(origin_qmk, origin_v)
origin_atten_sum = origin_atten.sum()
total_sum = fused_atten_sum + origin_atten_sum
total_sum.backward()
test_case.assertTrue(
np.allclose(fused_atten.numpy(), origin_atten.numpy(), atol=1e-4, rtol=1e-4)
)
test_case.assertTrue(
np.allclose(
fused_input.grad.numpy(), origin_input.grad.numpy(), atol=1e-4, rtol=1e-4,
)
)
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestFusedSelfAttention(flow.unittest.TestCase):
def test_fused_self_attention(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [test_fused_self_attention]
arg_dict["batch_size"] = [1, 4, 6, 8]
arg_dict["seq_len"] = [5, 10, 12]
arg_dict["num_heads"] = [4, 8, 16]
arg_dict["head_size"] = [16, 32, 64]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册