提交 4718a4bf 编写于 作者: W wanli

make GEMM can be supported with transA and transB in CUDA

上级 838c34ee
......@@ -374,6 +374,10 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<SoftmaxLayerInt8> create(const LayerParams& params);
};
/**
* `InnerProduct`, `MatMul` and `Gemm` operations are all implemented by Fully Connected Layer.
* Parameter `is_matmul` is used to distinguish `MatMul` and `Gemm` from `InnerProduct`.
*/
class CV_EXPORTS InnerProductLayer : public Layer
{
public:
......
......@@ -12,6 +12,8 @@
#include "../csl/tensor.hpp"
#include "../csl/tensor_ops.hpp"
#include "../kernels/scale_shift.hpp"
#include <opencv2/core.hpp>
#include <utility>
......@@ -23,7 +25,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp)
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp, const Mat& bias, bool _transA, bool _transB)
: stream(std::move(stream_)), cublasHandle(std::move(handle))
{
if (!constInp.empty())
......@@ -31,6 +33,15 @@ namespace cv { namespace dnn { namespace cuda4dnn {
constTensor = csl::makeTensorHeader<T>(constInp);
csl::copyMatToTensor<T>(constInp, constTensor, stream);
}
if (!bias.empty())
{
biasTensor = csl::makeTensorHeader<T>(bias);
csl::copyMatToTensor<T>(bias, biasTensor, stream);
}
transA = _transA;
transB = _transB;
}
void forward(
......@@ -69,50 +80,72 @@ namespace cv { namespace dnn { namespace cuda4dnn {
CV_Assert(input2.get_axis_size(i) == size);
}
auto m = input1.get_axis_size(-2);
auto n = input1.get_axis_size(-1);
auto b = input1.size() / m / n;
int k;
if (constTensor.empty())
int m1, n1, b1, m2, n2, b2;
if (transA)
{
m1 = input1.get_axis_size(-1);
n1 = input1.get_axis_size(-2);
}
else
{
m1 = input1.get_axis_size(-2);
n1 = input1.get_axis_size(-1);
}
if (transB)
{
k = input2.get_axis_size(-1);
CV_Assert(input2.get_axis_size(-2) == n);
m2 = input2.get_axis_size(-1);
n2 = input2.get_axis_size(-2);
}
else
{
k = input2.get_axis_size(-2);
CV_Assert(input2.get_axis_size(-1) == n);
m2 = input2.get_axis_size(-2);
n2 = input2.get_axis_size(-1);
}
CV_Assert(output.get_axis_size(-2) == m);
CV_Assert(output.get_axis_size(-1) == k);
b1 = input1.size() / m1 / n1;
b2 = input2.size() / m2 / n2;
CV_Assert(b1 == b2);
CV_Assert(n1 == m2);
CV_Assert(output.get_axis_size(-2) == m1);
CV_Assert(output.get_axis_size(-1) == n2);
if (get_effective_rank(output) <= 2)
{
CV_Assert(b == 1);
CV_Assert(b2 == 1);
CV_Assert(get_effective_rank(input1) <= 2);
CV_Assert(get_effective_rank(input2) <= 2);
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, transA, input1, transB, input2);
// used for GEMM
if (!biasTensor.empty())
kernels::biasN<T>(stream, output, output, 1, biasTensor);
}
else
{
CV_Assert(rank >= 3);
input1.reshape(b, m, n);
if (constTensor.empty())
input2.reshape(b, n, k);
if (transA)
input1.reshape(b1, n1, m1);
else
input1.reshape(b1, m1, n1);
if (transB)
input2.reshape(b2, n2, m2);
else
input2.reshape(b, k, n);
output.reshape(b, m, k);
input2.reshape(b2, m2, n2);
output.reshape(b1, m1, n2);
input1.squeeze_to(3);
input2.squeeze_to(3);
output.squeeze_to(3);
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, transA, input1, transB, input2);
}
}
private:
csl::Stream stream;
csl::cublas::Handle cublasHandle;
csl::Tensor<T> constTensor;
csl::Tensor<T> constTensor, biasTensor;
bool transA, transB;
};
}}} /* namespace cv::dnn::cuda4dnn */
......
......@@ -115,6 +115,8 @@ public:
biasMat = Mat::zeros(1, oriMat.size[oriMat.dims - 2], weightsMat.type());
else
biasMat = Mat::zeros(1, numOutput, weightsMat.type());
transB = !transB;
}
}
......@@ -155,7 +157,6 @@ public:
}
else
{
CV_Assert(!transA && !transB);
CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
CV_CheckEQ(blobs[0].dims, 2, "");
if(isMatMul)
......@@ -183,7 +184,7 @@ public:
return axis == 1 && !tranAorB;
#endif
return backendId == DNN_BACKEND_OPENCV ||
(backendId == DNN_BACKEND_CUDA && !tranAorB) ||
backendId == DNN_BACKEND_CUDA ||
(backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) ||
(backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB) ||
backendId == DNN_BACKEND_CANN;;
......@@ -527,7 +528,6 @@ public:
if (!blobs.empty())
{
CV_Assert(!transA && !transB);
int inp1Dim = input[0].dims;
if (isMatMul)
{
......@@ -611,12 +611,12 @@ public:
const std::vector<Ptr<BackendWrapper>>& outputs
) override
{
auto biasMat_ = bias ? biasMat : Mat();
auto context = reinterpret_cast<csl::CSLContext*>(context_);
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
if (weightsMat.empty() || isMatMul)
{
CV_Assert(!bias);
int inp2Dim;
// broadcast is not supported with CUDA
if(weightsMat.empty())
......@@ -627,13 +627,12 @@ public:
inp2Dim = oriMat.dims;
if(input_wrapper->getRank() == inp2Dim)
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat);
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat, biasMat_, transA, transB);
else
return Ptr<BackendNode>();
}
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
auto biasMat_ = bias ? biasMat : Mat();
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
}
#endif
......
......@@ -2056,6 +2056,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
}
layerParams.set("bias_term", node_proto.input_size() == 3);
layerParams.set("is_matmul", true);
addLayer(layerParams, node_proto);
}
......
......@@ -1745,6 +1745,11 @@ TEST_P(Test_ONNX_layers, Gemm)
testONNXModels("gemm_first_const");
}
TEST_P(Test_ONNX_layers, Gemm_bias)
{
testONNXModels("gemm_vector_bias");
}
TEST_P(Test_ONNX_layers, Quantized_Convolution)
{
// The difference of QOperator and QDQ format:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册