提交 4718a4bf 编写于 作者: W wanli

make GEMM can be supported with transA and transB in CUDA

上级 838c34ee
...@@ -374,6 +374,10 @@ CV__DNN_INLINE_NS_BEGIN ...@@ -374,6 +374,10 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<SoftmaxLayerInt8> create(const LayerParams& params); static Ptr<SoftmaxLayerInt8> create(const LayerParams& params);
}; };
/**
* `InnerProduct`, `MatMul` and `Gemm` operations are all implemented by Fully Connected Layer.
* Parameter `is_matmul` is used to distinguish `MatMul` and `Gemm` from `InnerProduct`.
*/
class CV_EXPORTS InnerProductLayer : public Layer class CV_EXPORTS InnerProductLayer : public Layer
{ {
public: public:
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "../csl/tensor.hpp" #include "../csl/tensor.hpp"
#include "../csl/tensor_ops.hpp" #include "../csl/tensor_ops.hpp"
#include "../kernels/scale_shift.hpp"
#include <opencv2/core.hpp> #include <opencv2/core.hpp>
#include <utility> #include <utility>
...@@ -23,7 +25,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { ...@@ -23,7 +25,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
public: public:
using wrapper_type = GetCUDABackendWrapperType<T>; using wrapper_type = GetCUDABackendWrapperType<T>;
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp) MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp, const Mat& bias, bool _transA, bool _transB)
: stream(std::move(stream_)), cublasHandle(std::move(handle)) : stream(std::move(stream_)), cublasHandle(std::move(handle))
{ {
if (!constInp.empty()) if (!constInp.empty())
...@@ -31,6 +33,15 @@ namespace cv { namespace dnn { namespace cuda4dnn { ...@@ -31,6 +33,15 @@ namespace cv { namespace dnn { namespace cuda4dnn {
constTensor = csl::makeTensorHeader<T>(constInp); constTensor = csl::makeTensorHeader<T>(constInp);
csl::copyMatToTensor<T>(constInp, constTensor, stream); csl::copyMatToTensor<T>(constInp, constTensor, stream);
} }
if (!bias.empty())
{
biasTensor = csl::makeTensorHeader<T>(bias);
csl::copyMatToTensor<T>(bias, biasTensor, stream);
}
transA = _transA;
transB = _transB;
} }
void forward( void forward(
...@@ -69,50 +80,72 @@ namespace cv { namespace dnn { namespace cuda4dnn { ...@@ -69,50 +80,72 @@ namespace cv { namespace dnn { namespace cuda4dnn {
CV_Assert(input2.get_axis_size(i) == size); CV_Assert(input2.get_axis_size(i) == size);
} }
auto m = input1.get_axis_size(-2); int m1, n1, b1, m2, n2, b2;
auto n = input1.get_axis_size(-1); if (transA)
auto b = input1.size() / m / n;
int k;
if (constTensor.empty())
{ {
k = input2.get_axis_size(-1); m1 = input1.get_axis_size(-1);
CV_Assert(input2.get_axis_size(-2) == n); n1 = input1.get_axis_size(-2);
} }
else else
{ {
k = input2.get_axis_size(-2); m1 = input1.get_axis_size(-2);
CV_Assert(input2.get_axis_size(-1) == n); n1 = input1.get_axis_size(-1);
} }
CV_Assert(output.get_axis_size(-2) == m);
CV_Assert(output.get_axis_size(-1) == k); if (transB)
{
m2 = input2.get_axis_size(-1);
n2 = input2.get_axis_size(-2);
}
else
{
m2 = input2.get_axis_size(-2);
n2 = input2.get_axis_size(-1);
}
b1 = input1.size() / m1 / n1;
b2 = input2.size() / m2 / n2;
CV_Assert(b1 == b2);
CV_Assert(n1 == m2);
CV_Assert(output.get_axis_size(-2) == m1);
CV_Assert(output.get_axis_size(-1) == n2);
if (get_effective_rank(output) <= 2) if (get_effective_rank(output) <= 2)
{ {
CV_Assert(b == 1); CV_Assert(b2 == 1);
CV_Assert(get_effective_rank(input1) <= 2); CV_Assert(get_effective_rank(input1) <= 2);
CV_Assert(get_effective_rank(input2) <= 2); CV_Assert(get_effective_rank(input2) <= 2);
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2); csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, transA, input1, transB, input2);
// used for GEMM
if (!biasTensor.empty())
kernels::biasN<T>(stream, output, output, 1, biasTensor);
} }
else else
{ {
CV_Assert(rank >= 3); CV_Assert(rank >= 3);
input1.reshape(b, m, n); if (transA)
if (constTensor.empty()) input1.reshape(b1, n1, m1);
input2.reshape(b, n, k); else
input1.reshape(b1, m1, n1);
if (transB)
input2.reshape(b2, n2, m2);
else else
input2.reshape(b, k, n); input2.reshape(b2, m2, n2);
output.reshape(b, m, k);
output.reshape(b1, m1, n2);
input1.squeeze_to(3); input1.squeeze_to(3);
input2.squeeze_to(3); input2.squeeze_to(3);
output.squeeze_to(3); output.squeeze_to(3);
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2); csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, transA, input1, transB, input2);
} }
} }
private: private:
csl::Stream stream; csl::Stream stream;
csl::cublas::Handle cublasHandle; csl::cublas::Handle cublasHandle;
csl::Tensor<T> constTensor; csl::Tensor<T> constTensor, biasTensor;
bool transA, transB;
}; };
}}} /* namespace cv::dnn::cuda4dnn */ }}} /* namespace cv::dnn::cuda4dnn */
......
...@@ -115,6 +115,8 @@ public: ...@@ -115,6 +115,8 @@ public:
biasMat = Mat::zeros(1, oriMat.size[oriMat.dims - 2], weightsMat.type()); biasMat = Mat::zeros(1, oriMat.size[oriMat.dims - 2], weightsMat.type());
else else
biasMat = Mat::zeros(1, numOutput, weightsMat.type()); biasMat = Mat::zeros(1, numOutput, weightsMat.type());
transB = !transB;
} }
} }
...@@ -155,7 +157,6 @@ public: ...@@ -155,7 +157,6 @@ public:
} }
else else
{ {
CV_Assert(!transA && !transB);
CV_CheckEQ(inputsTmp.size(), (size_t)1, ""); CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
CV_CheckEQ(blobs[0].dims, 2, ""); CV_CheckEQ(blobs[0].dims, 2, "");
if(isMatMul) if(isMatMul)
...@@ -183,7 +184,7 @@ public: ...@@ -183,7 +184,7 @@ public:
return axis == 1 && !tranAorB; return axis == 1 && !tranAorB;
#endif #endif
return backendId == DNN_BACKEND_OPENCV || return backendId == DNN_BACKEND_OPENCV ||
(backendId == DNN_BACKEND_CUDA && !tranAorB) || backendId == DNN_BACKEND_CUDA ||
(backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) || (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) ||
(backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB) || (backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB) ||
backendId == DNN_BACKEND_CANN;; backendId == DNN_BACKEND_CANN;;
...@@ -527,7 +528,6 @@ public: ...@@ -527,7 +528,6 @@ public:
if (!blobs.empty()) if (!blobs.empty())
{ {
CV_Assert(!transA && !transB);
int inp1Dim = input[0].dims; int inp1Dim = input[0].dims;
if (isMatMul) if (isMatMul)
{ {
...@@ -611,12 +611,12 @@ public: ...@@ -611,12 +611,12 @@ public:
const std::vector<Ptr<BackendWrapper>>& outputs const std::vector<Ptr<BackendWrapper>>& outputs
) override ) override
{ {
auto biasMat_ = bias ? biasMat : Mat();
auto context = reinterpret_cast<csl::CSLContext*>(context_); auto context = reinterpret_cast<csl::CSLContext*>(context_);
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>(); auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
if (weightsMat.empty() || isMatMul) if (weightsMat.empty() || isMatMul)
{ {
CV_Assert(!bias);
int inp2Dim; int inp2Dim;
// broadcast is not supported with CUDA // broadcast is not supported with CUDA
if(weightsMat.empty()) if(weightsMat.empty())
...@@ -627,13 +627,12 @@ public: ...@@ -627,13 +627,12 @@ public:
inp2Dim = oriMat.dims; inp2Dim = oriMat.dims;
if(input_wrapper->getRank() == inp2Dim) if(input_wrapper->getRank() == inp2Dim)
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat); return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat, biasMat_, transA, transB);
else else
return Ptr<BackendNode>(); return Ptr<BackendNode>();
} }
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank()); auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
auto biasMat_ = bias ? biasMat : Mat();
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_); return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
} }
#endif #endif
......
...@@ -2056,6 +2056,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr ...@@ -2056,6 +2056,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
} }
layerParams.set("bias_term", node_proto.input_size() == 3); layerParams.set("bias_term", node_proto.input_size() == 3);
layerParams.set("is_matmul", true);
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }
......
...@@ -1745,6 +1745,11 @@ TEST_P(Test_ONNX_layers, Gemm) ...@@ -1745,6 +1745,11 @@ TEST_P(Test_ONNX_layers, Gemm)
testONNXModels("gemm_first_const"); testONNXModels("gemm_first_const");
} }
TEST_P(Test_ONNX_layers, Gemm_bias)
{
testONNXModels("gemm_vector_bias");
}
TEST_P(Test_ONNX_layers, Quantized_Convolution) TEST_P(Test_ONNX_layers, Quantized_Convolution)
{ {
// The difference of QOperator and QDQ format: // The difference of QOperator and QDQ format:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册