From c05a4e5871f4a51f6aa794e061ed014e4e7a1196 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 21 Apr 2019 21:53:15 +0800 Subject: [PATCH] refactor Context and link with kernel Only kernel has context, not the OpLite. --- paddle/fluid/lite/core/context.h | 88 +++++++++++--------------------- paddle/fluid/lite/core/kernel.h | 9 ++-- paddle/fluid/lite/core/op_lite.h | 5 +- paddle/fluid/lite/core/types.h | 4 -- 4 files changed, 37 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/lite/core/context.h b/paddle/fluid/lite/core/context.h index ea954fc301..a7ec521dac 100644 --- a/paddle/fluid/lite/core/context.h +++ b/paddle/fluid/lite/core/context.h @@ -25,61 +25,6 @@ namespace paddle { namespace lite { -template -class Context { - public: - using target_wrapper_t = TargetWrapper; - using stream_t = typename TargetWrapper::stream_t; - using event_t = typename TargetWrapper::event_t; - - Context() = default; - Context(int device_id, stream_t compute_stream, stream_t data_stream) - : device_id_(device_id), - compute_stream_(compute_stream), - data_stream_(data_stream) {} - - void SetDeviceId(int device_id) { device_id_ = device_id; } - void SetComputeStream(stream_t x) { compute_stream_ = x; } - void SetDataStream(stream_t x) { data_stream_ = x; } - void SetDependEvents(const std::vector& events) { - depend_events_ = events; - } - - int device_id() const { return device_id_; } - stream_t compute_stream() const { return compute_stream_; } - stream_t data_stream() const { return data_stream_; } - const std::vector& depend_events() const { return depend_events_; } - - private: - int device_id_{0}; - stream_t compute_stream_; - stream_t data_stream_; - std::vector depend_events_; -}; - -class OpContext final { - public: - template - using target_ptr_t = std::unique_ptr>; - - // @param target valid target. - explicit OpContext(TargetType target) - : targets_(std::vector({target})) {} - // @param target valid target. - explicit OpContext(const std::vector& target) - : targets_(target) {} - - const std::vector& target() const { return targets_; } - - template - target_ptr_t CreateContext() { - return target_ptr_t(new Context); - } - - private: - std::vector targets_; -}; - #ifdef LITE_WITH_CUDA // Only works with CUDA kernels. struct CUDAContext { @@ -88,7 +33,7 @@ struct CUDAContext { cudaStream_t io_stream; // not thread-safe, should allocate for each thread. - std::shared_ptr> bias_fp32; + std::shared_ptr> blas_fp32; // kernel information std::vector input_events; @@ -108,12 +53,39 @@ struct X86Context { class KernelContext { public: #ifdef LITE_WITH_CUDA - CUDAContext cuda_ctx; + CUDAContext& AsCudaContext() { + if (target_ != TARGET(kUnk)) { + CHECK(target_ == TARGET(kCUDA)); + } else { + target_ = TARGET(kCUDA); + cuda_ctx_.reset(new CUDAContext); + } + return *cuda_ctx_; + } +#endif // LITE_WITH_CUDA + +#ifdef LITE_WITH_X86 + X86Context& AsX86Context() { + if (target_ != TARGET(kUnk)) { + CHECK(target_ == TARGET(kX86)); + } else { + target_ = TARGET(kX86); + x86_ctx_.reset(new X86Context); + } + return *x86_ctx_; + } +#endif // lite_with_x86 + + private: +#ifdef LITE_WITH_CUDA + std::unique_ptr cuda_ctx_; #endif #ifdef LITE_WITH_X86 - X86Context x86_ctx; + std::unique_ptr x86_ctx_; #endif + + TargetType target_{TARGET(kUnk)}; }; } // namespace lite diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index d3308d016f..9fb7b60b34 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -35,9 +36,8 @@ class KernelBase { public: virtual void Run() = 0; - template - void SetContext(std::unique_ptr>&& ctx) { - context_.set>>(std::move(ctx)); + void SetContext(std::unique_ptr&& ctx) { + context_ = std::move(ctx); } template @@ -59,13 +59,14 @@ class KernelBase { virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; + const KernelContext* context() const { return context_.get(); } virtual std::string name() const = 0; virtual ~KernelBase() = default; protected: - core::any_context_t context_; + std::unique_ptr context_; mutable operators::param_t param_; // The corresponding op type. std::string op_type_; diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 54d973ebc7..6101366712 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -67,8 +67,8 @@ class OpLite : public Registry { OpLite() = default; OpLite(const std::string &type) : op_type_(type) {} - OpLite(std::unique_ptr &&x, const std::vector &valid_places) - : op_context_(std::move(x)), valid_places_(valid_places) {} + OpLite(const std::vector &valid_places) + : valid_places_(valid_places) {} void SetValidPlaces(const std::vector &places) { valid_places_ = places; @@ -126,7 +126,6 @@ class OpLite : public Registry { friend class mir::SSAGraph; protected: - std::unique_ptr op_context_; std::unique_ptr kernel_; std::string op_type_; std::vector valid_places_; diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h index 990ffe0007..7562a911c8 100644 --- a/paddle/fluid/lite/core/types.h +++ b/paddle/fluid/lite/core/types.h @@ -21,10 +21,6 @@ namespace paddle { namespace lite { namespace core { -using any_context_t = variant, // - Context // - >; - // Factors that impact the kernel picking strategy. Multiple factors can be // considered together by using statement like 'factor1 | factor2' class KernelPickFactor { -- GitLab