提交 30a2d0ec 编写于 作者: J juncaipeng 提交者: GitHub

Add cuda match_matrix_tensor op and test (#2434)

* add cuda match_matrix_tensor op and test, test=develop
上级 2148bf49
......@@ -12,6 +12,7 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -21,6 +22,7 @@ set (
cuda_type_trans
cuda_transpose
cuda_elementwise
cuda_gemm
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/gemm.h"
#include <iostream>
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <>
bool Gemm<float, float>::init(const bool trans_a,
bool trans_b,
const int m,
const int n,
const int k,
Context<TARGET(kCUDA)> *ctx) {
if (cu_handle_ == nullptr) {
this->exe_stream_ = ctx->exec_stream();
CUBLAS_CALL(cublasCreate(&cu_handle_));
CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_));
}
lda_ = (!trans_a) ? k : m;
ldb_ = (!trans_b) ? n : k;
ldc_ = n;
m_ = m;
n_ = n;
k_ = k;
cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
return true;
}
template <>
bool Gemm<float, float>::init(const bool trans_a,
bool trans_b,
const int m,
const int n,
const int k,
const int lda,
const int ldb,
const int ldc,
Context<TARGET(kCUDA)> *ctx) {
if (cu_handle_ == nullptr) {
this->exe_stream_ = ctx->exec_stream();
CUBLAS_CALL(cublasCreate(&cu_handle_));
CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_));
}
m_ = m;
n_ = n;
k_ = k;
lda_ = lda;
ldb_ = ldb;
ldc_ = ldc;
cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
return true;
}
template <>
bool Gemm<float, float>::run(const float alpha,
const float beta,
const float *a,
const float *b,
float *c,
Context<TARGET(kCUDA)> *ctx) {
CUBLAS_CALL(cublasSgemm(cu_handle_,
cu_trans_b_,
cu_trans_a_,
n_,
m_,
k_,
&alpha,
b,
ldb_,
a,
lda_,
&beta,
c,
ldc_));
return true;
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename PtypeIn, typename PtypeOut>
class Gemm {
public:
Gemm() : cu_handle_(nullptr) {}
~Gemm() {}
bool init(const bool trans_a,
const bool trans_b,
const int m,
const int n,
const int k,
Context<TARGET(kCUDA)>* ctx);
bool init(const bool trans_a,
const bool trans_b,
const int m,
const int n,
const int k,
const int lda,
const int ldb,
const int ldc,
Context<TARGET(kCUDA)>* ctx);
bool run(const PtypeOut alpha,
const PtypeOut beta,
const PtypeIn* a,
const PtypeIn* b,
PtypeOut* c,
Context<TARGET(kCUDA)>* ctx);
private:
cudaStream_t exe_stream_;
cublasHandle_t cu_handle_;
cublasOperation_t cu_trans_a_;
cublasOperation_t cu_trans_b_;
int m_{-1};
int n_{-1};
int k_{-1};
int lda_{-1};
int ldb_{-1};
int ldc_{-1};
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -26,6 +26,7 @@ add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depaddin
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
......@@ -44,6 +45,7 @@ nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc D
nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda)
if(LITE_BUILD_EXTRA)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
endif()
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/match_matrix_tensor_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
void MatchMatrixTensorCompute::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
void MatchMatrixTensorCompute::Run() {
CHECK(ctx_) << "running context should be set first";
auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>();
auto* x = param.x;
auto* w = param.w;
auto* y = param.y;
auto* out = param.out;
auto* tmp = param.tmp;
int dim_t = param.dim_t;
int dim_in = x->dims()[1];
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
top_size += dim_t * len_l * len_r;
top_offset.push_back(top_size);
}
auto* bottom_l_data = x->data<float>();
auto* bottom_r_data = y->data<float>();
auto* t_data = w->data<float>();
auto* out_data = out->mutable_data<float>(TARGET(kCUDA));
auto* bottom_l_trans_data = tmp->mutable_data<float>(TARGET(kCUDA));
gemm_impl_->init(
false, false, x->dims()[0], dim_t * dim_in, dim_in, &context);
gemm_impl_->run(
1.0f, 0.0f, bottom_l_data, t_data, bottom_l_trans_data, &context);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
for (int t = 0; t < dim_t; t++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
auto* top_data = out_data + top_offset[b] + t * len_l * len_r;
const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
gemm_impl_->init(false,
true,
len_l,
len_r,
dim_in,
dim_t * dim_in,
dim_in,
len_r,
&context);
gemm_impl_->run(1.0f, 0.0f, l_t_data, r_data, top_data, &context);
}
}
LoD out_lod;
out_lod.push_back(top_offset);
out->set_lod(out_lod);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(match_matrix_tensor,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::MatchMatrixTensorCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("W",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Tmp",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class MatchMatrixTensorCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::MatchMatrixTensorParam;
void PrepareForRun() override;
void Run() override;
virtual ~MatchMatrixTensorCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/match_matrix_tensor_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
TEST(match_matrix_tensor, normal) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
MatchMatrixTensorCompute kernel;
operators::MatchMatrixTensorParam param;
// prepare ins and outs tensor in gpu, including size and lod
int ix = 5, iy = 4, h = 2, dim_t = 2;
Tensor x, w, y, out, tmp;
x.Resize({ix, h});
w.Resize({h, dim_t, h});
y.Resize({iy, h});
out.Resize({18, 1});
tmp.Resize({20, 1});
LoD x_lod{};
x_lod.push_back({0, 2, 5});
x.set_lod(x_lod);
LoD y_lod{};
y_lod.push_back({0, 3, 4});
y.set_lod(y_lod);
// init ins tensor in cpu
Tensor x_cpu, w_cpu, y_cpu, out_cpu, tmp_cpu;
x_cpu.Resize({ix, h});
w_cpu.Resize({h, dim_t, h});
y_cpu.Resize({iy, h});
out_cpu.Resize({18, 1});
tmp_cpu.Resize({20, 1});
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* w_cpu_data = w_cpu.mutable_data<float>();
auto* y_cpu_data = y_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = static_cast<float>(i);
}
for (int i = 0; i < w_cpu.numel(); ++i) {
w_cpu_data[i] = static_cast<float>(i);
}
for (int i = 0; i < y_cpu.numel(); ++i) {
y_cpu_data[i] = static_cast<float>(i);
}
// cpu tensor data assigin to gpu tensor
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
w.Assign<float, lite::DDim, TARGET(kCUDA)>(w_cpu_data, w_cpu.dims());
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
param.x = &x;
param.w = &w;
param.y = &y;
param.dim_t = dim_t;
param.out = &out;
param.tmp = &tmp;
kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
kernel.SetContext(std::move(ctx));
kernel.Launch();
cudaDeviceSynchronize();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
std::vector<float> ref_results = {5,
23,
41,
17,
75,
133,
7,
33,
59,
27,
125,
223,
323,
455,
587,
557,
793,
1029};
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册