提交 bb185125 编写于 作者: S superjomn

add cublas

上级 4fdb49e8
......@@ -184,6 +184,9 @@ if(WITH_BRPC_RDMA)
endif()
endif()
# for lite
option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
include(external/threadpool)
include(flags) # set paddle compile flags
......
......@@ -161,3 +161,13 @@ endif(ON_INFER)
if(WITH_WBAES)
add_definitions(-DPADDLE_WITH_WBAES)
endif(WITH_WBAES)
# for lite
# TODO(Superjomn) not work fine with the option
if (LITE_WITH_CUDA)
add_definitions("-DLITE_WITH_CUDA")
endif()
if (LITE_WITH_X86)
add_definitions("-DLITE_WITH_X86")
endif()
cc_library(cxx_api_lite SRCS cxx_api.h DEPS scope_lite executor_lite host_kernels ops_lite)
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite executor_lite host_kernels ops_lite)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-4-11.
//
#include "paddle/fluid/lite/api/cxx_api.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
......@@ -13,8 +13,49 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/executor.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace paddle {
namespace lite {} // namespace lite
namespace lite {
struct Config {};
class Predictor {
public:
void Build(const std::string& model_path,
const std::vector<OpLite::Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog;
LoadModel(model_path, &scope_, &prog);
framework::ProgramDesc prog_desc(prog);
executor_.reset(new Executor(&scope_, valid_places));
executor_->PrepareWorkspace(prog_desc);
executor_->Build(prog_desc);
}
// Get a tensor for input from scope directly.
Tensor* GetInputTensor(const std::string& name) {
auto* var = executor_->exec_scope()->FindVar(name);
CHECK(var) << "no tensor called " << name << " exists";
return var->GetMutable<Tensor>();
}
// Get a tensor for output from scope directly.
const Tensor* GetOutputTensor(const std::string& name) {
auto* var = executor_->exec_scope()->FindVar(name);
CHECK(var) << "no tensor called " << name << " exists";
return &var->Get<Tensor>();
}
void Run() { executor_->Run(); }
private:
Scope scope_;
std::unique_ptr<lite::Executor> executor_;
};
} // namespace lite
} // namespace paddle
......@@ -20,7 +20,7 @@
namespace paddle {
namespace lite {
TEST(CXXApi, test) {
TEST(CXXApi, raw) {
Scope scope;
framework::proto::ProgramDesc prog;
LoadModel("/home/chunwei/project2/models/model2", &scope, &prog);
......@@ -33,11 +33,20 @@ TEST(CXXApi, test) {
x->Resize({100, 100});
x->mutable_data<float>();
executor.PrepareWorkspace(prog_desc, &scope);
executor.PrepareWorkspace(prog_desc);
executor.Build(prog_desc);
executor.Run();
}
TEST(CXXApi, test) {
lite::Predictor predictor;
predictor.Build("/home/chunwei/project2/models/model2",
{OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
auto* x = predictor.GetInputTensor("a");
x->Resize({100, 200});
x->mutable_data<float>();
}
} // namespace lite
} // namespace paddle
......
......@@ -13,10 +13,15 @@
// limitations under the License.
#pragma once
#include <paddle/fluid/lite/cuda/blas.h>
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/target_wrapper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/lite/cuda/cuda_utils.h"
#endif
namespace paddle {
namespace lite {
......@@ -75,5 +80,41 @@ class OpContext final {
std::vector<TargetType> targets_;
};
#ifdef LITE_WITH_CUDA
// Only works with CUDA kernels.
struct CUDAContext {
// overall information
cudaStream_t exec_stream;
cudaStream_t io_stream;
// not thread-safe, should allocate for each thread.
std::shared_ptr<cuda::Blas<float>> bias_fp32;
// kernel information
std::vector<cudaEvent_t> input_events;
std::vector<cudaEvent_t> output_events;
};
#endif
#ifdef LITE_WITH_X86
struct X86Context {
// overall information
// kernel information
};
#endif
// Context for running a kernel.
// Holds the necessary resource and information.
class KernelContext {
public:
#ifdef LITE_WITH_CUDA
CUDAContext cuda_ctx;
#endif
#ifdef LITE_WITH_X86
X86Context x86_ctx;
#endif
};
} // namespace lite
} // namespace paddle
......@@ -28,7 +28,7 @@ class Executor {
: scope_(scope), valid_places_(valid_places) {}
// Create temporary variables.
void PrepareWorkspace(framework::ProgramDesc& program, lite::Scope* scope) {
void PrepareWorkspace(framework::ProgramDesc& program) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
......@@ -67,6 +67,9 @@ class Executor {
}
}
lite::Scope* scope() { return scope_; }
lite::Scope* exec_scope() { return exec_scope_; }
private:
std::vector<std::unique_ptr<OpLite>> ops_;
lite::Scope* scope_{};
......
......@@ -53,7 +53,7 @@ TEST(executor, test) {
w->mutable_data<float>();
x->mutable_data<float>();
executor.PrepareWorkspace(program, &scope);
executor.PrepareWorkspace(program);
executor.Build(program);
executor.Run();
}
......
......@@ -47,6 +47,8 @@ class KernelBase {
return param_.get<Param>();
}
void Torch() {}
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
......@@ -63,16 +65,20 @@ class KernelBase {
template <TargetType Target, PrecisionType Precision>
class OpKernel : public KernelBase {
public:
virtual void Run() { CHECK(false) << "Not Implemented"; }
// Set runtime context.
void SetContext(std::unique_ptr<KernelContext>&& ctx) { ctx_ = ctx; }
void Touch() {}
// Run the kernel.
virtual void Run() { CHECK(false) << "Not Implemented"; }
TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; }
OpKernel() = default;
virtual ~OpKernel() = default;
protected:
std::unique_ptr<KernelContext> ctx_;
};
} // namespace lite
......
nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas SRCS blas.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/cuda/blas.h"
namespace paddle {
namespace lite {
namespace cuda {
template <>
class Blas<float> : public BlasBase {
using T = float;
void sgemm(cublasOperation_t transa, cublasOperation_t transb, //
int m, int n, int k, //
const T* alpha, //
const T* A, int lda, //
const T* B, int ldb, //
const T* beta, //
T* C, int ldc) const {
CUBLAS_CALL(cublasSgemm(handle(), transa, transb, m, n, k, alpha, A, lda, B,
ldb, beta, C, ldc));
}
};
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cublasXt.h>
#include <cublas_api.h>
#include <cublas_v2.h>
#include <glog/logging.h>
#include <library_types.h>
#include "paddle/fluid/lite/cuda/cuda_utils.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace cuda {
#define CUBLAS_CHECK(xxx) CHECK_EQ((xxx), CUBLAS_STATUS_SUCCESS);
/*
* Some basic methods.
*/
struct BlasBase {
BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); }
~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); }
void SetStream(cudaStream_t stream) {
CUBLAS_CHECK(cublasSetStream(handle_, stream));
}
cudaStream_t GetStream() const {
cudaStream_t stream;
CUBLAS_CHECK(cublasGetStream_v2(handle_, &stream));
return stream;
}
int GetVersion() const {
int version{};
CUBLAS_CHECK(cublasGetVersion_v2(handle_, &version));
return version;
}
cublasHandle_t& handle() const { return handle_; }
protected:
// Not thread-safe, should created for each thread.
// According to cublas doc.
mutable cublasHandle_t handle_;
};
// T: Scalar type.
template <typename T>
class Blas : public lite::cuda::BlasBase {
public:
void sgemm(cublasOperation_t transa, cublasOperation_t transb, //
int m, int n, int k, //
const T* alpha, //
const T* A, int lda, //
const T* B, int ldb, //
const T* beta, //
T* C, int ldc) const {
LITE_UNIMPLEMENTED;
}
};
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cublasXt.h>
#include <cublas_api.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <glog/logging.h>
/*
* This file contains some CUDA specific utils.
*/
// For quickly implementing the prototype, some of the following code snippets
// are borrowed from project MXNet, great thanks for the original developers.
#define CHECK_CUDA_ERROR(msg) \
{ \
auto e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}
#define CUDA_CALL(func) \
{ \
auto e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}
#define CUBLAS_CALL(func) \
{ \
auto e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \
}
namespace paddle {
namespace lite {
namespace cuda {
const char* CublasErrorInfo(int error) {
switch (error) {
#define LITE_CUBLAS_ERROR_INFO(xx) \
case xx: \
return #xx; \
break;
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_NOT_INITIALIZED);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_ALLOC_FAILED);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_INVALID_VALUE);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_ARCH_MISMATCH);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_MAPPING_ERROR);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_EXECUTION_FAILED);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_INTERNAL_ERROR);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_NOT_SUPPORTED);
LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_LICENSE_ERROR);
#undef LITE_CUBLAS_ERROR_INFO
default:
return "unknown error";
}
}
} // namespace cuda
} // namespace lite
} // namespace paddle
add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/cuda/mul_compute.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/cuda/blas.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
void mul_compute(const lite::cuda::Blas<float>& blas, const T* x, int x_h,
int x_w, const T* y, int y_h, int y_w, T* out) {
blas.sgemm(CUBLAS_OP_N, CUBLAS_OP_N, x_w, x_h, y_w, nullptr, x, 0, y, 0,
nullptr, out, 0);
}
class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {}
virtual ~MulCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -19,3 +19,5 @@
class__(const class__&) = delete; \
class__& operator=(const class__&) = delete;
#endif
#define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented";
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册