From fbff4bf9f6a936b54326a8e8fcd3c205411dd06e Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sun, 22 Aug 2021 01:56:51 +0800 Subject: [PATCH] Add fused self attention (#5966) * add fused_self_attention functor * add fused_self_attention gradients * add test case * fix backward immpl bug * fix bug * fix bug * fix backward bug * fix comments * code format * fix comment * fix ci error * fix ci error * add init FusedSelfAttentionInterpState struct Co-authored-by: Yao Chi Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../gradient_funcs/fused_self_attention.cpp | 69 ++++++++++++ oneflow/core/functional/functional_api.yaml | 8 ++ oneflow/core/functional/impl/nn_functor.cpp | 47 ++++++++ .../test/modules/test_fused_self_attention.py | 104 ++++++++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 oneflow/core/autograd/gradient_funcs/fused_self_attention.cpp create mode 100644 python/oneflow/test/modules/test_fused_self_attention.py diff --git a/oneflow/core/autograd/gradient_funcs/fused_self_attention.cpp b/oneflow/core/autograd/gradient_funcs/fused_self_attention.cpp new file mode 100644 index 0000000000..581e5f5cf9 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/fused_self_attention.cpp @@ -0,0 +1,69 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct FusedSelfAttentionInterpState : public AutoGradCaptureState { + bool input_requires_grad = false; + float alpha = 1.0; +}; + +class FusedSelfAttention : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(FusedSelfAttentionInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->input_requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->input_requires_grad) { return Maybe::Ok(); } + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const FusedSelfAttentionInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + if (!ctx->input_requires_grad) { return Maybe::Ok(); } + + CHECK_EQ_OR_RETURN(out_grads.size(), 2); + in_grads->resize(1); + const auto& hidden_states = ctx->SavedTensors().at(0); + const std::shared_ptr& fused_self_attention_grad = + JUST(functional::FusedSelfAttentionGrad(out_grads.at(0), out_grads.at(1), hidden_states, + ctx->alpha)); + in_grads->at(0) = fused_self_attention_grad; + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("fused_self_attention_query_mul_key_and_value", FusedSelfAttention); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 75a6c8750e..cda2615588 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1069,6 +1069,14 @@ signature: "Tensor Randperm(Int32 n,*,Device device=None, Generator generator=None)" bind_python: True +- name: "fused_self_attention" + signature: "TensorTuple FusedSelfAttention(Tensor hidden_states, *, Int64 head_size=8, Float alpha=1.0)" + bind_python: True + +- name: "fused_self_attention_grad" + signature: "Tensor FusedSelfAttentionGrad(Tensor query_mul_key_grad, Tensor value_grad, Tensor hidden_states, *, Float alpha=1.0)" + bind_python: False + - name: "consistent_randperm" signature: "Tensor ConsistentRandperm(Int32 n,*, Placement placement, SbpList sbp_tuple, Generator generator=None)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 6e8fb48282..9b18f6ac99 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -673,6 +673,51 @@ class L2NormalizeFunctor { std::shared_ptr op_; }; +class FusedSelfAttentionFunctor { + public: + FusedSelfAttentionFunctor() { + op_ = CHECK_JUST(one::OpBuilder("fused_self_attention_query_mul_key_and_value") + .Input("hidden_states") + .Output("query_mul_key") + .Output("value") + .Build()); + } + Maybe operator()(const std::shared_ptr& hidden_states, + const int64_t& head_size, const float& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("head_size", head_size)); + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {hidden_states}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class FusedSelfAttentionGradFunctor { + public: + FusedSelfAttentionGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("fused_self_attention_query_mul_key_and_value_grad") + .Input("query_mul_key_grad") + .Input("value_grad") + .Input("hidden_states") + .Output("hidden_states_grad") + .Build()); + } + Maybe operator()(const std::shared_ptr& query_mul_key_grad, + const std::shared_ptr& value_grad, + const std::shared_ptr& hidden_states, + const float& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {query_mul_key_grad, value_grad, hidden_states}, + attrs); + } + + private: + std::shared_ptr op_; +}; + class L2NormalizeGradFunctor { public: L2NormalizeGradFunctor() { @@ -764,6 +809,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Avgpool2D"); m.add_functor("Avgpool3D"); m.add_functor("OneHot"); + m.add_functor("FusedSelfAttention"); + m.add_functor("FusedSelfAttentionGrad"); m.add_functor("L2Normalize"); m.add_functor("L2NormalizeGrad"); m.add_functor("FusedScaleTril"); diff --git a/python/oneflow/test/modules/test_fused_self_attention.py b/python/oneflow/test/modules/test_fused_self_attention.py new file mode 100644 index 0000000000..ac8b3c487a --- /dev/null +++ b/python/oneflow/test/modules/test_fused_self_attention.py @@ -0,0 +1,104 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest +from collections import OrderedDict + +import numpy as np +from test_util import GenArgList + +import oneflow as flow +import oneflow.unittest + + +def test_fused_self_attention(test_case, batch_size, seq_len, num_heads, head_size): + hidden_size = num_heads * 3 * head_size + + x = np.random.randn(seq_len, batch_size, hidden_size) + fused_input = flow.Tensor(x).to("cuda") + fused_input.requires_grad = True + (fused_qmk, fused_v) = flow.F.fused_self_attention( + fused_input, head_size=head_size, alpha=1.0, + ) + fused_atten = flow.matmul(fused_qmk, fused_v) + fused_atten_sum = fused_atten.sum() + + origin_input = flow.Tensor(x).to("cuda") + origin_input.requires_grad = True + reshape_input = flow.reshape(origin_input, (seq_len, batch_size, -1, 3 * head_size)) + + origin_q = flow.slice( + reshape_input, + slice_tup_list=[ + [None, None, None], + [None, None, None], + [None, None, None], + [0, head_size, 1], + ], + ).permute(1, 2, 0, 3) + origin_k = flow.slice( + reshape_input, + slice_tup_list=[ + [None, None, None], + [None, None, None], + [None, None, None], + [head_size, 2 * head_size, 1], + ], + ).permute(1, 2, 0, 3) + origin_v = flow.slice( + reshape_input, + slice_tup_list=[ + [None, None, None], + [None, None, None], + [None, None, None], + [2 * head_size, 3 * head_size, 1], + ], + ).permute(1, 2, 0, 3) + + origin_k = origin_k.transpose(2, 3) + origin_qmk = flow.matmul(origin_q, origin_k) + origin_atten = flow.matmul(origin_qmk, origin_v) + origin_atten_sum = origin_atten.sum() + + total_sum = fused_atten_sum + origin_atten_sum + total_sum.backward() + + test_case.assertTrue( + np.allclose(fused_atten.numpy(), origin_atten.numpy(), atol=1e-4, rtol=1e-4) + ) + test_case.assertTrue( + np.allclose( + fused_input.grad.numpy(), origin_input.grad.numpy(), atol=1e-4, rtol=1e-4, + ) + ) + + +@flow.unittest.skip_unless_1n1d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestFusedSelfAttention(flow.unittest.TestCase): + def test_fused_self_attention(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [test_fused_self_attention] + arg_dict["batch_size"] = [1, 4, 6, 8] + arg_dict["seq_len"] = [5, 10, 12] + arg_dict["num_heads"] = [4, 8, 16] + arg_dict["head_size"] = [16, 32, 64] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main() -- GitLab