From a01ccc2672d60f1ee64b06988fcbe9eddd466dc3 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 11 Jun 2020 19:18:49 +0800 Subject: [PATCH] fix mul kernel error. test=develop (#3774) --- lite/kernels/cuda/CMakeLists.txt | 2 +- lite/kernels/cuda/mul_compute.h | 58 +++++---------------------- lite/kernels/cuda/mul_compute_test.cc | 1 - 3 files changed, 12 insertions(+), 49 deletions(-) diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 9c2973c5d2..1a58a51c36 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -5,7 +5,7 @@ endif() message(STATUS "compile with lite CUDA kernels") # basic kernels -add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) +add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/cuda/mul_compute.h b/lite/kernels/cuda/mul_compute.h index 320b562128..aa80919920 100644 --- a/lite/kernels/cuda/mul_compute.h +++ b/lite/kernels/cuda/mul_compute.h @@ -13,10 +13,9 @@ // limitations under the License. #pragma once -#include "lite/backends/cuda/blas.h" -#include "lite/core/context.h" +#include +#include "lite/backends/cuda/math/gemm.h" #include "lite/core/kernel.h" -#include "lite/core/types.h" #include "lite/operators/op_params.h" namespace paddle { @@ -24,56 +23,17 @@ namespace lite { namespace kernels { namespace cuda { -template -void mul_compute(const lite::cuda::Blas& blas, - const T* x, - int x_h, - int x_w, - const T* y, - int y_h, - int y_w, - T* out) { - float alpha = 1.0; - float beta = 0.0; - /* - blas.sgemm(CUBLAS_OP_N, - CUBLAS_OP_N, - x_h, - y_w, - x_w, - &alpha, - x, - x_w, - y, - y_w, - &beta, - out, - x_h); - */ - blas.sgemm(CUBLAS_OP_N, - CUBLAS_OP_N, - y_w, - x_h, - y_h, - &alpha, - y, - y_w, - x, - x_w, - &beta, - out, - y_w); -} - class MulCompute : public KernelLite { public: using param_t = operators::MulParam; + void PrepareForRun() override { + gemm_impl_.reset(new lite::cuda::math::Gemm); + } + void Run() override { CHECK(ctx_) << "running context should be set first"; auto& context = this->ctx_->template As(); - CHECK(context.cublas_fp32()) << "blas should init first"; - auto& blas = *context.cublas_fp32(); auto& param = this->Param(); const auto* x_data = param.x->data(); @@ -94,10 +54,14 @@ class MulCompute : public KernelLite { .production()); CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; - mul_compute(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data); + CHECK(gemm_impl_->init(false, false, x_h, y_w, x_w, &context)); + gemm_impl_->run(1.0f, 0.0f, x_data, y_data, out_data, &context); } virtual ~MulCompute() = default; + + private: + std::unique_ptr> gemm_impl_{nullptr}; }; } // namespace cuda diff --git a/lite/kernels/cuda/mul_compute_test.cc b/lite/kernels/cuda/mul_compute_test.cc index f521a12e2d..60bee07694 100644 --- a/lite/kernels/cuda/mul_compute_test.cc +++ b/lite/kernels/cuda/mul_compute_test.cc @@ -27,7 +27,6 @@ TEST(mul_compute, normal) { MulCompute mul_kernel; std::unique_ptr ctx(new KernelContext); auto& context = ctx->As(); - context.InitOnce(); Tensor x, y, out, x_cpu, y_cpu, out_cpu; int x_h = 2, x_w_y_h = 3, y_w = 4; -- GitLab