From c2b7c1f13b998a23dc9e1c8f18324b9114052782 Mon Sep 17 00:00:00 2001 From: Yuantao Feng Date: Sat, 11 Feb 2023 02:03:29 +0800 Subject: [PATCH] Merge pull request #23219 from fengyuentau:add_gelu Add GELU layer for vision transformers * add gelu and gelu approximation * drop setKernelParams --- .../dnn/include/opencv2/dnn/all_layers.hpp | 12 ++ modules/dnn/src/init.cpp | 2 + modules/dnn/src/layers/elementwise_layers.cpp | 67 +++++++ .../dnn/src/onnx/onnx_graph_simplifier.cpp | 179 ++++++++++++++++++ modules/dnn/src/onnx/onnx_importer.cpp | 3 +- modules/dnn/src/opencl/activations.cl | 24 +++ modules/dnn/test/test_onnx_importer.cpp | 6 + 7 files changed, 292 insertions(+), 1 deletion(-) diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 405f761060..49be0674f4 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -806,6 +806,18 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS GeluLayer : public ActivationLayer + { + public: + static Ptr create(const LayerParams ¶ms); + }; + + class CV_EXPORTS GeluApproximationLayer : public ActivationLayer + { + public: + static Ptr create(const LayerParams ¶ms); + }; + class CV_EXPORTS ThresholdedReluLayer : public ActivationLayer { public: diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index 72eca9ed4e..a62312375c 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -145,6 +145,8 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(HardSigmoid, HardSigmoidLayer); CV_DNN_REGISTER_LAYER_CLASS(Selu, SeluLayer); CV_DNN_REGISTER_LAYER_CLASS(ThresholdedRelu,ThresholdedReluLayer); + CV_DNN_REGISTER_LAYER_CLASS(Gelu, GeluLayer); + CV_DNN_REGISTER_LAYER_CLASS(GeluApproximation, GeluApproximationLayer); CV_DNN_REGISTER_LAYER_CLASS(BatchNorm, BatchNormLayer); CV_DNN_REGISTER_LAYER_CLASS(MaxUnpool, MaxUnpoolLayer); CV_DNN_REGISTER_LAYER_CLASS(Dropout, BlankLayer); diff --git a/modules/dnn/src/layers/elementwise_layers.cpp b/modules/dnn/src/layers/elementwise_layers.cpp index a4b71ddddf..9819073bc6 100644 --- a/modules/dnn/src/layers/elementwise_layers.cpp +++ b/modules/dnn/src/layers/elementwise_layers.cpp @@ -837,6 +837,57 @@ private: static const char* const ocl_kernel_name; }; +struct GeluFunctor : public BaseDefaultFunctor +{ + typedef GeluLayer Layer; + + explicit GeluFunctor() {} + + bool supportBackend(int backendId, int) + { + return backendId == DNN_BACKEND_OPENCV; + } + + inline float calculate(float x) const + { + return 0.5f * x * (1.0f + erf(x * M_SQRT1_2)); + } + + int64 getFLOPSPerElement() const { return 100; } +}; + +template<> +const char* const BaseDefaultFunctor::ocl_kernel_name = "GeluForward"; + +namespace GeluApproximationConstants +{ + static constexpr float sqrt_2_pi = 0.7978845834732056f; + static constexpr float coef_sqrt_2_pi = 0.044714998453855515f * sqrt_2_pi; +} + +struct GeluApproximationFunctor : public BaseDefaultFunctor +{ + typedef GeluApproximationLayer Layer; + + explicit GeluApproximationFunctor() {} + + bool supportBackend(int backendId, int) + { + return backendId == DNN_BACKEND_OPENCV; + } + + inline float calculate(float x) const + { + return 0.5f * x * (1.f + tanh(x * (GeluApproximationConstants::sqrt_2_pi + + GeluApproximationConstants::coef_sqrt_2_pi * x * x))); + } + + int64 getFLOPSPerElement() const { return 100; } +}; + +template<> +const char* const BaseDefaultFunctor::ocl_kernel_name = "GeluApproximationForward"; + struct TanHFunctor : public BaseDefaultFunctor { typedef TanHLayer Layer; @@ -2694,6 +2745,22 @@ Ptr ReLU6Layer::create(const LayerParams& params) return l; } +Ptr GeluLayer::create(const LayerParams& params) +{ + Ptr l(new ElementWiseLayer(GeluFunctor())); + l->setParamsFrom(params); + + return l; +} + +Ptr GeluApproximationLayer::create(const LayerParams& params) +{ + Ptr l(new ElementWiseLayer(GeluApproximationFunctor())); + l->setParamsFrom(params); + + return l; +} + Ptr TanHLayer::create(const LayerParams& params) { Ptr l(new ElementWiseLayer()); diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index c977a4761d..e9559a4c19 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -132,6 +132,183 @@ private: opencv_onnx::GraphProto& net; }; +/* Fusion for Gelu. + + Graph before fusion: + +---------------------------------------------+ + | | + [Input] -> Div[B=sqrt(2)] -> Erf -> Add[B=1] -> Mul -> Mul[B=0.5] -> [Output] + + Graph after fusion: + [Input] -> Gelu -> [Output] + +*/ +class GeluSubGraph : public Subgraph +{ +public: + GeluSubGraph() + { + int input = addNodeToMatch(""); + int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ ); + int erf = addNodeToMatch("Erf", div); + int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ ); + int mul = addNodeToMatch("Mul", input, add); + addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ; + + setFusedNode("Gelu", input); + } + + static bool isWithInitializer(const std::vector& matchedNodesIds) + { + // if node.getType() is Constant, Constant nodes are placed between other nodes + if (matchedNodesIds[2] - matchedNodesIds[1] != 1) + return false; + // if Initializer, there is no Constant node between other nodes + return true; + } + + static float extractConstant(const Ptr& net, int node_id, int input_id, bool withInitializer) + { + if (withInitializer) + { + auto onnx_net = net.dynamicCast(); + int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); + Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); + return *const_mat.ptr(); + } else { + const Ptr node = net->getNode(node_id); + int constant_id = getInputNodeId(net, node, input_id); + Ptr constant_ptr = net->getNode(constant_id); + opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; + opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); + Mat constant_mat = getMatFromTensor(constant_proto); + return *constant_mat.ptr(); + } + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds) CV_OVERRIDE + { + if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + { + bool withInitializer = isWithInitializer(matchedNodesIds); + + // Check Div[B=sqrt(2)] + float divisor = extractConstant(net, matchedNodesIds[0], 1, withInitializer); + if (divisor - M_SQRT2 >= 1e-6) + return false; + + // Check Add[B=1] + float add_const = extractConstant(net, matchedNodesIds[2], 1, withInitializer); + if (add_const - 1.f >= 1e-6) + return false; + + // Check Mul[B=0.5] + float mul_const = extractConstant(net, matchedNodesIds[4], 1, withInitializer); + if (mul_const - 0.5f >= 1e-6) + return false; + + return true; + } + return false; + } +}; + +/* Fusion for GeluApproximation. + + Graph before fusion: + +--------+------+----------------+------------------------------------+ + | | | | | + [Input] -> Mul -> Mul -> Mul[ ] -> Add -> Mul[ ] -> Tanh -> Add[A=1] -> Mul -> Mul(A=0.5) -> [Output] + / \ + A=0.044714998453855515 A=sqrt(2/pie) + + Graph after fusion: + [Input] -> GeluApproximation -> [Output] + +*/ +class GeluApproximationSubGraph : public Subgraph +{ +public: + GeluApproximationSubGraph() + { + int input = addNodeToMatch(""); + int mul0 = addNodeToMatch("Mul", input, input); + int mul1 = addNodeToMatch("Mul", input, mul0); + int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1); + int add0 = addNodeToMatch("Add", input, mul2); + int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0); + int tanh = addNodeToMatch("Tanh", mul3); + int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh); + int mul4 = addNodeToMatch("Mul", input, add1); + addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4); + + setFusedNode("GeluApproximation", input); + } + + static bool isWithInitializer(const std::vector& matchedNodesIds) + { + // if node.getType() is Constant, Constant nodes are placed between other nodes + if (matchedNodesIds[2] - matchedNodesIds[1] != 1) + return false; + // if Initializer, there is no Constant node between other nodes + return true; + } + + static float extractConstant(const Ptr& net, int node_id, int input_id, bool withInitializer) + { + if (withInitializer) + { + auto onnx_net = net.dynamicCast(); + int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); + Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); + return *const_mat.ptr(); + } else { + const Ptr node = net->getNode(node_id); + int constant_id = getInputNodeId(net, node, input_id); + Ptr constant_ptr = net->getNode(constant_id); + opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; + opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); + Mat constant_mat = getMatFromTensor(constant_proto); + return *constant_mat.ptr(); + } + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds) CV_OVERRIDE + { + if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + { + bool withInitializer = isWithInitializer(matchedNodesIds); + + // Check Mul[A=0.044714998453855515] + float coef = extractConstant(net, matchedNodesIds[2], 0, withInitializer); + if (coef - 0.044714998453855515 >= 1e-6) + return false; + + // Check Mul[A=sqrt(2/pie)] + float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0, withInitializer); + if (sqrt_2_pie - 0.7978845834732056 >= 1e-6) + return false; + + // Check Add[A=1] + float add_const = extractConstant(net, matchedNodesIds[6], 0, withInitializer); + if (add_const - 1.f >= 1e-6) + return false; + + // Check Mul[A=0.5] + float mul_const = extractConstant(net, matchedNodesIds[8], 0, withInitializer); + if (mul_const - 0.5f >= 1e-6) + return false; + + return true; + } + return false; + } +}; + class LayerNormSubGraph : public Subgraph { public: @@ -904,6 +1081,8 @@ public: void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 48a75f728e..307a05ef4b 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -4051,7 +4051,8 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) std::vector simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos", "Cosh", "Dropout", "Erf", "Exp", "Floor", "HardSigmoid", "HardSwish", "Identity", "Log", "Round", "Reciprocal", "Selu", "Sign", "Sigmoid", "Sin", "Sinh", "Softmax", - "Softplus", "Softsign", "Shrink", "Sqrt", "Tan", "ThresholdedRelu"}; + "Softplus", "Softsign", "Shrink", "Sqrt", "Tan", "ThresholdedRelu", "Gelu", + "GeluApproximation"}; for (const auto& name : simpleLayers) { dispatch[name] = &ONNXImporter::parseSimpleLayers; diff --git a/modules/dnn/src/opencl/activations.cl b/modules/dnn/src/opencl/activations.cl index 0624f48e19..317d2c1e62 100644 --- a/modules/dnn/src/opencl/activations.cl +++ b/modules/dnn/src/opencl/activations.cl @@ -307,6 +307,30 @@ __kernel void ThresholdedReluForward(const int n, __global T* in, __global T* ou out[index] = (in[index] > alpha ? in[index] : 0.f); } +__kernel void GeluForward(const int n, __global T* in, __global T* out) +{ + int index = get_global_id(0); + if (index < n) + { + T x = in[index]; + out[index] = (T)0.5f * x * ( (T)1.f + erf(x * M_SQRT1_2) ); + } +} + +__kernel void GeluApproximationForward(const int n, __global T* in, __global T* out) +{ + // see GeluApproximationConstants from modules/dnn/src/layers/elementwise_layers.cpp + const T sqrt_2_pi = 0.7978845834732056f; + const T coef_sqrt_2_pi = 0.044714998453855515f * sqrt_2_pi; + + int index = get_global_id(0); + if(index < n) + { + T x = in[index]; + out[index] = (T)0.5f * x * ( (T)1.f + tanh(x * (sqrt_2_pi + coef_sqrt_2_pi * x * x)) ); + } +} + __kernel void ShrinkForward(const int n, __global T* in, __global T* out, const KERNEL_ARG_DTYPE bias, const KERNEL_ARG_DTYPE lambd) diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 12bbb31372..6698174521 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2456,6 +2456,12 @@ TEST_P(Test_ONNX_layers, LayerNormExpanded) testONNXModels("layer_norm_expanded_with_initializers"); } +TEST_P(Test_ONNX_layers, Gelu) +{ + testONNXModels("gelu"); + testONNXModels("gelu_approximation"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace -- GitLab