From bb18512523e05c37dfa76747c4405c22132824e0 Mon Sep 17 00:00:00 2001 From: superjomn Date: Fri, 12 Apr 2019 21:44:41 +0800 Subject: [PATCH] add cublas --- CMakeLists.txt | 3 + cmake/configure.cmake | 10 +++ paddle/fluid/lite/api/CMakeLists.txt | 2 +- paddle/fluid/lite/api/cxx_api.cc | 8 +- paddle/fluid/lite/api/cxx_api.h | 43 +++++++++- paddle/fluid/lite/api/cxx_api_test.cc | 13 +++- paddle/fluid/lite/core/context.h | 41 ++++++++++ paddle/fluid/lite/core/executor.h | 5 +- paddle/fluid/lite/core/executor_test.cc | 2 +- paddle/fluid/lite/core/kernel.h | 12 ++- paddle/fluid/lite/cuda/CMakeLists.txt | 1 + paddle/fluid/lite/cuda/blas.cc | 39 ++++++++++ paddle/fluid/lite/cuda/blas.h | 78 +++++++++++++++++++ paddle/fluid/lite/cuda/cuda_utils.h | 76 ++++++++++++++++++ paddle/fluid/lite/kernels/CMakeLists.txt | 1 + paddle/fluid/lite/kernels/cuda/CMakeLists.txt | 1 + paddle/fluid/lite/kernels/cuda/mul_compute.cc | 15 ++++ paddle/fluid/lite/kernels/cuda/mul_compute.h | 44 +++++++++++ paddle/fluid/lite/utils/macros.h | 2 + 19 files changed, 383 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/lite/cuda/blas.cc create mode 100644 paddle/fluid/lite/cuda/blas.h create mode 100644 paddle/fluid/lite/cuda/cuda_utils.h create mode 100644 paddle/fluid/lite/kernels/cuda/CMakeLists.txt create mode 100644 paddle/fluid/lite/kernels/cuda/mul_compute.cc create mode 100644 paddle/fluid/lite/kernels/cuda/mul_compute.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b77659f6..a343a6591 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -184,6 +184,9 @@ if(WITH_BRPC_RDMA) endif() endif() +# for lite +option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON) +option(LITE_WITH_X86 "Enable X86 in lite mode" ON) include(external/threadpool) include(flags) # set paddle compile flags diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 283845541..f859fd10a 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -161,3 +161,13 @@ endif(ON_INFER) if(WITH_WBAES) add_definitions(-DPADDLE_WITH_WBAES) endif(WITH_WBAES) + +# for lite +# TODO(Superjomn) not work fine with the option +if (LITE_WITH_CUDA) +add_definitions("-DLITE_WITH_CUDA") +endif() + +if (LITE_WITH_X86) + add_definitions("-DLITE_WITH_X86") +endif() diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index cae0912bd..9997b83ee 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(cxx_api_lite SRCS cxx_api.h DEPS scope_lite executor_lite host_kernels ops_lite) +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite executor_lite host_kernels ops_lite) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) diff --git a/paddle/fluid/lite/api/cxx_api.cc b/paddle/fluid/lite/api/cxx_api.cc index 81450cb8d..35a0373a2 100644 --- a/paddle/fluid/lite/api/cxx_api.cc +++ b/paddle/fluid/lite/api/cxx_api.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// -// Created by chunwei on 19-4-11. -// - #include "paddle/fluid/lite/api/cxx_api.h" + +namespace paddle { +namespace lite {} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index 9c304b036..255d55ced 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -13,8 +13,49 @@ // limitations under the License. #pragma once +#include "paddle/fluid/lite/core/executor.h" +#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/model_parser/model_parser.h" namespace paddle { -namespace lite {} // namespace lite +namespace lite { + +struct Config {}; + +class Predictor { + public: + void Build(const std::string& model_path, + const std::vector& valid_places) { + CHECK(!executor_.get()) << "duplicate build found"; + framework::proto::ProgramDesc prog; + LoadModel(model_path, &scope_, &prog); + framework::ProgramDesc prog_desc(prog); + + executor_.reset(new Executor(&scope_, valid_places)); + executor_->PrepareWorkspace(prog_desc); + executor_->Build(prog_desc); + } + + // Get a tensor for input from scope directly. + Tensor* GetInputTensor(const std::string& name) { + auto* var = executor_->exec_scope()->FindVar(name); + CHECK(var) << "no tensor called " << name << " exists"; + return var->GetMutable(); + } + + // Get a tensor for output from scope directly. + const Tensor* GetOutputTensor(const std::string& name) { + auto* var = executor_->exec_scope()->FindVar(name); + CHECK(var) << "no tensor called " << name << " exists"; + return &var->Get(); + } + + void Run() { executor_->Run(); } + + private: + Scope scope_; + std::unique_ptr executor_; +}; + +} // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 09fd7a78b..157bf41ee 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -20,7 +20,7 @@ namespace paddle { namespace lite { -TEST(CXXApi, test) { +TEST(CXXApi, raw) { Scope scope; framework::proto::ProgramDesc prog; LoadModel("/home/chunwei/project2/models/model2", &scope, &prog); @@ -33,11 +33,20 @@ TEST(CXXApi, test) { x->Resize({100, 100}); x->mutable_data(); - executor.PrepareWorkspace(prog_desc, &scope); + executor.PrepareWorkspace(prog_desc); executor.Build(prog_desc); executor.Run(); } +TEST(CXXApi, test) { + lite::Predictor predictor; + predictor.Build("/home/chunwei/project2/models/model2", + {OpLite::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto* x = predictor.GetInputTensor("a"); + x->Resize({100, 200}); + x->mutable_data(); +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/context.h b/paddle/fluid/lite/core/context.h index abf00e53d..ea954fc30 100644 --- a/paddle/fluid/lite/core/context.h +++ b/paddle/fluid/lite/core/context.h @@ -13,10 +13,15 @@ // limitations under the License. #pragma once +#include #include #include #include "paddle/fluid/lite/core/target_wrapper.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/lite/cuda/cuda_utils.h" +#endif + namespace paddle { namespace lite { @@ -75,5 +80,41 @@ class OpContext final { std::vector targets_; }; +#ifdef LITE_WITH_CUDA +// Only works with CUDA kernels. +struct CUDAContext { + // overall information + cudaStream_t exec_stream; + cudaStream_t io_stream; + + // not thread-safe, should allocate for each thread. + std::shared_ptr> bias_fp32; + + // kernel information + std::vector input_events; + std::vector output_events; +}; +#endif + +#ifdef LITE_WITH_X86 +struct X86Context { + // overall information + // kernel information +}; +#endif + +// Context for running a kernel. +// Holds the necessary resource and information. +class KernelContext { + public: +#ifdef LITE_WITH_CUDA + CUDAContext cuda_ctx; +#endif + +#ifdef LITE_WITH_X86 + X86Context x86_ctx; +#endif +}; + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/executor.h b/paddle/fluid/lite/core/executor.h index b87cb232d..19c9e2767 100644 --- a/paddle/fluid/lite/core/executor.h +++ b/paddle/fluid/lite/core/executor.h @@ -28,7 +28,7 @@ class Executor { : scope_(scope), valid_places_(valid_places) {} // Create temporary variables. - void PrepareWorkspace(framework::ProgramDesc& program, lite::Scope* scope) { + void PrepareWorkspace(framework::ProgramDesc& program) { CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found"; exec_scope_ = &scope_->NewScope(); @@ -67,6 +67,9 @@ class Executor { } } + lite::Scope* scope() { return scope_; } + lite::Scope* exec_scope() { return exec_scope_; } + private: std::vector> ops_; lite::Scope* scope_{}; diff --git a/paddle/fluid/lite/core/executor_test.cc b/paddle/fluid/lite/core/executor_test.cc index c52b76555..17b674108 100644 --- a/paddle/fluid/lite/core/executor_test.cc +++ b/paddle/fluid/lite/core/executor_test.cc @@ -53,7 +53,7 @@ TEST(executor, test) { w->mutable_data(); x->mutable_data(); - executor.PrepareWorkspace(program, &scope); + executor.PrepareWorkspace(program); executor.Build(program); executor.Run(); } diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index b2208a015..3695420d6 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -47,6 +47,8 @@ class KernelBase { return param_.get(); } + void Torch() {} + virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; @@ -63,16 +65,20 @@ class KernelBase { template class OpKernel : public KernelBase { public: - virtual void Run() { CHECK(false) << "Not Implemented"; } + // Set runtime context. + void SetContext(std::unique_ptr&& ctx) { ctx_ = ctx; } - void Touch() {} + // Run the kernel. + virtual void Run() { CHECK(false) << "Not Implemented"; } TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } OpKernel() = default; - virtual ~OpKernel() = default; + + protected: + std::unique_ptr ctx_; }; } // namespace lite diff --git a/paddle/fluid/lite/cuda/CMakeLists.txt b/paddle/fluid/lite/cuda/CMakeLists.txt index 4ec76c2a5..1ac05bc7c 100644 --- a/paddle/fluid/lite/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/cuda/CMakeLists.txt @@ -1 +1,2 @@ nv_library(target_wrapper_cuda SRCS target_wrapper.cc) +nv_library(cuda_blas SRCS blas.cc) diff --git a/paddle/fluid/lite/cuda/blas.cc b/paddle/fluid/lite/cuda/blas.cc new file mode 100644 index 000000000..87f252fc7 --- /dev/null +++ b/paddle/fluid/lite/cuda/blas.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/cuda/blas.h" + +namespace paddle { +namespace lite { +namespace cuda { + +template <> +class Blas : public BlasBase { + using T = float; + + void sgemm(cublasOperation_t transa, cublasOperation_t transb, // + int m, int n, int k, // + const T* alpha, // + const T* A, int lda, // + const T* B, int ldb, // + const T* beta, // + T* C, int ldc) const { + CUBLAS_CALL(cublasSgemm(handle(), transa, transb, m, n, k, alpha, A, lda, B, + ldb, beta, C, ldc)); + } +}; + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/cuda/blas.h b/paddle/fluid/lite/cuda/blas.h new file mode 100644 index 000000000..3a8d2fd92 --- /dev/null +++ b/paddle/fluid/lite/cuda/blas.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/cuda/cuda_utils.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace cuda { + +#define CUBLAS_CHECK(xxx) CHECK_EQ((xxx), CUBLAS_STATUS_SUCCESS); + +/* + * Some basic methods. + */ +struct BlasBase { + BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); } + ~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); } + + void SetStream(cudaStream_t stream) { + CUBLAS_CHECK(cublasSetStream(handle_, stream)); + } + + cudaStream_t GetStream() const { + cudaStream_t stream; + CUBLAS_CHECK(cublasGetStream_v2(handle_, &stream)); + return stream; + } + + int GetVersion() const { + int version{}; + CUBLAS_CHECK(cublasGetVersion_v2(handle_, &version)); + return version; + } + + cublasHandle_t& handle() const { return handle_; } + + protected: + // Not thread-safe, should created for each thread. + // According to cublas doc. + mutable cublasHandle_t handle_; +}; + +// T: Scalar type. +template +class Blas : public lite::cuda::BlasBase { + public: + void sgemm(cublasOperation_t transa, cublasOperation_t transb, // + int m, int n, int k, // + const T* alpha, // + const T* A, int lda, // + const T* B, int ldb, // + const T* beta, // + T* C, int ldc) const { + LITE_UNIMPLEMENTED; + } +}; + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/cuda/cuda_utils.h b/paddle/fluid/lite/cuda/cuda_utils.h new file mode 100644 index 000000000..bdeaaebbe --- /dev/null +++ b/paddle/fluid/lite/cuda/cuda_utils.h @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include + +/* + * This file contains some CUDA specific utils. + */ + +// For quickly implementing the prototype, some of the following code snippets +// are borrowed from project MXNet, great thanks for the original developers. + +#define CHECK_CUDA_ERROR(msg) \ + { \ + auto e = cudaGetLastError(); \ + CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ + } + +#define CUDA_CALL(func) \ + { \ + auto e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ + } + +#define CUBLAS_CALL(func) \ + { \ + auto e = (func); \ + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ + << "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \ + } + +namespace paddle { +namespace lite { +namespace cuda { + +const char* CublasErrorInfo(int error) { + switch (error) { +#define LITE_CUBLAS_ERROR_INFO(xx) \ + case xx: \ + return #xx; \ + break; + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_NOT_INITIALIZED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_ALLOC_FAILED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_INVALID_VALUE); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_ARCH_MISMATCH); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_MAPPING_ERROR); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_EXECUTION_FAILED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_INTERNAL_ERROR); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_NOT_SUPPORTED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_LICENSE_ERROR); +#undef LITE_CUBLAS_ERROR_INFO + default: + return "unknown error"; + } +} + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index 1401be5e5..a7a894de1 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(host) add_subdirectory(arm) +add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt new file mode 100644 index 000000000..6814f3f51 --- /dev/null +++ b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.cc b/paddle/fluid/lite/kernels/cuda/mul_compute.cc new file mode 100644 index 000000000..c80851bf6 --- /dev/null +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/cuda/mul_compute.h" diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h new file mode 100644 index 000000000..e8d65a93d --- /dev/null +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/cuda/blas.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void mul_compute(const lite::cuda::Blas& blas, const T* x, int x_h, + int x_w, const T* y, int y_h, int y_w, T* out) { + blas.sgemm(CUBLAS_OP_N, CUBLAS_OP_N, x_w, x_h, y_w, nullptr, x, 0, y, 0, + nullptr, out, 0); +} + +class MulCompute : public OpKernel { + public: + using param_t = operators::MulParam; + + void Run() override {} + + virtual ~MulCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/utils/macros.h b/paddle/fluid/lite/utils/macros.h index 52ad44c70..1115c71cd 100644 --- a/paddle/fluid/lite/utils/macros.h +++ b/paddle/fluid/lite/utils/macros.h @@ -19,3 +19,5 @@ class__(const class__&) = delete; \ class__& operator=(const class__&) = delete; #endif + +#define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented"; -- GitLab