提交 44dbc0bc 编写于 作者: S sangoly 提交者: Yan Chunwei

refactor context to support both server and light (#17762)

上级 2cf061f2
...@@ -25,13 +25,6 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc ...@@ -25,13 +25,6 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.") "A path setting inference demo download directories.")
# lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
# DEPS cxx_api_lite model_parser_lite target_wrapper_host
# PROFILE_DEPS basic_profiler_lite
# ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
# --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host DEPS cxx_api_lite model_parser_lite target_wrapper_host
...@@ -42,7 +35,15 @@ if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) ...@@ -42,7 +35,15 @@ if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz")
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif() endif(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif(WITH_TESTING)
# if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif()
lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h" #include "paddle/fluid/lite/model_parser/model_parser.h"
...@@ -84,6 +85,7 @@ class LightPredictor { ...@@ -84,6 +85,7 @@ class LightPredictor {
return it->alias() == alias; return it->alias() == alias;
}); });
CHECK(it != kernels.end()); CHECK(it != kernels.end());
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); program_.reset(new RuntimeProgram(std::move(insts)));
......
...@@ -44,3 +44,18 @@ USE_LITE_OP(scale); ...@@ -44,3 +44,18 @@ USE_LITE_OP(scale);
USE_LITE_OP(feed); USE_LITE_OP(feed);
USE_LITE_OP(fetch); USE_LITE_OP(fetch);
USE_LITE_OP(io_copy); USE_LITE_OP(io_copy);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
#endif
...@@ -54,3 +54,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li ...@@ -54,3 +54,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
#lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite) #lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite)
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
...@@ -33,7 +33,7 @@ namespace lite { ...@@ -33,7 +33,7 @@ namespace lite {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
void ARMContext::SetCache(int l1size, int l2size, int l3size) { void Context<TargetType::kARM>::SetCache(int l1size, int l2size, int l3size) {
DeviceInfo& dev = DeviceInfo::Global(); DeviceInfo& dev = DeviceInfo::Global();
int cpu_count = arm_get_cpucount(); int cpu_count = arm_get_cpucount();
dev.L1_cache_.resize(cpu_count); dev.L1_cache_.resize(cpu_count);
...@@ -47,7 +47,7 @@ void ARMContext::SetCache(int l1size, int l2size, int l3size) { ...@@ -47,7 +47,7 @@ void ARMContext::SetCache(int l1size, int l2size, int l3size) {
workspace_.Resize({2 * (l1size + l2size)}); workspace_.Resize({2 * (l1size + l2size)});
} }
ARMContext::ARMContext() { Context<TargetType::kARM>::Context() {
active_ids_ = {0}; active_ids_ = {0};
mode_ = LITE_POWER_HIGH; mode_ = LITE_POWER_HIGH;
DeviceInfo& dev = DeviceInfo::Global(); DeviceInfo& dev = DeviceInfo::Global();
...@@ -62,11 +62,11 @@ ARMContext::ARMContext() { ...@@ -62,11 +62,11 @@ ARMContext::ARMContext() {
#endif #endif
} }
PowerMode ARMContext::mode() const { return mode_; } PowerMode Context<TargetType::kARM>::mode() const { return mode_; }
int ARMContext::threads() const { return active_ids_.size(); } int Context<TargetType::kARM>::threads() const { return active_ids_.size(); }
ARMContext::ARMContext(const ARMContext& ctx) { Context<TargetType::kARM>::Context(const ARMContext& ctx) {
mode_ = ctx.mode_; mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_; active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_; workspace_ = ctx.workspace_;
...@@ -74,7 +74,7 @@ ARMContext::ARMContext(const ARMContext& ctx) { ...@@ -74,7 +74,7 @@ ARMContext::ARMContext(const ARMContext& ctx) {
count_ = ctx.count_; count_ = ctx.count_;
} }
ARMContext& ARMContext::operator=(const ARMContext& ctx) { ARMContext& Context<TargetType::kARM>::operator=(const ARMContext& ctx) {
mode_ = ctx.mode_; mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_; active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_; workspace_ = ctx.workspace_;
...@@ -83,7 +83,7 @@ ARMContext& ARMContext::operator=(const ARMContext& ctx) { ...@@ -83,7 +83,7 @@ ARMContext& ARMContext::operator=(const ARMContext& ctx) {
return *this; return *this;
} }
void ARMContext::BindDev() { void Context<TargetType::kARM>::BindDev() {
#ifdef USE_OPENMP #ifdef USE_OPENMP
int num_threads = active_ids_.size(); int num_threads = active_ids_.size();
omp_set_num_threads(num_threads); omp_set_num_threads(num_threads);
...@@ -116,7 +116,7 @@ void ARMContext::BindDev() { ...@@ -116,7 +116,7 @@ void ARMContext::BindDev() {
#endif // USE_OPENMP #endif // USE_OPENMP
} }
void ARMContext::SetRunMode(PowerMode mode, int threads) { void Context<TargetType::kARM>::SetRunMode(PowerMode mode, int threads) {
DeviceInfo& dev = DeviceInfo::Global(); DeviceInfo& dev = DeviceInfo::Global();
int big_core_size = dev.big_core_ids_.size(); int big_core_size = dev.big_core_ids_.size();
int small_core_size = dev.little_core_ids_.size(); int small_core_size = dev.little_core_ids_.size();
...@@ -293,26 +293,26 @@ void ARMContext::SetRunMode(PowerMode mode, int threads) { ...@@ -293,26 +293,26 @@ void ARMContext::SetRunMode(PowerMode mode, int threads) {
arch_ = DeviceInfo::Global().archs_[active_ids_[0]]; arch_ = DeviceInfo::Global().archs_[active_ids_[0]];
} }
ARMArch ARMContext::arch() const { return arch_; } ARMArch Context<TargetType::kARM>::arch() const { return arch_; }
void ARMContext::SetArch(ARMArch arch) { arch_ = arch; } void Context<TargetType::kARM>::SetArch(ARMArch arch) { arch_ = arch; }
int ARMContext::l1_cache_size() const { int Context<TargetType::kARM>::l1_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global(); DeviceInfo& dev = DeviceInfo::Global();
return dev.L1_cache_[active_ids_[0]]; return dev.L1_cache_[active_ids_[0]];
} }
int ARMContext::l2_cache_size() const { int Context<TargetType::kARM>::l2_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global(); DeviceInfo& dev = DeviceInfo::Global();
return dev.L2_cache_[active_ids_[0]]; return dev.L2_cache_[active_ids_[0]];
} }
int ARMContext::l3_cache_size() const { int Context<TargetType::kARM>::l3_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global(); DeviceInfo& dev = DeviceInfo::Global();
return dev.L3_cache_[active_ids_[0]]; return dev.L3_cache_[active_ids_[0]];
} }
bool ARMContext::ExtendWorkspace(DDimLite dims) { bool Context<TargetType::kARM>::ExtendWorkspace(DDimLite dims) {
auto count = dims.product(); auto count = dims.product();
auto old = workspace_.dims(); auto old = workspace_.dims();
if (count == old.product()) { if (count == old.product()) {
...@@ -324,5 +324,6 @@ bool ARMContext::ExtendWorkspace(DDimLite dims) { ...@@ -324,5 +324,6 @@ bool ARMContext::ExtendWorkspace(DDimLite dims) {
return true; return true;
} }
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -23,28 +23,55 @@ ...@@ -23,28 +23,55 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#endif #endif
#include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/cpu_info.h" #include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h" #include "paddle/fluid/lite/core/lite_tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
struct HostContext {}; template <TargetType Type>
class Context;
using HostContext = Context<TargetType::kHost>;
using X86Context = Context<TargetType::kX86>;
using CUDAContext = Context<TargetType::kCUDA>;
using ARMContext = Context<TargetType::kARM>;
template <>
class Context<TargetType::kHost> {
public:
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopyShared(const HostContext* ctx) {}
std::string name() const { return "HostContext"; }
};
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
struct ARMContext { template <>
class Context<TargetType::kARM> {
public: public:
ARMContext(); Context();
ARMContext(PowerMode mode, int threads); Context(PowerMode mode, int threads);
ARMContext(const ARMContext& ctx); explicit Context(const ARMContext& ctx);
ARMContext& operator=(const ARMContext& ctx); ARMContext& operator=(const ARMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { DeviceInfo::Init(); }
void CopyShared(const ARMContext* ctx) {}
void SetRunMode(PowerMode mode, int threads); void SetRunMode(PowerMode mode, int threads);
void SetCache(int l1size, int l2size, int l3size); void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch); void SetArch(ARMArch arch);
...@@ -64,6 +91,8 @@ struct ARMContext { ...@@ -64,6 +91,8 @@ struct ARMContext {
int l3_cache_size() const; int l3_cache_size() const;
bool ExtendWorkspace(DDimLite dims); bool ExtendWorkspace(DDimLite dims);
std::string name() const { return "ARMContext"; }
private: private:
// LITE_POWER_HIGH stands for using big cores, // LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_LOW stands for using small core, // LITE_POWER_LOW stands for using small core,
...@@ -78,33 +107,99 @@ struct ARMContext { ...@@ -78,33 +107,99 @@ struct ARMContext {
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
// Only works with CUDA kernels. // Only works with CUDA kernels.
struct CUDAContext { template <>
class Context<TargetType::kCUDA> {
public:
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
void CopyShared(const CUDAContext* ctx) {
CHECK(ctx);
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
ctx->cublas_fp32_ = cublas_fp32_;
}
const cudaStream_t exec_stream() { return exec_stream_; }
void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; }
const cudaStream_t io_stream() { return io_stream_; }
void SetIoStream(cudaStream_t stream) { io_stream_ = stream; }
std::shared_ptr<cuda::Blas<float>> cublas_fp32() { return cublas_fp32_; }
void SetCuBlasFP32(std::shared_ptr<cuda::Blas<float>> cublas_fp32) {
cublas_fp32_ = cublas_fp32;
}
const std::vector<cudaEvent_t>& input_events() { return input_events_; }
void SetInputEvents(const std::vector<cudaEvent_t>& input_events) {
input_events_.clear();
input_events_.assign(input_events.begin(), input_events.end());
}
const std::vector<cudaEvent_t>& output_events() { return output_events_; }
void SetOutputEvents(const std::vector<cudaEvent_t>& output_events) {
output_events_.clear();
output_events_.assign(output_events.begin(), output_events.end());
}
std::string name() const { return "CUDAContext"; }
private:
// overall information // overall information
cudaStream_t exec_stream; cudaStream_t exec_stream_;
cudaStream_t io_stream; cudaStream_t io_stream_;
// not thread-safe, should allocate for each thread. // not thread-safe, should allocate for each thread.
std::shared_ptr<cuda::Blas<float>> blas_fp32; std::shared_ptr<cuda::Blas<float>> cublas_fp32_;
// kernel information // kernel information
std::vector<cudaEvent_t> input_events; std::vector<cudaEvent_t> input_events_;
std::vector<cudaEvent_t> output_events; std::vector<cudaEvent_t> output_events_;
}; };
#endif #endif
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
struct X86Context { template <>
// overall information class Context<TargetType::kX86> {
X86Context() { public:
x86_device_context.reset(new ::paddle::platform::CPUDeviceContext); using device_ctx_t = ::paddle::platform::CPUDeviceContext;
x86_execution_context.reset( using execution_ctx_t = ::paddle::framework::ExecutionContext;
new ::paddle::framework::ExecutionContext(*x86_device_context));
Context() {
x86_device_context_.reset(new ::paddle::platform::CPUDeviceContext);
x86_execution_context_.reset(
new ::paddle::framework::ExecutionContext(*x86_device_context_));
}
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopyShared(const X86Context* ctx) {}
const device_ctx_t* x86_device_context() { return x86_device_context_.get(); }
void SetX86DeviceContext(std::unique_ptr<device_ctx_t>&& ctx) {
x86_device_context_ = std::move(ctx);
}
const execution_ctx_t* x86_execution_context() {
return x86_execution_context_.get();
}
void SetX86ExecutionContext(std::unique_ptr<execution_ctx_t>&& ctx) {
x86_execution_context_ = std::move(ctx);
} }
std::string name() const { return "X86Context"; }
private:
// overall information
//
// kernel information // kernel information
// legacy info. // legacy info.
std::unique_ptr<::paddle::platform::CPUDeviceContext> x86_device_context; std::unique_ptr<device_ctx_t> x86_device_context_;
std::unique_ptr<::paddle::framework::ExecutionContext> x86_execution_context; std::unique_ptr<execution_ctx_t> x86_execution_context_;
}; };
#endif #endif
...@@ -124,5 +219,67 @@ class KernelContext { ...@@ -124,5 +219,67 @@ class KernelContext {
Any ctx_; Any ctx_;
}; };
// The ContextScheduler helps to assign different context for each kernel.
class ContextScheduler {
public:
static ContextScheduler& Global() {
static auto* x = new ContextScheduler;
return *x;
}
std::unique_ptr<KernelContext> NewContext(TargetType target) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
switch (target) {
case TARGET(kHost):
kernel_contexts_[TargetType::kHost].As<HostContext>().CopyShared(
&ctx->As<HostContext>());
break;
#ifdef LITE_WITH_X86
case TARGET(kX86):
kernel_contexts_[TargetType::kX86].As<X86Context>().CopyShared(
&ctx->As<X86Context>());
break;
#endif
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
kernel_contexts_[TargetType::kCUDA].As<CUDAContext>().CopyShared(
&ctx->As<CUDAContext>());
break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
kernel_contexts_[TargetType::kARM].As<ARMContext>().CopyShared(
&ctx->As<ARMContext>());
break;
#endif
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
}
return ctx;
}
private:
template <TargetType Type, typename ContextT>
void InitContext() {
kernel_contexts_[Type].As<ContextT>().InitOnce();
}
ContextScheduler() {
InitContext<TargetType::kHost, HostContext>();
#ifdef LITE_WITH_X86
InitContext<TargetType::kX86, X86Context>();
#endif
#ifdef LITE_WITH_CUDA
InitContext<TargetType::kCUDA, CUDAContext>();
#endif
#ifdef LITE_WITH_ARM
InitContext<TargetType::kARM, ARMContext>();
#endif
}
private:
std::map<TargetType, KernelContext> kernel_contexts_;
};
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/context.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
#ifdef LITE_WITH_X86
TEST(ContextScheduler, NewContext) {
auto ctx1_p = ContextScheduler::Global().NewContext(TargetType::kX86);
auto ctx2_p = ContextScheduler::Global().NewContext(TargetType::kX86);
ASSERT_FALSE(ctx1_p.get() == ctx2_p.get());
auto& ctx1 = ctx1_p->As<X86Context>();
auto& ctx2 = ctx2_p->As<X86Context>();
ASSERT_EQ(ctx1.name(), "X86Context");
ASSERT_EQ(ctx2.name(), "X86Context");
ASSERT_FALSE(ctx1.x86_device_context() == nullptr ||
ctx2.x86_device_context() == nullptr);
ASSERT_FALSE(ctx1.x86_execution_context() == nullptr ||
ctx2.x86_execution_context() == nullptr);
ASSERT_TRUE(ctx1.x86_device_context() != ctx2.x86_device_context());
ASSERT_TRUE(ctx1.x86_execution_context() != ctx2.x86_execution_context());
using device_ctx_t = ::paddle::platform::CPUDeviceContext;
using exec_ctx_t = ::paddle::framework::ExecutionContext;
auto* device_ctx = new device_ctx_t;
ctx1.SetX86DeviceContext(std::unique_ptr<device_ctx_t>(device_ctx));
ctx1.SetX86ExecutionContext(
std::unique_ptr<exec_ctx_t>(new exec_ctx_t(*device_ctx)));
}
#endif
} // namespace lite
} // namespace paddle
...@@ -21,85 +21,16 @@ namespace mir { ...@@ -21,85 +21,16 @@ namespace mir {
class RuntimeContextAssignPass : public StmtPass { class RuntimeContextAssignPass : public StmtPass {
public: public:
RuntimeContextAssignPass() { RuntimeContextAssignPass() {}
#ifdef LITE_WITH_CUDA
InitCudaBlas();
#endif
}
void Apply(const std::unique_ptr<SSAGraph>& graph) override { void Apply(const std::unique_ptr<SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& inst = node.AsStmt(); auto& inst = node.AsStmt();
switch (inst.picked_kernel().target()) { inst.picked_kernel().SetContext(
case TARGET(kHost): ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
inst.picked_kernel().SetContext(NewHostContext());
break;
#ifdef LITE_WITH_X86
case TARGET(kX86):
inst.picked_kernel().SetContext(NewX86Context());
break;
#endif
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
inst.picked_kernel().SetContext(NewCudaContext());
break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
inst.picked_kernel().SetContext(NewARMContext());
break;
#endif
default:
LOG(FATAL) << "unsupported target "
<< TargetToStr(inst.picked_kernel().target());
}
}
}
std::unique_ptr<KernelContext> NewHostContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<HostContext>();
// Some initialization here.
return ctx;
} }
#ifdef LITE_WITH_X86
std::unique_ptr<KernelContext> NewX86Context() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
return ctx;
} }
#endif
#ifdef LITE_WITH_ARM
std::unique_ptr<KernelContext> NewARMContext() {
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
return ctx;
}
#endif
#ifdef LITE_WITH_CUDA
std::unique_ptr<KernelContext> NewCudaContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& cuda = ctx->As<CUDAContext>();
// Some initialization here.
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
cuda.blas_fp32 = cublas_fp32_;
return ctx;
}
void InitCudaBlas() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
#endif
private:
#ifdef LITE_WITH_CUDA
std::shared_ptr<lite::cuda::Blas<float>> cublas_fp32_;
#endif
}; };
} // namespace mir } // namespace mir
......
...@@ -37,9 +37,9 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -37,9 +37,9 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
void Run() override { void Run() override {
CHECK(ctx_) << "running context should be set first"; CHECK(ctx_) << "running context should be set first";
auto& context = ctx_->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
CHECK(context.blas_fp32) << "blas should init first"; CHECK(context.cublas_fp32()) << "blas should init first";
/* /*
auto& blas = *context.blas_fp32; auto& blas = *context.cublas_fp32();
CHECK(param.x->target() == TARGET(kCUDA)); CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>(); auto* x = param.x->data<float>();
int x_h = param.x->dims()[0]; int x_h = param.x->dims()[0];
......
...@@ -62,10 +62,10 @@ class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -62,10 +62,10 @@ class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>(); auto& param = *param_.get_mutable<operators::ActivationParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.Out->template mutable_data<T>(); param.Out->template mutable_data<T>();
Activate<paddle::operators::SquareFunctor<T>>(*context.x86_device_context, Activate<paddle::operators::SquareFunctor<T>>(*context.x86_device_context(),
&param.X->raw_tensor(), &param.X->raw_tensor(),
&param.Out->raw_tensor()); &param.Out->raw_tensor());
} }
...@@ -81,11 +81,11 @@ class SquareGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -81,11 +81,11 @@ class SquareGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationGradParam>(); auto& param = *param_.get_mutable<operators::ActivationGradParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>(); param.X_grad->template mutable_data<T>();
ActivateGrad<paddle::operators::SquareGradFunctor<T>>( ActivateGrad<paddle::operators::SquareGradFunctor<T>>(
*context.x86_device_context, &param.X->raw_tensor(), *context.x86_device_context(), &param.X->raw_tensor(),
&param.Out->raw_tensor(), &param.Out_grad->raw_tensor(), &param.Out->raw_tensor(), &param.Out_grad->raw_tensor(),
&param.X_grad->raw_tensor()); &param.X_grad->raw_tensor());
} }
......
...@@ -44,12 +44,12 @@ class ElementwiseSubCompute ...@@ -44,12 +44,12 @@ class ElementwiseSubCompute
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.Out->template mutable_data<T>(); param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<SubFunctor<T>, paddle::operators::ElementwiseComputeEx<SubFunctor<T>,
platform::CPUDeviceContext, T>( platform::CPUDeviceContext, T>(
*context.x86_execution_context, &param.X->raw_tensor(), *context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, SubFunctor<T>(), &param.Y->raw_tensor(), param.axis, SubFunctor<T>(),
&param.Out->raw_tensor()); &param.Out->raw_tensor());
} }
...@@ -75,7 +75,7 @@ class ElementwiseSubGradCompute ...@@ -75,7 +75,7 @@ class ElementwiseSubGradCompute
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>(); param.X_grad->template mutable_data<T>();
param.Y_grad->template mutable_data<T>(); param.Y_grad->template mutable_data<T>();
...@@ -86,8 +86,8 @@ class ElementwiseSubGradCompute ...@@ -86,8 +86,8 @@ class ElementwiseSubGradCompute
auto& skip = dout; auto& skip = dout;
paddle::operators::ElemwiseExplicitGradCompute< paddle::operators::ElemwiseExplicitGradCompute<
platform::CPUDeviceContext, T, SubGradDX<T>, SubGradDY<T>>( platform::CPUDeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
*context.x86_execution_context, skip, skip, skip, dout, param.axis, &dx, *context.x86_execution_context(), skip, skip, skip, dout, param.axis,
&dy, SubGradDX<T>(), SubGradDY<T>()); &dx, &dy, SubGradDX<T>(), SubGradDY<T>());
} }
virtual ~ElementwiseSubGradCompute() = default; virtual ~ElementwiseSubGradCompute() = default;
...@@ -101,11 +101,11 @@ class ElementwiseAddCompute ...@@ -101,11 +101,11 @@ class ElementwiseAddCompute
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.Out->template mutable_data<T>(); param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<AddFunctor<T>, paddle::operators::ElementwiseComputeEx<AddFunctor<T>,
platform::CPUDeviceContext, T>( platform::CPUDeviceContext, T>(
*context.x86_execution_context, &param.X->raw_tensor(), *context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, AddFunctor<T>(), &param.Y->raw_tensor(), param.axis, AddFunctor<T>(),
&param.Out->raw_tensor()); &param.Out->raw_tensor());
} }
......
...@@ -32,12 +32,12 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -32,12 +32,12 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.Out->template mutable_data<T>(); param.Out->template mutable_data<T>();
paddle::operators::math::set_constant( paddle::operators::math::set_constant(
*context.x86_device_context, &param.Out->raw_tensor(), param.value); *context.x86_device_context(), &param.Out->raw_tensor(), param.value);
} }
virtual ~FillConstantCompute() = default; virtual ~FillConstantCompute() = default;
......
...@@ -38,13 +38,13 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -38,13 +38,13 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.Out->template mutable_data<T>(); param.Out->template mutable_data<T>();
auto X = EigenVector<T>::Flatten(param.X->raw_tensor()); auto X = EigenVector<T>::Flatten(param.X->raw_tensor());
auto y = EigenScalar<T>::From(param.Out->raw_tensor()); auto y = EigenScalar<T>::From(param.Out->raw_tensor());
const auto& place = *(context.x86_device_context->eigen_device()); const auto& place = *(context.x86_device_context()->eigen_device());
y.device(place) = X.mean(); y.device(place) = X.mean();
} }
...@@ -61,13 +61,13 @@ class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -61,13 +61,13 @@ class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1); CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1);
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>(); param.X_grad->template mutable_data<T>();
T x_grad_size = static_cast<T>(param.X_grad->raw_tensor().numel()); T x_grad_size = static_cast<T>(param.X_grad->raw_tensor().numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_grad_size)); Eigen::DSizes<int, 1> bcast(static_cast<int>(x_grad_size));
EigenVector<T>::Flatten(param.X_grad->raw_tensor()) EigenVector<T>::Flatten(param.X_grad->raw_tensor())
.device(*(context.x86_device_context->eigen_device())) = .device(*(context.x86_device_context()->eigen_device())) =
(EigenVector<T>::From(param.Out_grad->raw_tensor()) / x_grad_size) (EigenVector<T>::From(param.Out_grad->raw_tensor()) / x_grad_size)
.broadcast(bcast); .broadcast(bcast);
} }
......
...@@ -32,7 +32,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -32,7 +32,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulParam>(); auto& param = *param_.get_mutable<operators::MulParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
param.output->template mutable_data<T>(); param.output->template mutable_data<T>();
...@@ -53,7 +53,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -53,7 +53,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
} }
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>( auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context); *context.x86_device_context());
blas.MatMul(x_matrix, y_matrix, z); blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) { if (z_dim.size() != 2) {
...@@ -70,7 +70,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -70,7 +70,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulGradParam>(); auto& param = *param_.get_mutable<operators::MulGradParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
auto* x = &param.x->raw_tensor(); auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor(); auto* y = &param.y->raw_tensor();
...@@ -99,7 +99,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -99,7 +99,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
} }
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>( auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context); *context.x86_device_context());
if (dx) { if (dx) {
// dx->mutable_data<T>(context.x86_device_context->GetPlace()); // dx->mutable_data<T>(context.x86_device_context->GetPlace());
param.x_grad->template mutable_data<T>(); param.x_grad->template mutable_data<T>();
......
...@@ -32,5 +32,8 @@ set(ops_lite ...@@ -32,5 +32,8 @@ set(ops_lite
dropout_op_lite dropout_op_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite X86_DEPS fc_compute_x86) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite
X86_DEPS fc_compute_x86
ARM_DEPS fc_compute_arm)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
TEST(fc_op_lite, test) { TEST(fc_op_lite, TestX86) {
// prepare variables // prepare variables
Scope scope; Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>(); auto* x = scope.Var("x")->GetMutable<Tensor>();
...@@ -57,9 +57,11 @@ TEST(fc_op_lite, test) { ...@@ -57,9 +57,11 @@ TEST(fc_op_lite, test) {
FcOpLite fc("fc"); FcOpLite fc("fc");
fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}}); fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
fc.Attach(desc, &scope); fc.Attach(desc, &scope);
auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)}}); auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
ASSERT_FALSE(kernels.empty()); ASSERT_FALSE(kernels.empty());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册