提交 c05a4e58 编写于 作者: S Superjomn

refactor Context and link with kernel

Only kernel has context, not the OpLite.
上级 8532bb4a
......@@ -25,61 +25,6 @@
namespace paddle {
namespace lite {
template <TargetType Target>
class Context {
public:
using target_wrapper_t = TargetWrapper<Target>;
using stream_t = typename TargetWrapper<Target>::stream_t;
using event_t = typename TargetWrapper<Target>::event_t;
Context() = default;
Context(int device_id, stream_t compute_stream, stream_t data_stream)
: device_id_(device_id),
compute_stream_(compute_stream),
data_stream_(data_stream) {}
void SetDeviceId(int device_id) { device_id_ = device_id; }
void SetComputeStream(stream_t x) { compute_stream_ = x; }
void SetDataStream(stream_t x) { data_stream_ = x; }
void SetDependEvents(const std::vector<event_t>& events) {
depend_events_ = events;
}
int device_id() const { return device_id_; }
stream_t compute_stream() const { return compute_stream_; }
stream_t data_stream() const { return data_stream_; }
const std::vector<event_t>& depend_events() const { return depend_events_; }
private:
int device_id_{0};
stream_t compute_stream_;
stream_t data_stream_;
std::vector<event_t> depend_events_;
};
class OpContext final {
public:
template <TargetType Target>
using target_ptr_t = std::unique_ptr<Context<Target>>;
// @param target valid target.
explicit OpContext(TargetType target)
: targets_(std::vector<TargetType>({target})) {}
// @param target valid target.
explicit OpContext(const std::vector<TargetType>& target)
: targets_(target) {}
const std::vector<TargetType>& target() const { return targets_; }
template <TargetType Target>
target_ptr_t<Target> CreateContext() {
return target_ptr_t<Target>(new Context<Target>);
}
private:
std::vector<TargetType> targets_;
};
#ifdef LITE_WITH_CUDA
// Only works with CUDA kernels.
struct CUDAContext {
......@@ -88,7 +33,7 @@ struct CUDAContext {
cudaStream_t io_stream;
// not thread-safe, should allocate for each thread.
std::shared_ptr<cuda::Blas<float>> bias_fp32;
std::shared_ptr<cuda::Blas<float>> blas_fp32;
// kernel information
std::vector<cudaEvent_t> input_events;
......@@ -108,12 +53,39 @@ struct X86Context {
class KernelContext {
public:
#ifdef LITE_WITH_CUDA
CUDAContext cuda_ctx;
CUDAContext& AsCudaContext() {
if (target_ != TARGET(kUnk)) {
CHECK(target_ == TARGET(kCUDA));
} else {
target_ = TARGET(kCUDA);
cuda_ctx_.reset(new CUDAContext);
}
return *cuda_ctx_;
}
#endif // LITE_WITH_CUDA
#ifdef LITE_WITH_X86
X86Context& AsX86Context() {
if (target_ != TARGET(kUnk)) {
CHECK(target_ == TARGET(kX86));
} else {
target_ = TARGET(kX86);
x86_ctx_.reset(new X86Context);
}
return *x86_ctx_;
}
#endif // lite_with_x86
private:
#ifdef LITE_WITH_CUDA
std::unique_ptr<CUDAContext> cuda_ctx_;
#endif
#ifdef LITE_WITH_X86
X86Context x86_ctx;
std::unique_ptr<X86Context> x86_ctx_;
#endif
TargetType target_{TARGET(kUnk)};
};
} // namespace lite
......
......@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
......@@ -35,9 +36,8 @@ class KernelBase {
public:
virtual void Run() = 0;
template <TargetType Target>
void SetContext(std::unique_ptr<Context<Target>>&& ctx) {
context_.set<std::unique_ptr<Context<Target>>>(std::move(ctx));
void SetContext(std::unique_ptr<KernelContext>&& ctx) {
context_ = std::move(ctx);
}
template <typename T>
......@@ -59,13 +59,14 @@ class KernelBase {
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
const KernelContext* context() const { return context_.get(); }
virtual std::string name() const = 0;
virtual ~KernelBase() = default;
protected:
core::any_context_t context_;
std::unique_ptr<KernelContext> context_;
mutable operators::param_t param_;
// The corresponding op type.
std::string op_type_;
......
......@@ -67,8 +67,8 @@ class OpLite : public Registry {
OpLite() = default;
OpLite(const std::string &type) : op_type_(type) {}
OpLite(std::unique_ptr<OpContext> &&x, const std::vector<Place> &valid_places)
: op_context_(std::move(x)), valid_places_(valid_places) {}
OpLite(const std::vector<Place> &valid_places)
: valid_places_(valid_places) {}
void SetValidPlaces(const std::vector<Place> &places) {
valid_places_ = places;
......@@ -126,7 +126,6 @@ class OpLite : public Registry {
friend class mir::SSAGraph;
protected:
std::unique_ptr<OpContext> op_context_;
std::unique_ptr<KernelBase> kernel_;
std::string op_type_;
std::vector<Place> valid_places_;
......
......@@ -21,10 +21,6 @@ namespace paddle {
namespace lite {
namespace core {
using any_context_t = variant<Context<TARGET(kX86)>, //
Context<TARGET(kCUDA)> //
>;
// Factors that impact the kernel picking strategy. Multiple factors can be
// considered together by using statement like 'factor1 | factor2'
class KernelPickFactor {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册