未验证 提交 57d8e42e 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

to support yolov3 unet alexnet can run on tx2 (#2216)

* add gpu kernel mul pool relu scale softmax dropout bilinear_interp and can run in tx2

* rm GREATER_EQUAL
上级 fda4d42c
......@@ -6,7 +6,7 @@ set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 62 70 75")
######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
......
......@@ -34,6 +34,14 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
)
if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0))
find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH)
set(CUBLAS_LIBRARIES ${CUBLAS_LIBRARY})
else()
set(CUBLAS_LIBRARIES ${CUDA_CUBLAS_LIBRARIES})
endif()
set(CUDNN_LIB_NAME "libcudnn.so")
if(WIN32)
......
......@@ -146,8 +146,11 @@ set(GPU_COMMON_FLAGS
-Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
-gencode arch=compute_62,code=sm_62
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
if(NOT LITE_WITH_CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
endif()
endif(NOT WIN32)
if (APPLE)
......
......@@ -507,7 +507,7 @@ function(nv_test TARGET_NAME)
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY})
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} )
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
......
......@@ -30,10 +30,8 @@ namespace cuda {
* Some basic methods.
*/
struct BlasBase {
/*
BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); }
~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); }
*/
void SetStream(cudaStream_t stream) {
CUBLAS_CHECK(cublasSetStream(handle_, stream));
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "iostream"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/utils.h"
......@@ -21,18 +22,36 @@ namespace lite {
namespace cuda {
namespace math {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void scale_kernel(int num, const T* in, T* out, const float scale) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
#if __CUDA_ARCH__ >= 350
out[tid] = __ldg(in + tid) * scale;
#else
out[tid] = in[tid] * scale;
#endif
__global__ void scale_kernel(int count,
const T* in_data,
T* out_data,
const T* scale_data,
const T* bias_data,
const int scale_dim,
const int inner_dim) {
CUDA_KERNEL_LOOP(tid, count) {
int scale_id = (tid / inner_dim) % scale_dim;
T scale = scale_data[scale_id];
if (bias_data == nullptr) {
out_data[tid] = scale * in_data[tid];
} else {
out_data[tid] = scale * in_data[tid] + bias_data[scale_id];
}
}
}
template <typename T>
__global__ void scale_kernel(
int count, const T* in_data, T* out_data, const T scale, const T bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; }
}
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
float4* out,
......@@ -114,21 +133,25 @@ void fp32_scale_nhwc(int num,
}
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream) {
void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale);
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale, bias);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template <typename T>
void scale(int num, const T* in, T* out, float scale) {
void scale(int num, const T* in, T* out, T scale, T bias) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread>>>(num, in, out, scale);
scale_kernel<<<block, thread>>>(num, in, out, scale, bias);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template void scale(int num, const float*, float*, float, cudaStream_t);
template void scale(int num, const float*, float*, float);
template void scale(int num, const float*, float*, float, cudaStream_t, float);
template void scale(int num, const float*, float*, float, float);
} // namespace math
} // namespace cuda
......
......@@ -32,10 +32,11 @@ void fp32_scale_nhwc(int num,
cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
void scale(
int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
void scale(int num, const T* in, T* out, T scale, T bias = 0);
} // namespace math
} // namespace cuda
......
......@@ -54,11 +54,6 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc,
valid_kernels_ = op_->CreateKernels(valid_places);
}
std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) {
os << "Statement " << other.op_type() << " " << other.place().DebugString();
return os;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) {
auto &x = AsArg();
x.name = name;
......
......@@ -74,7 +74,11 @@ class Node {
KernelBase& picked_kernel();
friend std::ostream& operator<<(std::ostream& os, const Stmt& other);
friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
os << "Statement " << other.op_type() << " "
<< other.place().DebugString();
return os;
}
// Description.
std::string desc;
......
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
......@@ -16,6 +17,8 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
......@@ -24,6 +27,7 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
......@@ -31,4 +35,7 @@ nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc D
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
......@@ -111,6 +111,28 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ConvCompute,
def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
conv2d,
kCUDA,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/dropout_compute.h"
#include <string>
#include "lite/backends/cuda/math/scale.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void DropoutCompute::Run() {
auto& param = Param<operators::DropoutParam>();
const float* x_data = param.x->data<float>();
float* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
int num = param.x->dims().production();
const float prob_data = param.dropout_prob;
float scale = 1.0f;
if (param.dropout_implementation == "downgrade_in_infer") {
scale = 1.0f - prob_data;
}
lite::cuda::math::scale(num, x_data, out_data, scale, 0);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(dropout,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::DropoutCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class DropoutCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~DropoutCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/dropout_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename dtype>
void dropout_compute_ref(const operators::DropoutParam& param) {
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
int num = param.x->dims().production();
const float prob_data = param.dropout_prob;
if (param.dropout_implementation.compare(
std::string({"downgrade_in_infer"})) == 0) {
float scale = 1.0 - prob_data;
for (int i = 0; i < num; i++) {
output_data[i] = x_data[i] * scale;
}
} else {
for (int i = 0; i < num; i++) {
output_data[i] = x_data[i];
}
}
}
TEST(dropout_cuda, normal) {
DropoutCompute dropout_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::DropoutParam param;
lite::Tensor x;
lite::Tensor x_cpu;
lite::Tensor x_ref;
lite::Tensor output;
lite::Tensor output_cpu;
lite::Tensor output_ref;
for (auto n : {1, 3, 4}) {
for (auto c : {1, 3, 4, 256}) {
for (auto h : {1, 3, 4, 6}) {
for (auto w : {1, 3, 4, 6}) {
for (auto prob : {0.2f, 0.8f})
for (auto impl : {std::string({"downgrade_in_infer"})}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
x_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
x_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* output_data = output.mutable_data<float>(TARGET(kCUDA));
auto* output_cpu_data = output_cpu.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_cpu_data[i] = i;
x_ref_data[i] = i;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data,
x_cpu.dims());
param.x = &x;
param.output = &output;
param.dropout_prob = prob;
param.dropout_implementation = impl;
dropout_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
dropout_kernel.SetContext(std::move(ctx));
dropout_kernel.Launch();
CopySync<TARGET(kCUDA)>(output_cpu_data,
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
param.x = &x_ref;
param.output = &output_ref;
dropout_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -29,12 +29,7 @@ void ElementwiseAddCompute::Run() {
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
const int n = x->dims()[0];
const int c = x->dims()[1];
const int h = x->dims()[2];
const int w = x->dims()[3];
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
......@@ -57,12 +52,7 @@ void ElementwiseAddComputeNHWC::Run() {
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
const int n = x->dims()[0];
const int c = x->dims()[1];
const int h = x->dims()[2];
const int w = x->dims()[3];
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
......@@ -85,7 +75,7 @@ void ElementwiseAddComputeInt8::Run() {
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
CHECK(x->dims().production() == y->dims().production());
const int c = x->dims()[3];
......
......@@ -33,19 +33,36 @@ void mul_compute(const lite::cuda::Blas<float>& blas,
int y_h,
int y_w,
T* out) {
float alpha = 1.0;
float beta = 0.0;
/*
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
x_h,
y_w,
x_w,
nullptr,
&alpha,
x,
x_w,
y,
y_w,
nullptr,
&beta,
out,
x_h);
*/
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
y_w,
x_h,
y_h,
&alpha,
y,
y_w,
x,
x_w,
&beta,
out,
y_w);
}
class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
......@@ -56,23 +73,29 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
CHECK(ctx_) << "running context should be set first";
auto& context = this->ctx_->template As<CUDAContext>();
CHECK(context.cublas_fp32()) << "blas should init first";
/*
auto& blas = *context.cublas_fp32();
CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>();
int x_h = param.x->dims()[0];
int x_w = param.x->dims()[1];
auto* y = param.y->data<float>();
int y_h = param.y->dims()[0];
int y_w = param.y->dims()[1];
*/
auto& param = this->Param<param_t>();
const auto* x_data = param.x->data<float>();
const auto* y_data = param.y->data<float>();
auto* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
const auto& param = Param<operators::MulParam>();
param.output->mutable_data<float>(TARGET(kCUDA));
LOG(INFO) << "mul output memory size " << param.output->data_size();
int x_h = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w = static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
int y_w = static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
LOG(INFO) << x_h << " " << x_w << " " << y_h << " " << y_w;
// mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out);
mul_compute<float>(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data);
}
virtual ~MulCompute() = default;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/mul_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(mul_compute, normal) {
MulCompute mul_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
Tensor x, y, out, x_cpu, y_cpu, out_cpu;
int x_h = 2, x_w_y_h = 3, y_w = 4;
out.Resize({x_h, y_w});
x_cpu.Resize({x_h, x_w_y_h});
y_cpu.Resize({x_w_y_h, y_w});
out_cpu.Resize({x_h, y_w});
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
float* out_cpu_data = out_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i + 1.0;
}
for (int i = 0; i < y_cpu.numel(); i++) {
y_cpu_data[i] = i + 1.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
operators::MulParam param;
param.x = &x;
param.y = &y;
param.output = &out;
mul_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
mul_kernel.SetContext(std::move(ctx));
mul_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
for (int i = 0; i < out_cpu.numel(); i++) {
LOG(INFO) << out_cpu_data[i];
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/relu_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void ReluKernel(const int num, const T* input, T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) : 0;
#else
output[index] = input[index] >= 0 ? input[index] : 0;
#endif
}
}
void ReluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
int threads = 1024;
int blocks = (num + threads - 1) / threads;
ReluKernel<<<blocks, threads, 0, stream>>>(num, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
relu, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ReluCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ReluCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~ReluCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/relu_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(relu, normal) {
ReluCompute relu_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 256, w = 256;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
relu_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
relu_kernel.SetContext(std::move(ctx));
relu_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
// for (int i = 0; i < y.numel(); i++) {
// LOG(INFO) << y_cpu_data[i];
// }
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/kernels/cuda/scale_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void ScaleCompute::Run() {
auto& param = Param<operators::ScaleParam>();
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
DDim x_dims = param.x->dims();
bool bias_after_scale = param.bias_after_scale;
float scale = param.scale;
float bias = param.bias;
if (!bias_after_scale) {
bias *= scale;
}
lite::cuda::math::scale(
x_dims.production(), x_data, output_data, scale, bias);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
scale, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ScaleCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/cuda/math/scale.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ScaleCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ScaleCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册