未验证 提交 a01ccc26 编写于 作者: W Wilber 提交者: GitHub

fix mul kernel error. test=develop (#3774)

上级 6299a90a
...@@ -5,7 +5,7 @@ endif() ...@@ -5,7 +5,7 @@ endif()
message(STATUS "compile with lite CUDA kernels") message(STATUS "compile with lite CUDA kernels")
# basic kernels # basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
......
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/backends/cuda/blas.h" #include <memory>
#include "lite/core/context.h" #include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/types.h"
#include "lite/operators/op_params.h" #include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
...@@ -24,56 +23,17 @@ namespace lite { ...@@ -24,56 +23,17 @@ namespace lite {
namespace kernels { namespace kernels {
namespace cuda { namespace cuda {
template <typename T>
void mul_compute(const lite::cuda::Blas<float>& blas,
const T* x,
int x_h,
int x_w,
const T* y,
int y_h,
int y_w,
T* out) {
float alpha = 1.0;
float beta = 0.0;
/*
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
x_h,
y_w,
x_w,
&alpha,
x,
x_w,
y,
y_w,
&beta,
out,
x_h);
*/
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
y_w,
x_h,
y_h,
&alpha,
y,
y_w,
x,
x_w,
&beta,
out,
y_w);
}
class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public: public:
using param_t = operators::MulParam; using param_t = operators::MulParam;
void PrepareForRun() override {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
void Run() override { void Run() override {
CHECK(ctx_) << "running context should be set first"; CHECK(ctx_) << "running context should be set first";
auto& context = this->ctx_->template As<CUDAContext>(); auto& context = this->ctx_->template As<CUDAContext>();
CHECK(context.cublas_fp32()) << "blas should init first";
auto& blas = *context.cublas_fp32();
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
const auto* x_data = param.x->data<float>(); const auto* x_data = param.x->data<float>();
...@@ -94,10 +54,14 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -94,10 +54,14 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
.production()); .production());
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
mul_compute<float>(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data); CHECK(gemm_impl_->init(false, false, x_h, y_w, x_w, &context));
gemm_impl_->run(1.0f, 0.0f, x_data, y_data, out_data, &context);
} }
virtual ~MulCompute() = default; virtual ~MulCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_{nullptr};
}; };
} // namespace cuda } // namespace cuda
......
...@@ -27,7 +27,6 @@ TEST(mul_compute, normal) { ...@@ -27,7 +27,6 @@ TEST(mul_compute, normal) {
MulCompute mul_kernel; MulCompute mul_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx->As<CUDAContext>();
context.InitOnce();
Tensor x, y, out, x_cpu, y_cpu, out_cpu; Tensor x, y, out, x_cpu, y_cpu, out_cpu;
int x_h = 2, x_w_y_h = 3, y_w = 4; int x_h = 2, x_w_y_h = 3, y_w = 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册