未验证 提交 da328594 编写于 作者: Z Zhen Wang 提交者: GitHub

add calib cuda kernel. (#1977)

* add calib cuda kernel.

* add unit test for calib cuda kernel. test=develop
上级 7014a76b
......@@ -105,6 +105,7 @@ KernelRegistry::KernelRegistry()
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kInt8, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
......
......@@ -17,6 +17,9 @@ nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEP
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_library(calib_compute_cuda SRCS calib_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
set(cuda_kernels
conv2d_cuda
mul_compute_cuda
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/cuda/calib_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
__device__ __forceinline__ int8_t float2int8(float x) {
x = fmaxf(x, INT8_MIN);
x = fminf(x, INT8_MAX);
return __float2int_rn(x);
}
__global__ void Fp32ToInt8Kernel(const int num,
const float scale,
const float* input,
int8_t* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = float2int8(input[index] / scale);
}
}
__global__ void Int8ToFp32Kernel(const int num,
const float scale,
const int8_t* input,
float* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = input[index] * scale;
}
}
void CalibComputeFp32ToInt8::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<CUDAContext>();
auto stream = ctx.exec_stream();
auto scale = param.scale;
const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<int8_t>(TARGET(kCUDA));
int num = static_cast<int>(param.input->numel());
int threads = 1024;
int blocks = (num + threads - 1) / threads;
Fp32ToInt8Kernel<<<blocks, threads, 0, stream>>>(num, scale, din, dout);
cudaError_t error = cudaGetLastError();
CHECK(error == cudaSuccess) << cudaGetErrorString(error);
}
void CalibComputeInt8ToFp32::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<CUDAContext>();
auto stream = ctx.exec_stream();
auto scale = param.scale;
const auto* din = param.input->data<int8_t>();
auto* dout = param.output->mutable_data<float>(TARGET(kCUDA));
int num = static_cast<int>(param.input->numel());
int threads = 1024;
int blocks = (num + threads - 1) / threads;
Int8ToFp32Kernel<<<blocks, threads, 0, stream>>>(num, scale, din, dout);
cudaError_t error = cudaGetLastError();
CHECK(error == cudaSuccess) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(calib,
kCUDA,
kInt8,
kNCHW,
paddle::lite::kernels::cuda::CalibComputeFp32ToInt8,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(calib,
kCUDA,
kInt8,
kNCHW,
paddle::lite::kernels::cuda::CalibComputeInt8ToFp32,
int8_to_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(calib_once,
kCUDA,
kInt8,
kNCHW,
paddle::lite::kernels::cuda::CalibComputeFp32ToInt8,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(calib_once,
kCUDA,
kInt8,
kNCHW,
paddle::lite::kernels::cuda::CalibComputeInt8ToFp32,
int8_to_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/kernel.h"
#include "lite/operators/calib_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class CalibComputeFp32ToInt8
: public KernelLite<TARGET(kCUDA), PRECISION(kInt8)> {
public:
using param_t = operators::CalibParam;
void Run() override;
virtual ~CalibComputeFp32ToInt8() = default;
std::string doc() const override { return "Fp32 --> Int8"; }
};
class CalibComputeInt8ToFp32
: public KernelLite<TARGET(kCUDA), PRECISION(kInt8)> {
public:
using param_t = operators::CalibParam;
void Run() override;
virtual ~CalibComputeInt8ToFp32() = default;
std::string doc() const override { return "Int8 --> Fp32"; }
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
static void int8_to_fp32_basic(const int8_t* din,
float* dout,
const float scale,
int num) {
for (int j = 0; j < num; ++j) {
dout[j] = din[j] * scale;
}
}
static void fp32_to_int8_basic(const float* din,
int8_t* dout,
const float scale,
int num) {
for (int j = 0; j < num; ++j) {
auto v = din[j] / scale;
v = std::max(v, static_cast<float>(INT8_MIN));
v = std::min(v, static_cast<float>(INT8_MAX));
v = roundf(v);
dout[j] = static_cast<int8_t>(v);
}
}
void calib_ref(const operators::CalibParam& param, bool to_float = true) {
auto scale = param.scale;
if (to_float) {
const auto* din = param.input->data<int8_t>();
auto* dout = param.output->mutable_data<float>();
int8_to_fp32_basic(din, dout, scale, param.input->numel());
} else {
const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<int8_t>();
fp32_to_int8_basic(din, dout, scale, param.input->numel());
}
}
TEST(calib_cuda, int8_to_fp32) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto calib = std::move(*std::next(kernels.begin(), 1));
LOG(INFO) << "get kernel: " << calib->doc();
const int n = 64, c = 32, h = 18, w = 18;
Tensor x;
Tensor x_cpu;
Tensor output;
Tensor output_cpu;
// set the dims of input, output tensors
x.Resize({n, c, h, w});
x_cpu.Resize({n, c, h, w});
output.Resize({n, c, h, w});
output_cpu.Resize({n, c, h, w});
// initialize the data of input tensors
auto* x_data = x.mutable_data<int8_t>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
for (int i = 0; i < x.dims().production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_cpu_data[i] = static_cast<int8_t>(sign * (i % 127));
}
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
// prepare kernel params and run
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
calib->SetContext(std::move(ctx));
operators::CalibParam param;
param.scale = 0.013f;
param.input = &x;
param.output = &output;
calib->SetParam(param);
calib->Launch();
cudaDeviceSynchronize();
// invoking ref implementation and compare results
param.input = &x_cpu;
param.output = &output_cpu;
calib_ref(param);
auto* output_data = output.mutable_data<float>();
std::unique_ptr<float[]> output_gpu_copy(new float[output.numel()]);
CopySync<TARGET(kCUDA)>(output_gpu_copy.get(),
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
const auto* output_cpu_data = output_cpu.data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_gpu_copy[i], output_cpu_data[i], 1e-5);
}
}
TEST(calib_cuda, fp32_to_int8) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto calib = std::move(kernels.front());
LOG(INFO) << "get kernel: " << calib->doc();
const int n = 64, c = 32, h = 18, w = 18;
Tensor x;
Tensor x_cpu;
Tensor output;
Tensor output_cpu;
// set the dims of input, output tensors
x.Resize({n, c, h, w});
x_cpu.Resize({n, c, h, w});
output.Resize({n, c, h, w});
output_cpu.Resize({n, c, h, w});
// initialize the data of input tensors
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_cpu_data[i] = sign * (i % 127) * 0.013f;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
// prepare kernel params and run
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
calib->SetContext(std::move(ctx));
operators::CalibParam param;
param.scale = 0.013f;
param.input = &x;
param.output = &output;
calib->SetParam(param);
calib->Launch();
cudaDeviceSynchronize();
// invoking ref implementation and compare results
param.input = &x_cpu;
param.output = &output_cpu;
calib_ref(param, false);
auto* output_data = output.mutable_data<int8_t>();
std::unique_ptr<int8_t[]> output_gpu_copy(new int8_t[output.numel()]);
CopySync<TARGET(kCUDA)>(output_gpu_copy.get(),
output_data,
sizeof(int8_t) * output.numel(),
IoDirection::DtoH);
const auto* output_cpu_data = output_cpu.data<int8_t>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_EQ(output_gpu_copy[i], output_cpu_data[i]);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(calib, kCUDA, kInt8, kNCHW, int8_to_fp32);
USE_LITE_KERNEL(calib, kCUDA, kInt8, kNCHW, fp32_to_int8);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册