未验证 提交 823f0dae 编写于 作者: W Wilber 提交者: GitHub

Optimize cuda kernel and remove io_copy added by default due to missing fetch_cuda kernel (#2920)

Optimize cuda kernel and remove io_copy added by default due to missing fetch_cuda kernel
上级 0679feed
......@@ -39,13 +39,26 @@ class TargetWrapper<TARGET(kCUDA)> {
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void CreateEvent(event_t* event) { cudaEventCreate(event); }
static void CreateEventWithFlags(
event_t* event, unsigned int flags = cudaEventDisableTiming) {
cudaEventCreateWithFlags(event, flags);
}
static void DestroyEvent(const event_t& event) { cudaEventDestroy(event); }
static void RecordEvent(const event_t& event) {}
static void RecordEvent(const event_t& event, const stream_t& stream) {
cudaEventRecord(event, stream);
}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void StreamSync(const stream_t& stream) {
cudaStreamSynchronize(stream);
}
static void StreamSync(const stream_t& stream, const event_t& event) {
cudaStreamWaitEvent(stream, event, 0);
}
static void DeviceSync() { cudaDeviceSynchronize(); }
static void* Malloc(size_t size);
static void Free(void* ptr);
......
......@@ -149,6 +149,9 @@ void RuntimeProgram::Run() {
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
}
#ifdef LITE_WITH_CUDA
TargetWrapperCuda::DeviceSync();
#endif
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 0);
#endif // LITE_WITH_PROFILE
......
......@@ -20,6 +20,7 @@ add_kernel(elementwise_compute_cuda CUDA basic SRCS elementwise_compute.cu DEPS
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fetch_compute_cuda CUDA basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
......
......@@ -63,6 +63,21 @@ __global__ void ker_attention_padding_mask(T* out_data,
}
}
template <typename Dtype>
__global__ void ker_find_begin_data(int count,
Dtype* out,
const Dtype* src,
const Dtype pad_data,
const int offset_len) {
CUDA_KERNEL_LOOP(tid, count) {
int index = offset_len - 1;
const Dtype* src_data = src + offset_len * tid;
for (; index >= 0 && pad_data == src_data[index]; --index) {
}
out[tid] = index + 1;
}
}
void AttentionPaddingMaskCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
......@@ -85,34 +100,16 @@ void AttentionPaddingMaskCompute::Run() {
auto attn_data = attn->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
std::vector<float> src_cpu(src->numel(), 0);
TargetWrapperCuda::MemcpyAsync(src_cpu.data(),
src->data<float>(),
sizeof(float) * src->numel(),
IoDirection::DtoH,
stream);
cudaStreamSynchronize(stream);
std::vector<float> pad_begin(src_seq_num, 0);
auto src_len = static_cast<int64_t>(src->lod()[0][1]);
int _pad_id = param.pad_id;
for (int i = 0; i < src_seq_num; ++i) {
const auto* src_data = src_cpu.data() + src_len * i;
int index = src_len - 1;
for (; index >= 0 && _pad_id == static_cast<int>(src_data[index]);
--index) {
}
pad_begin[i] = static_cast<float>(index + 1);
}
param.pad_begin->Resize({static_cast<int64_t>(src_seq_num)});
auto pad_begin_cuda_data =
param.pad_begin->mutable_data<float>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(pad_begin_cuda_data,
pad_begin.data(),
sizeof(float) * src_seq_num,
IoDirection::HtoD,
stream);
ker_find_begin_data<
float><<<CUDA_GET_BLOCKS(src_seq_num), CUDA_NUM_THREADS, 0, stream>>>(
src_seq_num,
pad_begin_cuda_data,
src->data<float>(),
static_cast<float>(param.pad_id),
static_cast<int>(src->lod()[0][1]));
std::vector<int> src_offset_cpu(src_offset.size(), 0);
for (int i = 0; i < src_offset.size(); i++) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fetch_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
void FetchCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto* fetch_list = param.fetch_list;
if (fetch_list->size() <= static_cast<size_t>(param.col)) {
fetch_list->resize(param.col + 1);
}
int num = static_cast<int>(param.input->numel());
auto& dst = fetch_list->at(param.col);
dst.Resize(param.input->dims());
auto output = dst.template mutable_data<T>();
TargetW::MemcpyAsync(output,
param.input->template data<T>(),
num * sizeof(T),
IoDirection::DtoH,
stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::cuda::FetchCompute<float, PRECISION(kFloat)>
FetchFp32;
REGISTER_LITE_KERNEL(fetch, kCUDA, kFloat, kNCHW, FetchFp32, nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(fetch, kCUDA, kFloat, kNHWC, FetchFp32, nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class FetchCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::FetchParam;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
void Run() override;
virtual ~FetchCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -89,7 +89,6 @@ void SearchGroupPaddingCompute::Run() {
out_new_lod.push_back(in_seq_offset);
out_new->set_lod(out_new_lod);
out_new->Resize({x_dims[0], 1});
float* out_new_data = out_new->mutable_data<float>(TARGET(kCUDA));
LoD out_padding_lod;
out_padding_lod.push_back(new_offset);
......@@ -111,12 +110,11 @@ void SearchGroupPaddingCompute::Run() {
IoDirection::HtoD,
cuda_stream);
TargetWrapperCuda::MemsetSync(
out_new_data, 0, out_new->dims()[0] * out_new->dims()[1] * sizeof(float));
TargetWrapperCuda::MemsetSync(
TargetWrapperCuda::MemsetAsync(
out_padding_data,
0,
out_padding->dims()[0] * out_padding->dims()[1] * sizeof(float));
out_padding->dims()[0] * out_padding->dims()[1] * sizeof(float),
cuda_stream);
ker_search_group_padding<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, cuda_stream>>>(
......
......@@ -161,7 +161,6 @@ void SequencePoolCompute::Run() {
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
const float* in_data = param.X->data<float>();
lite::Tensor seq_offset_D;
seq_offset_D.Resize({static_cast<int64_t>(seq_offset.size())});
TargetWrapperCuda::MemcpyAsync(
seq_offset_D.mutable_data<uint64_t>(TARGET(kCUDA)),
......
......@@ -27,6 +27,9 @@ class SequencePoolCompute
void Run() override;
virtual ~SequencePoolCompute() = default;
private:
lite::Tensor seq_offset_D;
};
} // namespace cuda
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册