未验证 提交 a01ccc26 编写于 作者: W Wilber 提交者: GitHub

fix mul kernel error. test=develop (#3774)

上级 6299a90a
......@@ -5,7 +5,7 @@ endif()
message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
......
......@@ -13,10 +13,9 @@
// limitations under the License.
#pragma once
#include "lite/backends/cuda/blas.h"
#include "lite/core/context.h"
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/core/types.h"
#include "lite/operators/op_params.h"
namespace paddle {
......@@ -24,56 +23,17 @@ namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
void mul_compute(const lite::cuda::Blas<float>& blas,
const T* x,
int x_h,
int x_w,
const T* y,
int y_h,
int y_w,
T* out) {
float alpha = 1.0;
float beta = 0.0;
/*
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
x_h,
y_w,
x_w,
&alpha,
x,
x_w,
y,
y_w,
&beta,
out,
x_h);
*/
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
y_w,
x_h,
y_h,
&alpha,
y,
y_w,
x,
x_w,
&beta,
out,
y_w);
}
class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void PrepareForRun() override {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
void Run() override {
CHECK(ctx_) << "running context should be set first";
auto& context = this->ctx_->template As<CUDAContext>();
CHECK(context.cublas_fp32()) << "blas should init first";
auto& blas = *context.cublas_fp32();
auto& param = this->Param<param_t>();
const auto* x_data = param.x->data<float>();
......@@ -94,10 +54,14 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
.production());
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
mul_compute<float>(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data);
CHECK(gemm_impl_->init(false, false, x_h, y_w, x_w, &context));
gemm_impl_->run(1.0f, 0.0f, x_data, y_data, out_data, &context);
}
virtual ~MulCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_{nullptr};
};
} // namespace cuda
......
......@@ -27,7 +27,6 @@ TEST(mul_compute, normal) {
MulCompute mul_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
context.InitOnce();
Tensor x, y, out, x_cpu, y_cpu, out_cpu;
int x_h = 2, x_w_y_h = 3, y_w = 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册