From 44dbc0bc8a902648def8bf7c7bcad3eae4fde93a Mon Sep 17 00:00:00 2001 From: sangoly Date: Mon, 3 Jun 2019 08:34:26 +0800 Subject: [PATCH] refactor context to support both server and light (#17762) --- paddle/fluid/lite/api/CMakeLists.txt | 25 +-- paddle/fluid/lite/api/light_api.h | 2 + paddle/fluid/lite/api/light_api_test.cc | 15 ++ paddle/fluid/lite/core/CMakeLists.txt | 1 + paddle/fluid/lite/core/context.cc | 29 +-- paddle/fluid/lite/core/context.h | 195 ++++++++++++++++-- paddle/fluid/lite/core/context_test.cc | 51 +++++ .../core/mir/runtime_context_assign_pass.cc | 75 +------ paddle/fluid/lite/kernels/cuda/mul_compute.h | 4 +- .../lite/kernels/x86/activation_compute.cc | 8 +- .../lite/kernels/x86/elementwise_compute.cc | 14 +- .../lite/kernels/x86/fill_constant_compute.cc | 4 +- paddle/fluid/lite/kernels/x86/mean_compute.cc | 8 +- paddle/fluid/lite/kernels/x86/mul_compute.cc | 8 +- paddle/fluid/lite/operators/CMakeLists.txt | 5 +- paddle/fluid/lite/operators/fc_op_test.cc | 8 +- 16 files changed, 308 insertions(+), 144 deletions(-) create mode 100644 paddle/fluid/lite/core/context_test.cc diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index c95e7f65b86..7f3bc8c2f36 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -25,24 +25,25 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING "A path setting inference demo download directories.") -# lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc -# DEPS cxx_api_lite model_parser_lite target_wrapper_host -# PROFILE_DEPS basic_profiler_lite -# ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model -# --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) - - if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc - DEPS cxx_api_lite model_parser_lite target_wrapper_host - ${ops_lite} ${host_kernels} ${x86_kernels} - PROFILE_DEPS basic_profiler_lite - ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model + DEPS cxx_api_lite model_parser_lite target_wrapper_host + ${ops_lite} ${host_kernels} ${x86_kernels} + PROFILE_DEPS basic_profiler_lite + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) -endif() +endif(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) + add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) +endif(WITH_TESTING) + +# if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) +# endif() lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc diff --git a/paddle/fluid/lite/api/light_api.h b/paddle/fluid/lite/api/light_api.h index 9cd9f62a0b0..e79bc588dec 100644 --- a/paddle/fluid/lite/api/light_api.h +++ b/paddle/fluid/lite/api/light_api.h @@ -22,6 +22,7 @@ #include #include #include +#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/model_parser/model_parser.h" @@ -84,6 +85,7 @@ class LightPredictor { return it->alias() == alias; }); CHECK(it != kernels.end()); + (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); insts.emplace_back(op, std::move(*it)); } program_.reset(new RuntimeProgram(std::move(insts))); diff --git a/paddle/fluid/lite/api/light_api_test.cc b/paddle/fluid/lite/api/light_api_test.cc index 600a7b62c6b..b1e6741e09e 100644 --- a/paddle/fluid/lite/api/light_api_test.cc +++ b/paddle/fluid/lite/api/light_api_test.cc @@ -44,3 +44,18 @@ USE_LITE_OP(scale); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); + +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index e1a23fa31e6..d37e44c733f 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -54,3 +54,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li #lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) +lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator) diff --git a/paddle/fluid/lite/core/context.cc b/paddle/fluid/lite/core/context.cc index 2c36c60159c..cd7006f4724 100644 --- a/paddle/fluid/lite/core/context.cc +++ b/paddle/fluid/lite/core/context.cc @@ -33,7 +33,7 @@ namespace lite { #ifdef LITE_WITH_ARM -void ARMContext::SetCache(int l1size, int l2size, int l3size) { +void Context::SetCache(int l1size, int l2size, int l3size) { DeviceInfo& dev = DeviceInfo::Global(); int cpu_count = arm_get_cpucount(); dev.L1_cache_.resize(cpu_count); @@ -47,7 +47,7 @@ void ARMContext::SetCache(int l1size, int l2size, int l3size) { workspace_.Resize({2 * (l1size + l2size)}); } -ARMContext::ARMContext() { +Context::Context() { active_ids_ = {0}; mode_ = LITE_POWER_HIGH; DeviceInfo& dev = DeviceInfo::Global(); @@ -62,11 +62,11 @@ ARMContext::ARMContext() { #endif } -PowerMode ARMContext::mode() const { return mode_; } +PowerMode Context::mode() const { return mode_; } -int ARMContext::threads() const { return active_ids_.size(); } +int Context::threads() const { return active_ids_.size(); } -ARMContext::ARMContext(const ARMContext& ctx) { +Context::Context(const ARMContext& ctx) { mode_ = ctx.mode_; active_ids_ = ctx.active_ids_; workspace_ = ctx.workspace_; @@ -74,7 +74,7 @@ ARMContext::ARMContext(const ARMContext& ctx) { count_ = ctx.count_; } -ARMContext& ARMContext::operator=(const ARMContext& ctx) { +ARMContext& Context::operator=(const ARMContext& ctx) { mode_ = ctx.mode_; active_ids_ = ctx.active_ids_; workspace_ = ctx.workspace_; @@ -83,7 +83,7 @@ ARMContext& ARMContext::operator=(const ARMContext& ctx) { return *this; } -void ARMContext::BindDev() { +void Context::BindDev() { #ifdef USE_OPENMP int num_threads = active_ids_.size(); omp_set_num_threads(num_threads); @@ -116,7 +116,7 @@ void ARMContext::BindDev() { #endif // USE_OPENMP } -void ARMContext::SetRunMode(PowerMode mode, int threads) { +void Context::SetRunMode(PowerMode mode, int threads) { DeviceInfo& dev = DeviceInfo::Global(); int big_core_size = dev.big_core_ids_.size(); int small_core_size = dev.little_core_ids_.size(); @@ -293,26 +293,26 @@ void ARMContext::SetRunMode(PowerMode mode, int threads) { arch_ = DeviceInfo::Global().archs_[active_ids_[0]]; } -ARMArch ARMContext::arch() const { return arch_; } +ARMArch Context::arch() const { return arch_; } -void ARMContext::SetArch(ARMArch arch) { arch_ = arch; } +void Context::SetArch(ARMArch arch) { arch_ = arch; } -int ARMContext::l1_cache_size() const { +int Context::l1_cache_size() const { DeviceInfo& dev = DeviceInfo::Global(); return dev.L1_cache_[active_ids_[0]]; } -int ARMContext::l2_cache_size() const { +int Context::l2_cache_size() const { DeviceInfo& dev = DeviceInfo::Global(); return dev.L2_cache_[active_ids_[0]]; } -int ARMContext::l3_cache_size() const { +int Context::l3_cache_size() const { DeviceInfo& dev = DeviceInfo::Global(); return dev.L3_cache_[active_ids_[0]]; } -bool ARMContext::ExtendWorkspace(DDimLite dims) { +bool Context::ExtendWorkspace(DDimLite dims) { auto count = dims.product(); auto old = workspace_.dims(); if (count == old.product()) { @@ -324,5 +324,6 @@ bool ARMContext::ExtendWorkspace(DDimLite dims) { return true; } #endif // LITE_WITH_ARM + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/context.h b/paddle/fluid/lite/core/context.h index 9d5decfdbed..4702512af3a 100644 --- a/paddle/fluid/lite/core/context.h +++ b/paddle/fluid/lite/core/context.h @@ -23,28 +23,55 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device_context.h" #endif +#include #include #include +#include +#include #include #include "paddle/fluid/lite/core/cpu_info.h" #include "paddle/fluid/lite/core/lite_tensor.h" #include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { -struct HostContext {}; +template +class Context; + +using HostContext = Context; +using X86Context = Context; +using CUDAContext = Context; +using ARMContext = Context; + +template <> +class Context { + public: + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + + void CopyShared(const HostContext* ctx) {} + + std::string name() const { return "HostContext"; } +}; #ifdef LITE_WITH_ARM -struct ARMContext { +template <> +class Context { public: - ARMContext(); - ARMContext(PowerMode mode, int threads); - ARMContext(const ARMContext& ctx); + Context(); + Context(PowerMode mode, int threads); + explicit Context(const ARMContext& ctx); ARMContext& operator=(const ARMContext& ctx); + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { DeviceInfo::Init(); } + + void CopyShared(const ARMContext* ctx) {} + void SetRunMode(PowerMode mode, int threads); void SetCache(int l1size, int l2size, int l3size); void SetArch(ARMArch arch); @@ -64,6 +91,8 @@ struct ARMContext { int l3_cache_size() const; bool ExtendWorkspace(DDimLite dims); + std::string name() const { return "ARMContext"; } + private: // LITE_POWER_HIGH stands for using big cores, // LITE_POWER_LOW stands for using small core, @@ -78,33 +107,99 @@ struct ARMContext { #ifdef LITE_WITH_CUDA // Only works with CUDA kernels. -struct CUDAContext { +template <> +class Context { + public: + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { + cublas_fp32_ = std::make_shared>(); + } + + void CopyShared(const CUDAContext* ctx) { + CHECK(ctx); + CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; + ctx->cublas_fp32_ = cublas_fp32_; + } + + const cudaStream_t exec_stream() { return exec_stream_; } + void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; } + + const cudaStream_t io_stream() { return io_stream_; } + void SetIoStream(cudaStream_t stream) { io_stream_ = stream; } + + std::shared_ptr> cublas_fp32() { return cublas_fp32_; } + void SetCuBlasFP32(std::shared_ptr> cublas_fp32) { + cublas_fp32_ = cublas_fp32; + } + + const std::vector& input_events() { return input_events_; } + void SetInputEvents(const std::vector& input_events) { + input_events_.clear(); + input_events_.assign(input_events.begin(), input_events.end()); + } + + const std::vector& output_events() { return output_events_; } + void SetOutputEvents(const std::vector& output_events) { + output_events_.clear(); + output_events_.assign(output_events.begin(), output_events.end()); + } + + std::string name() const { return "CUDAContext"; } + + private: // overall information - cudaStream_t exec_stream; - cudaStream_t io_stream; + cudaStream_t exec_stream_; + cudaStream_t io_stream_; // not thread-safe, should allocate for each thread. - std::shared_ptr> blas_fp32; + std::shared_ptr> cublas_fp32_; // kernel information - std::vector input_events; - std::vector output_events; + std::vector input_events_; + std::vector output_events_; }; #endif #ifdef LITE_WITH_X86 -struct X86Context { - // overall information - X86Context() { - x86_device_context.reset(new ::paddle::platform::CPUDeviceContext); - x86_execution_context.reset( - new ::paddle::framework::ExecutionContext(*x86_device_context)); +template <> +class Context { + public: + using device_ctx_t = ::paddle::platform::CPUDeviceContext; + using execution_ctx_t = ::paddle::framework::ExecutionContext; + + Context() { + x86_device_context_.reset(new ::paddle::platform::CPUDeviceContext); + x86_execution_context_.reset( + new ::paddle::framework::ExecutionContext(*x86_device_context_)); + } + + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + + void CopyShared(const X86Context* ctx) {} + + const device_ctx_t* x86_device_context() { return x86_device_context_.get(); } + void SetX86DeviceContext(std::unique_ptr&& ctx) { + x86_device_context_ = std::move(ctx); + } + + const execution_ctx_t* x86_execution_context() { + return x86_execution_context_.get(); + } + void SetX86ExecutionContext(std::unique_ptr&& ctx) { + x86_execution_context_ = std::move(ctx); } + + std::string name() const { return "X86Context"; } + + private: + // overall information + // // kernel information // legacy info. - std::unique_ptr<::paddle::platform::CPUDeviceContext> x86_device_context; - std::unique_ptr<::paddle::framework::ExecutionContext> x86_execution_context; + std::unique_ptr x86_device_context_; + std::unique_ptr x86_execution_context_; }; #endif @@ -124,5 +219,67 @@ class KernelContext { Any ctx_; }; +// The ContextScheduler helps to assign different context for each kernel. +class ContextScheduler { + public: + static ContextScheduler& Global() { + static auto* x = new ContextScheduler; + return *x; + } + + std::unique_ptr NewContext(TargetType target) { + std::unique_ptr ctx(new KernelContext); + switch (target) { + case TARGET(kHost): + kernel_contexts_[TargetType::kHost].As().CopyShared( + &ctx->As()); + break; +#ifdef LITE_WITH_X86 + case TARGET(kX86): + kernel_contexts_[TargetType::kX86].As().CopyShared( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_CUDA + case TARGET(kCUDA): + kernel_contexts_[TargetType::kCUDA].As().CopyShared( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_ARM + case TARGET(kARM): + kernel_contexts_[TargetType::kARM].As().CopyShared( + &ctx->As()); + break; +#endif + default: + LOG(FATAL) << "unsupported target " << TargetToStr(target); + } + return ctx; + } + + private: + template + void InitContext() { + kernel_contexts_[Type].As().InitOnce(); + } + + ContextScheduler() { + InitContext(); +#ifdef LITE_WITH_X86 + InitContext(); +#endif +#ifdef LITE_WITH_CUDA + InitContext(); +#endif +#ifdef LITE_WITH_ARM + InitContext(); +#endif + } + + private: + std::map kernel_contexts_; +}; + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/context_test.cc b/paddle/fluid/lite/core/context_test.cc new file mode 100644 index 00000000000..0952aec33f3 --- /dev/null +++ b/paddle/fluid/lite/core/context_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/context.h" +#include + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_X86 +TEST(ContextScheduler, NewContext) { + auto ctx1_p = ContextScheduler::Global().NewContext(TargetType::kX86); + auto ctx2_p = ContextScheduler::Global().NewContext(TargetType::kX86); + ASSERT_FALSE(ctx1_p.get() == ctx2_p.get()); + + auto& ctx1 = ctx1_p->As(); + auto& ctx2 = ctx2_p->As(); + + ASSERT_EQ(ctx1.name(), "X86Context"); + ASSERT_EQ(ctx2.name(), "X86Context"); + + ASSERT_FALSE(ctx1.x86_device_context() == nullptr || + ctx2.x86_device_context() == nullptr); + ASSERT_FALSE(ctx1.x86_execution_context() == nullptr || + ctx2.x86_execution_context() == nullptr); + + ASSERT_TRUE(ctx1.x86_device_context() != ctx2.x86_device_context()); + ASSERT_TRUE(ctx1.x86_execution_context() != ctx2.x86_execution_context()); + + using device_ctx_t = ::paddle::platform::CPUDeviceContext; + using exec_ctx_t = ::paddle::framework::ExecutionContext; + auto* device_ctx = new device_ctx_t; + ctx1.SetX86DeviceContext(std::unique_ptr(device_ctx)); + ctx1.SetX86ExecutionContext( + std::unique_ptr(new exec_ctx_t(*device_ctx))); +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc index f7c983b675f..257766945af 100644 --- a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc +++ b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc @@ -21,85 +21,16 @@ namespace mir { class RuntimeContextAssignPass : public StmtPass { public: - RuntimeContextAssignPass() { -#ifdef LITE_WITH_CUDA - InitCudaBlas(); -#endif - } + RuntimeContextAssignPass() {} void Apply(const std::unique_ptr& graph) override { for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; - auto& inst = node.AsStmt(); - switch (inst.picked_kernel().target()) { - case TARGET(kHost): - inst.picked_kernel().SetContext(NewHostContext()); - break; -#ifdef LITE_WITH_X86 - case TARGET(kX86): - inst.picked_kernel().SetContext(NewX86Context()); - break; -#endif -#ifdef LITE_WITH_CUDA - case TARGET(kCUDA): - inst.picked_kernel().SetContext(NewCudaContext()); - break; -#endif -#ifdef LITE_WITH_ARM - case TARGET(kARM): - inst.picked_kernel().SetContext(NewARMContext()); - break; -#endif - default: - LOG(FATAL) << "unsupported target " - << TargetToStr(inst.picked_kernel().target()); - } + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); } } - - std::unique_ptr NewHostContext() { - std::unique_ptr ctx(new KernelContext); - ctx->As(); - // Some initialization here. - - return ctx; - } -#ifdef LITE_WITH_X86 - std::unique_ptr NewX86Context() { - std::unique_ptr ctx(new KernelContext); - ctx->As(); - return ctx; - } -#endif - -#ifdef LITE_WITH_ARM - std::unique_ptr NewARMContext() { - DeviceInfo::Init(); - std::unique_ptr ctx(new KernelContext); - ctx->As(); - return ctx; - } -#endif -#ifdef LITE_WITH_CUDA - std::unique_ptr NewCudaContext() { - std::unique_ptr ctx(new KernelContext); - auto& cuda = ctx->As(); - // Some initialization here. - CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; - cuda.blas_fp32 = cublas_fp32_; - return ctx; - } - - void InitCudaBlas() { - cublas_fp32_ = std::make_shared>(); - } -#endif - - private: -#ifdef LITE_WITH_CUDA - std::shared_ptr> cublas_fp32_; -#endif }; } // namespace mir diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index 5eb30bf8dfd..43ad6ba5f96 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -37,9 +37,9 @@ class MulCompute : public KernelLite { void Run() override { CHECK(ctx_) << "running context should be set first"; auto& context = ctx_->As(); - CHECK(context.blas_fp32) << "blas should init first"; + CHECK(context.cublas_fp32()) << "blas should init first"; /* - auto& blas = *context.blas_fp32; + auto& blas = *context.cublas_fp32(); CHECK(param.x->target() == TARGET(kCUDA)); auto* x = param.x->data(); int x_h = param.x->dims()[0]; diff --git a/paddle/fluid/lite/kernels/x86/activation_compute.cc b/paddle/fluid/lite/kernels/x86/activation_compute.cc index 4ea1c0f6504..a07a69af2d1 100644 --- a/paddle/fluid/lite/kernels/x86/activation_compute.cc +++ b/paddle/fluid/lite/kernels/x86/activation_compute.cc @@ -62,10 +62,10 @@ class SquareCompute : public KernelLite { void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); - Activate>(*context.x86_device_context, + Activate>(*context.x86_device_context(), ¶m.X->raw_tensor(), ¶m.Out->raw_tensor()); } @@ -81,11 +81,11 @@ class SquareGradCompute : public KernelLite { void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.X_grad->template mutable_data(); ActivateGrad>( - *context.x86_device_context, ¶m.X->raw_tensor(), + *context.x86_device_context(), ¶m.X->raw_tensor(), ¶m.Out->raw_tensor(), ¶m.Out_grad->raw_tensor(), ¶m.X_grad->raw_tensor()); } diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc index b1326ee730f..8e2ea92d6de 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc @@ -44,12 +44,12 @@ class ElementwiseSubCompute void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); paddle::operators::ElementwiseComputeEx, platform::CPUDeviceContext, T>( - *context.x86_execution_context, ¶m.X->raw_tensor(), + *context.x86_execution_context(), ¶m.X->raw_tensor(), ¶m.Y->raw_tensor(), param.axis, SubFunctor(), ¶m.Out->raw_tensor()); } @@ -75,7 +75,7 @@ class ElementwiseSubGradCompute void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.X_grad->template mutable_data(); param.Y_grad->template mutable_data(); @@ -86,8 +86,8 @@ class ElementwiseSubGradCompute auto& skip = dout; paddle::operators::ElemwiseExplicitGradCompute< platform::CPUDeviceContext, T, SubGradDX, SubGradDY>( - *context.x86_execution_context, skip, skip, skip, dout, param.axis, &dx, - &dy, SubGradDX(), SubGradDY()); + *context.x86_execution_context(), skip, skip, skip, dout, param.axis, + &dx, &dy, SubGradDX(), SubGradDY()); } virtual ~ElementwiseSubGradCompute() = default; @@ -101,11 +101,11 @@ class ElementwiseAddCompute void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); paddle::operators::ElementwiseComputeEx, platform::CPUDeviceContext, T>( - *context.x86_execution_context, ¶m.X->raw_tensor(), + *context.x86_execution_context(), ¶m.X->raw_tensor(), ¶m.Y->raw_tensor(), param.axis, AddFunctor(), ¶m.Out->raw_tensor()); } diff --git a/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc b/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc index d0b03c78ee0..5a5a719af3b 100644 --- a/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc +++ b/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc @@ -32,12 +32,12 @@ class FillConstantCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); paddle::operators::math::set_constant( - *context.x86_device_context, ¶m.Out->raw_tensor(), param.value); + *context.x86_device_context(), ¶m.Out->raw_tensor(), param.value); } virtual ~FillConstantCompute() = default; diff --git a/paddle/fluid/lite/kernels/x86/mean_compute.cc b/paddle/fluid/lite/kernels/x86/mean_compute.cc index 95cb0c89e03..ac1a37707ad 100644 --- a/paddle/fluid/lite/kernels/x86/mean_compute.cc +++ b/paddle/fluid/lite/kernels/x86/mean_compute.cc @@ -38,13 +38,13 @@ class MeanCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); auto X = EigenVector::Flatten(param.X->raw_tensor()); auto y = EigenScalar::From(param.Out->raw_tensor()); - const auto& place = *(context.x86_device_context->eigen_device()); + const auto& place = *(context.x86_device_context()->eigen_device()); y.device(place) = X.mean(); } @@ -61,13 +61,13 @@ class MeanGradCompute : public KernelLite { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.X_grad->template mutable_data(); T x_grad_size = static_cast(param.X_grad->raw_tensor().numel()); Eigen::DSizes bcast(static_cast(x_grad_size)); EigenVector::Flatten(param.X_grad->raw_tensor()) - .device(*(context.x86_device_context->eigen_device())) = + .device(*(context.x86_device_context()->eigen_device())) = (EigenVector::From(param.Out_grad->raw_tensor()) / x_grad_size) .broadcast(bcast); } diff --git a/paddle/fluid/lite/kernels/x86/mul_compute.cc b/paddle/fluid/lite/kernels/x86/mul_compute.cc index a099a2fdf13..ad009893c8a 100644 --- a/paddle/fluid/lite/kernels/x86/mul_compute.cc +++ b/paddle/fluid/lite/kernels/x86/mul_compute.cc @@ -32,7 +32,7 @@ class MulCompute : public KernelLite { void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.output->template mutable_data(); @@ -53,7 +53,7 @@ class MulCompute : public KernelLite { } auto blas = paddle::operators::math::GetBlas( - *context.x86_device_context); + *context.x86_device_context()); blas.MatMul(x_matrix, y_matrix, z); if (z_dim.size() != 2) { @@ -70,7 +70,7 @@ class MulGradCompute : public KernelLite { void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); @@ -99,7 +99,7 @@ class MulGradCompute : public KernelLite { } auto blas = paddle::operators::math::GetBlas( - *context.x86_device_context); + *context.x86_device_context()); if (dx) { // dx->mutable_data(context.x86_device_context->GetPlace()); param.x_grad->template mutable_data(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 86bc014821c..d17ff90ecf0 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -32,5 +32,8 @@ set(ops_lite dropout_op_lite PARENT_SCOPE) -lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite X86_DEPS fc_compute_x86) +lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc + DEPS fc_op_lite memory_lite + X86_DEPS fc_compute_x86 + ARM_DEPS fc_compute_arm) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/fc_op_test.cc b/paddle/fluid/lite/operators/fc_op_test.cc index 9d8c4be0438..9ef91dbc147 100644 --- a/paddle/fluid/lite/operators/fc_op_test.cc +++ b/paddle/fluid/lite/operators/fc_op_test.cc @@ -20,7 +20,7 @@ namespace paddle { namespace lite { namespace operators { -TEST(fc_op_lite, test) { +TEST(fc_op_lite, TestX86) { // prepare variables Scope scope; auto* x = scope.Var("x")->GetMutable(); @@ -57,9 +57,11 @@ TEST(fc_op_lite, test) { FcOpLite fc("fc"); - fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}}); + fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); fc.Attach(desc, &scope); - auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)}}); + auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); ASSERT_FALSE(kernels.empty()); } -- GitLab