diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 1e6f34a62129e2ca0a717ceb489d98b56b78d47a..26e8d28fcd430e7642e07c4375f90c67c10cbaba 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -6,7 +6,7 @@ set(paddle_known_gpu_archs "30 35 50 52 60 61 70") set(paddle_known_gpu_archs7 "30 35 50 52") set(paddle_known_gpu_archs8 "30 35 50 52 60 61") set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") -set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") +set(paddle_known_gpu_archs10 "30 35 50 52 60 61 62 70 75") ###################################################################################### # A function for automatic detection of GPUs installed (if autodetection is enabled) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 3775d6cc2bdaa617f225b4cff9a03092bd9a19cc..a8a4d34fe5f01a4f0fedf2b9c2a09d7c4383bd25 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -34,6 +34,14 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 ) + +if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0)) + find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH) + set(CUBLAS_LIBRARIES ${CUBLAS_LIBRARY}) +else() + set(CUBLAS_LIBRARIES ${CUDA_CUBLAS_LIBRARIES}) +endif() + set(CUDNN_LIB_NAME "libcudnn.so") if(WIN32) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 36b533aa4f7815896fb48c33fefad892b8d0d29c..903c70fbbff285bc90697281f9703b544fd00186 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -146,8 +146,11 @@ set(GPU_COMMON_FLAGS -Wno-error=unused-local-typedefs -Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=array-bounds # Warnings in Eigen::array + -gencode arch=compute_62,code=sm_62 ) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +if(NOT LITE_WITH_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +endif() endif(NOT WIN32) if (APPLE) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index f1b73f1d882afc3a2352514906e5958c70eee542..415eb451a986cd7e59829b9a8f2c744ecf464bd6 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -507,7 +507,7 @@ function(nv_test TARGET_NAME) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest -gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY}) +gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} ) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog) common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) diff --git a/lite/backends/cuda/blas.h b/lite/backends/cuda/blas.h index f73bb576b8dd5ecad178ba69a9208b2286c050ab..058b961f3678197a3a6719a3337e0decac78564f 100644 --- a/lite/backends/cuda/blas.h +++ b/lite/backends/cuda/blas.h @@ -30,10 +30,8 @@ namespace cuda { * Some basic methods. */ struct BlasBase { - /* BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); } ~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); } - */ void SetStream(cudaStream_t stream) { CUBLAS_CHECK(cublasSetStream(handle_, stream)); diff --git a/lite/backends/cuda/math/scale.cu b/lite/backends/cuda/math/scale.cu index 9ab8f91779ebcc5259a99cf5a415ce13d4cfcebb..806a3697a2eb19354a81056f0a7ab6272ed991a1 100644 --- a/lite/backends/cuda/math/scale.cu +++ b/lite/backends/cuda/math/scale.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "iostream" +#include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/math/scale.h" #include "lite/backends/cuda/math/utils.h" @@ -21,18 +22,36 @@ namespace lite { namespace cuda { namespace math { +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + template -__global__ void scale_kernel(int num, const T* in, T* out, const float scale) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < num) { -#if __CUDA_ARCH__ >= 350 - out[tid] = __ldg(in + tid) * scale; -#else - out[tid] = in[tid] * scale; -#endif +__global__ void scale_kernel(int count, + const T* in_data, + T* out_data, + const T* scale_data, + const T* bias_data, + const int scale_dim, + const int inner_dim) { + CUDA_KERNEL_LOOP(tid, count) { + int scale_id = (tid / inner_dim) % scale_dim; + T scale = scale_data[scale_id]; + if (bias_data == nullptr) { + out_data[tid] = scale * in_data[tid]; + } else { + out_data[tid] = scale * in_data[tid] + bias_data[scale_id]; + } } } +template +__global__ void scale_kernel( + int count, const T* in_data, T* out_data, const T scale, const T bias) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; } +} + __global__ void fp32_scale_nhwc4_kernel(int num, const float4* in, float4* out, @@ -114,21 +133,25 @@ void fp32_scale_nhwc(int num, } template -void scale(int num, const T* in, T* out, float scale, cudaStream_t stream) { +void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) { int thread = 256; int block = (num + thread - 1) / thread; - scale_kernel<<>>(num, in, out, scale); + scale_kernel<<>>(num, in, out, scale, bias); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); } template -void scale(int num, const T* in, T* out, float scale) { +void scale(int num, const T* in, T* out, T scale, T bias) { int thread = 256; int block = (num + thread - 1) / thread; - scale_kernel<<>>(num, in, out, scale); + scale_kernel<<>>(num, in, out, scale, bias); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); } -template void scale(int num, const float*, float*, float, cudaStream_t); -template void scale(int num, const float*, float*, float); +template void scale(int num, const float*, float*, float, cudaStream_t, float); +template void scale(int num, const float*, float*, float, float); } // namespace math } // namespace cuda diff --git a/lite/backends/cuda/math/scale.h b/lite/backends/cuda/math/scale.h index 83af600ba8c68a236fdb2a5c9f8521199b46f633..52ed1d38ae79ce11cac50a9abef0f57e6de1352c 100644 --- a/lite/backends/cuda/math/scale.h +++ b/lite/backends/cuda/math/scale.h @@ -32,10 +32,11 @@ void fp32_scale_nhwc(int num, cudaStream_t stream); template -void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); +void scale( + int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0); template -void scale(int num, const T* in, T* out, float scale); +void scale(int num, const T* in, T* out, T scale, T bias = 0); } // namespace math } // namespace cuda diff --git a/lite/core/mir/node.cc b/lite/core/mir/node.cc index 61d3d317e7b7bbbfc4064cfbe0f2503f8fbe7a31..4a90e530a46c4d42d2ba032da1828973dfc1bcef 100644 --- a/lite/core/mir/node.cc +++ b/lite/core/mir/node.cc @@ -54,11 +54,6 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, valid_kernels_ = op_->CreateKernels(valid_places); } -std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) { - os << "Statement " << other.op_type() << " " << other.place().DebugString(); - return os; -} - mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { auto &x = AsArg(); x.name = name; diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h index 9c7d441ca3811d39b8ba9f5b49746c9a31c1d449..60fa1fb1ebe49e1be38a7d84cb82545389ea4aac 100644 --- a/lite/core/mir/node.h +++ b/lite/core/mir/node.h @@ -74,7 +74,11 @@ class Node { KernelBase& picked_kernel(); - friend std::ostream& operator<<(std::ostream& os, const Stmt& other); + friend std::ostream& operator<<(std::ostream& os, const Stmt& other) { + os << "Statement " << other.op_type() << " " + << other.place().DebugString(); + return os; + } // Description. std::string desc; diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index dd5676e6430069297cdd3527900bce69c59f3dfb..67f55881ce4010d1179d9b6013aa560c56dd949e 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels") add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose) add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) @@ -16,6 +17,8 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute. add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale) +add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale) add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) @@ -24,6 +27,7 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_ nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) +nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) @@ -31,4 +35,7 @@ nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc D nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda) nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) +nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) +nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) +nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) diff --git a/lite/kernels/cuda/conv_compute.cc b/lite/kernels/cuda/conv_compute.cc index cee7d9593c576365d62f271d42c0bc4ad356650f..eea81602ddf94158250aecf01fe5e95193bf58c1 100644 --- a/lite/kernels/cuda/conv_compute.cc +++ b/lite/kernels/cuda/conv_compute.cc @@ -111,6 +111,28 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kNCHW))}) .Finalize(); +REGISTER_LITE_KERNEL(depthwise_conv2d, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ConvCompute, + def) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + REGISTER_LITE_KERNEL( conv2d, kCUDA, diff --git a/lite/kernels/cuda/dropout_compute.cc b/lite/kernels/cuda/dropout_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e3a3a62432f3bc5f2e62112b2b220abc17ee2bd --- /dev/null +++ b/lite/kernels/cuda/dropout_compute.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/dropout_compute.h" +#include +#include "lite/backends/cuda/math/scale.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void DropoutCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* out_data = param.output->mutable_data(TARGET(kCUDA)); + int num = param.x->dims().production(); + const float prob_data = param.dropout_prob; + float scale = 1.0f; + if (param.dropout_implementation == "downgrade_in_infer") { + scale = 1.0f - prob_data; + } + lite::cuda::math::scale(num, x_data, out_data, scale, 0); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(dropout, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::DropoutCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/dropout_compute.h b/lite/kernels/cuda/dropout_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..aec0d1966bc8b368484b5c810f133a8e9a6fb410 --- /dev/null +++ b/lite/kernels/cuda/dropout_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class DropoutCompute : public KernelLite { + public: + void Run() override; + + virtual ~DropoutCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/dropout_compute_test.cc b/lite/kernels/cuda/dropout_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e6ed54330c0a109091934ebe48ed341afcae96f9 --- /dev/null +++ b/lite/kernels/cuda/dropout_compute_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/dropout_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void dropout_compute_ref(const operators::DropoutParam& param) { + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + int num = param.x->dims().production(); + const float prob_data = param.dropout_prob; + if (param.dropout_implementation.compare( + std::string({"downgrade_in_infer"})) == 0) { + float scale = 1.0 - prob_data; + for (int i = 0; i < num; i++) { + output_data[i] = x_data[i] * scale; + } + } else { + for (int i = 0; i < num; i++) { + output_data[i] = x_data[i]; + } + } +} + +TEST(dropout_cuda, normal) { + DropoutCompute dropout_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::DropoutParam param; + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor x_ref; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4}) { + for (auto c : {1, 3, 4, 256}) { + for (auto h : {1, 3, 4, 6}) { + for (auto w : {1, 3, 4, 6}) { + for (auto prob : {0.2f, 0.8f}) + for (auto impl : {std::string({"downgrade_in_infer"})}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + x_ref.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_cpu.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + auto* output_data = output.mutable_data(TARGET(kCUDA)); + auto* output_cpu_data = output_cpu.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + + for (int i = 0; i < x.dims().production(); i++) { + x_cpu_data[i] = i; + x_ref_data[i] = i; + } + + x.Assign(x_cpu_data, + x_cpu.dims()); + + param.x = &x; + param.output = &output; + param.dropout_prob = prob; + param.dropout_implementation = impl; + dropout_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + dropout_kernel.SetContext(std::move(ctx)); + dropout_kernel.Launch(); + + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + + param.x = &x_ref; + param.output = &output_ref; + dropout_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_add_compute.cu b/lite/kernels/cuda/elementwise_add_compute.cu index 390dacc7bcc96b65f9de5c4bb4c2aaea090e977e..4bacf532a2b67168679449200b1af721b7a282c8 100644 --- a/lite/kernels/cuda/elementwise_add_compute.cu +++ b/lite/kernels/cuda/elementwise_add_compute.cu @@ -29,12 +29,7 @@ void ElementwiseAddCompute::Run() { const lite::Tensor* y = param.Y; lite::Tensor* out = param.Out; - CHECK(x->dims() == y->dims()); - - const int n = x->dims()[0]; - const int c = x->dims()[1]; - const int h = x->dims()[2]; - const int w = x->dims()[3]; + CHECK(x->dims().production() == y->dims().production()); auto* x_data = x->data(); auto* y_data = y->data(); @@ -57,12 +52,7 @@ void ElementwiseAddComputeNHWC::Run() { const lite::Tensor* y = param.Y; lite::Tensor* out = param.Out; - CHECK(x->dims() == y->dims()); - - const int n = x->dims()[0]; - const int c = x->dims()[1]; - const int h = x->dims()[2]; - const int w = x->dims()[3]; + CHECK(x->dims().production() == y->dims().production()); auto* x_data = x->data(); auto* y_data = y->data(); @@ -85,7 +75,7 @@ void ElementwiseAddComputeInt8::Run() { const lite::Tensor* y = param.Y; lite::Tensor* out = param.Out; - CHECK(x->dims() == y->dims()); + CHECK(x->dims().production() == y->dims().production()); const int c = x->dims()[3]; diff --git a/lite/kernels/cuda/mul_compute.h b/lite/kernels/cuda/mul_compute.h index 4a542104d6743b52758cbecfb11c025628e46333..c2fc4364ef77742858b143734d2ecf4d13e201e9 100644 --- a/lite/kernels/cuda/mul_compute.h +++ b/lite/kernels/cuda/mul_compute.h @@ -33,19 +33,36 @@ void mul_compute(const lite::cuda::Blas& blas, int y_h, int y_w, T* out) { + float alpha = 1.0; + float beta = 0.0; + /* blas.sgemm(CUBLAS_OP_N, CUBLAS_OP_N, x_h, y_w, x_w, - nullptr, + &alpha, x, x_w, y, y_w, - nullptr, + &beta, out, x_h); + */ + blas.sgemm(CUBLAS_OP_N, + CUBLAS_OP_N, + y_w, + x_h, + y_h, + &alpha, + y, + y_w, + x, + x_w, + &beta, + out, + y_w); } class MulCompute : public KernelLite { @@ -56,23 +73,29 @@ class MulCompute : public KernelLite { CHECK(ctx_) << "running context should be set first"; auto& context = this->ctx_->template As(); CHECK(context.cublas_fp32()) << "blas should init first"; - /* auto& blas = *context.cublas_fp32(); - CHECK(param.x->target() == TARGET(kCUDA)); - auto* x = param.x->data(); - int x_h = param.x->dims()[0]; - int x_w = param.x->dims()[1]; - auto* y = param.y->data(); - int y_h = param.y->dims()[0]; - int y_w = param.y->dims()[1]; - */ + auto& param = this->Param(); + const auto* x_data = param.x->data(); + const auto* y_data = param.y->data(); + auto* out_data = param.output->mutable_data(TARGET(kCUDA)); - const auto& param = Param(); - param.output->mutable_data(TARGET(kCUDA)); - LOG(INFO) << "mul output memory size " << param.output->data_size(); + int x_h = static_cast( + param.x->dims().Slice(0, param.x_num_col_dims).production()); + int x_w = static_cast( + param.x->dims() + .Slice(param.x_num_col_dims, param.x->dims().size()) + .production()); + int y_h = static_cast( + param.y->dims().Slice(0, param.y_num_col_dims).production()); + int y_w = static_cast( + param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production()); + CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; + LOG(INFO) << x_h << " " << x_w << " " << y_h << " " << y_w; - // mul_compute(blas, x, x_h, x_w, y, y_h, y_w, out); + mul_compute(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data); } virtual ~MulCompute() = default; diff --git a/lite/kernels/cuda/mul_compute_test.cc b/lite/kernels/cuda/mul_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1c1d63e7dcd46f84cd128fc5b855da2098e179d --- /dev/null +++ b/lite/kernels/cuda/mul_compute_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/mul_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(mul_compute, normal) { + MulCompute mul_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + Tensor x, y, out, x_cpu, y_cpu, out_cpu; + int x_h = 2, x_w_y_h = 3, y_w = 4; + out.Resize({x_h, y_w}); + x_cpu.Resize({x_h, x_w_y_h}); + y_cpu.Resize({x_w_y_h, y_w}); + out_cpu.Resize({x_h, y_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i + 1.0; + } + for (int i = 0; i < y_cpu.numel(); i++) { + y_cpu_data[i] = i + 1.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + operators::MulParam param; + param.x = &x; + param.y = &y; + param.output = &out; + mul_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + mul_kernel.SetContext(std::move(ctx)); + mul_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out_cpu.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/relu_compute.cu b/lite/kernels/cuda/relu_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..7c6623a4fe3bc68408a90c7ed2a2e9e35d7061fb --- /dev/null +++ b/lite/kernels/cuda/relu_compute.cu @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/relu_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void ReluKernel(const int num, const T* input, T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { +#if __CUDA_ARCH__ >= 350 + output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) : 0; +#else + output[index] = input[index] >= 0 ? input[index] : 0; +#endif + } +} + +void ReluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + int threads = 1024; + int blocks = (num + threads - 1) / threads; + ReluKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + relu, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ReluCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/relu_compute.h b/lite/kernels/cuda/relu_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b0fd500ff4369fc4a4ca256153aa5f0d21cf1e8e --- /dev/null +++ b/lite/kernels/cuda/relu_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~ReluCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/relu_compute_test.cc b/lite/kernels/cuda/relu_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..39144bfda13a9eac4ac7ad65d3d426d528fc2beb --- /dev/null +++ b/lite/kernels/cuda/relu_compute_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/relu_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(relu, normal) { + ReluCompute relu_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 256, w = 256; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = x_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 5.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + relu_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + relu_kernel.SetContext(std::move(ctx)); + relu_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + // for (int i = 0; i < y.numel(); i++) { + // LOG(INFO) << y_cpu_data[i]; + // } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/scale_compute.cc b/lite/kernels/cuda/scale_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bf7414d8c85383a834159678cdd5a09e0b434d9 --- /dev/null +++ b/lite/kernels/cuda/scale_compute.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/kernels/cuda/scale_compute.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void ScaleCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; + } + lite::cuda::math::scale( + x_dims.production(), x_data, output_data, scale, bias); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + scale, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ScaleCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/scale_compute.h b/lite/kernels/cuda/scale_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd082122a7e16762c790c8f360e2e0d7939496c --- /dev/null +++ b/lite/kernels/cuda/scale_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/backends/cuda/math/scale.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ScaleCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScaleCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle