提交 44dbc0bc 编写于 作者: S sangoly 提交者: Yan Chunwei

refactor context to support both server and light (#17762)

上级 2cf061f2
......@@ -25,13 +25,6 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.")
# lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
# DEPS cxx_api_lite model_parser_lite target_wrapper_host
# PROFILE_DEPS basic_profiler_lite
# ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
# --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host
......@@ -42,7 +35,15 @@ if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz")
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif()
endif(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif(WITH_TESTING)
# if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif()
lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
......
......@@ -22,6 +22,7 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
......@@ -84,6 +85,7 @@ class LightPredictor {
return it->alias() == alias;
});
CHECK(it != kernels.end());
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
......
......@@ -44,3 +44,18 @@ USE_LITE_OP(scale);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
#endif
......@@ -54,3 +54,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
#lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite)
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
......@@ -33,7 +33,7 @@ namespace lite {
#ifdef LITE_WITH_ARM
void ARMContext::SetCache(int l1size, int l2size, int l3size) {
void Context<TargetType::kARM>::SetCache(int l1size, int l2size, int l3size) {
DeviceInfo& dev = DeviceInfo::Global();
int cpu_count = arm_get_cpucount();
dev.L1_cache_.resize(cpu_count);
......@@ -47,7 +47,7 @@ void ARMContext::SetCache(int l1size, int l2size, int l3size) {
workspace_.Resize({2 * (l1size + l2size)});
}
ARMContext::ARMContext() {
Context<TargetType::kARM>::Context() {
active_ids_ = {0};
mode_ = LITE_POWER_HIGH;
DeviceInfo& dev = DeviceInfo::Global();
......@@ -62,11 +62,11 @@ ARMContext::ARMContext() {
#endif
}
PowerMode ARMContext::mode() const { return mode_; }
PowerMode Context<TargetType::kARM>::mode() const { return mode_; }
int ARMContext::threads() const { return active_ids_.size(); }
int Context<TargetType::kARM>::threads() const { return active_ids_.size(); }
ARMContext::ARMContext(const ARMContext& ctx) {
Context<TargetType::kARM>::Context(const ARMContext& ctx) {
mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_;
......@@ -74,7 +74,7 @@ ARMContext::ARMContext(const ARMContext& ctx) {
count_ = ctx.count_;
}
ARMContext& ARMContext::operator=(const ARMContext& ctx) {
ARMContext& Context<TargetType::kARM>::operator=(const ARMContext& ctx) {
mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_;
......@@ -83,7 +83,7 @@ ARMContext& ARMContext::operator=(const ARMContext& ctx) {
return *this;
}
void ARMContext::BindDev() {
void Context<TargetType::kARM>::BindDev() {
#ifdef USE_OPENMP
int num_threads = active_ids_.size();
omp_set_num_threads(num_threads);
......@@ -116,7 +116,7 @@ void ARMContext::BindDev() {
#endif // USE_OPENMP
}
void ARMContext::SetRunMode(PowerMode mode, int threads) {
void Context<TargetType::kARM>::SetRunMode(PowerMode mode, int threads) {
DeviceInfo& dev = DeviceInfo::Global();
int big_core_size = dev.big_core_ids_.size();
int small_core_size = dev.little_core_ids_.size();
......@@ -293,26 +293,26 @@ void ARMContext::SetRunMode(PowerMode mode, int threads) {
arch_ = DeviceInfo::Global().archs_[active_ids_[0]];
}
ARMArch ARMContext::arch() const { return arch_; }
ARMArch Context<TargetType::kARM>::arch() const { return arch_; }
void ARMContext::SetArch(ARMArch arch) { arch_ = arch; }
void Context<TargetType::kARM>::SetArch(ARMArch arch) { arch_ = arch; }
int ARMContext::l1_cache_size() const {
int Context<TargetType::kARM>::l1_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L1_cache_[active_ids_[0]];
}
int ARMContext::l2_cache_size() const {
int Context<TargetType::kARM>::l2_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L2_cache_[active_ids_[0]];
}
int ARMContext::l3_cache_size() const {
int Context<TargetType::kARM>::l3_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L3_cache_[active_ids_[0]];
}
bool ARMContext::ExtendWorkspace(DDimLite dims) {
bool Context<TargetType::kARM>::ExtendWorkspace(DDimLite dims) {
auto count = dims.product();
auto old = workspace_.dims();
if (count == old.product()) {
......@@ -324,5 +324,6 @@ bool ARMContext::ExtendWorkspace(DDimLite dims) {
return true;
}
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
......@@ -23,28 +23,55 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#endif
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
struct HostContext {};
template <TargetType Type>
class Context;
using HostContext = Context<TargetType::kHost>;
using X86Context = Context<TargetType::kX86>;
using CUDAContext = Context<TargetType::kCUDA>;
using ARMContext = Context<TargetType::kARM>;
template <>
class Context<TargetType::kHost> {
public:
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopyShared(const HostContext* ctx) {}
std::string name() const { return "HostContext"; }
};
#ifdef LITE_WITH_ARM
struct ARMContext {
template <>
class Context<TargetType::kARM> {
public:
ARMContext();
ARMContext(PowerMode mode, int threads);
ARMContext(const ARMContext& ctx);
Context();
Context(PowerMode mode, int threads);
explicit Context(const ARMContext& ctx);
ARMContext& operator=(const ARMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { DeviceInfo::Init(); }
void CopyShared(const ARMContext* ctx) {}
void SetRunMode(PowerMode mode, int threads);
void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch);
......@@ -64,6 +91,8 @@ struct ARMContext {
int l3_cache_size() const;
bool ExtendWorkspace(DDimLite dims);
std::string name() const { return "ARMContext"; }
private:
// LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_LOW stands for using small core,
......@@ -78,33 +107,99 @@ struct ARMContext {
#ifdef LITE_WITH_CUDA
// Only works with CUDA kernels.
struct CUDAContext {
template <>
class Context<TargetType::kCUDA> {
public:
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
void CopyShared(const CUDAContext* ctx) {
CHECK(ctx);
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
ctx->cublas_fp32_ = cublas_fp32_;
}
const cudaStream_t exec_stream() { return exec_stream_; }
void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; }
const cudaStream_t io_stream() { return io_stream_; }
void SetIoStream(cudaStream_t stream) { io_stream_ = stream; }
std::shared_ptr<cuda::Blas<float>> cublas_fp32() { return cublas_fp32_; }
void SetCuBlasFP32(std::shared_ptr<cuda::Blas<float>> cublas_fp32) {
cublas_fp32_ = cublas_fp32;
}
const std::vector<cudaEvent_t>& input_events() { return input_events_; }
void SetInputEvents(const std::vector<cudaEvent_t>& input_events) {
input_events_.clear();
input_events_.assign(input_events.begin(), input_events.end());
}
const std::vector<cudaEvent_t>& output_events() { return output_events_; }
void SetOutputEvents(const std::vector<cudaEvent_t>& output_events) {
output_events_.clear();
output_events_.assign(output_events.begin(), output_events.end());
}
std::string name() const { return "CUDAContext"; }
private:
// overall information
cudaStream_t exec_stream;
cudaStream_t io_stream;
cudaStream_t exec_stream_;
cudaStream_t io_stream_;
// not thread-safe, should allocate for each thread.
std::shared_ptr<cuda::Blas<float>> blas_fp32;
std::shared_ptr<cuda::Blas<float>> cublas_fp32_;
// kernel information
std::vector<cudaEvent_t> input_events;
std::vector<cudaEvent_t> output_events;
std::vector<cudaEvent_t> input_events_;
std::vector<cudaEvent_t> output_events_;
};
#endif
#ifdef LITE_WITH_X86
struct X86Context {
// overall information
X86Context() {
x86_device_context.reset(new ::paddle::platform::CPUDeviceContext);
x86_execution_context.reset(
new ::paddle::framework::ExecutionContext(*x86_device_context));
template <>
class Context<TargetType::kX86> {
public:
using device_ctx_t = ::paddle::platform::CPUDeviceContext;
using execution_ctx_t = ::paddle::framework::ExecutionContext;
Context() {
x86_device_context_.reset(new ::paddle::platform::CPUDeviceContext);
x86_execution_context_.reset(
new ::paddle::framework::ExecutionContext(*x86_device_context_));
}
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopyShared(const X86Context* ctx) {}
const device_ctx_t* x86_device_context() { return x86_device_context_.get(); }
void SetX86DeviceContext(std::unique_ptr<device_ctx_t>&& ctx) {
x86_device_context_ = std::move(ctx);
}
const execution_ctx_t* x86_execution_context() {
return x86_execution_context_.get();
}
void SetX86ExecutionContext(std::unique_ptr<execution_ctx_t>&& ctx) {
x86_execution_context_ = std::move(ctx);
}
std::string name() const { return "X86Context"; }
private:
// overall information
//
// kernel information
// legacy info.
std::unique_ptr<::paddle::platform::CPUDeviceContext> x86_device_context;
std::unique_ptr<::paddle::framework::ExecutionContext> x86_execution_context;
std::unique_ptr<device_ctx_t> x86_device_context_;
std::unique_ptr<execution_ctx_t> x86_execution_context_;
};
#endif
......@@ -124,5 +219,67 @@ class KernelContext {
Any ctx_;
};
// The ContextScheduler helps to assign different context for each kernel.
class ContextScheduler {
public:
static ContextScheduler& Global() {
static auto* x = new ContextScheduler;
return *x;
}
std::unique_ptr<KernelContext> NewContext(TargetType target) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
switch (target) {
case TARGET(kHost):
kernel_contexts_[TargetType::kHost].As<HostContext>().CopyShared(
&ctx->As<HostContext>());
break;
#ifdef LITE_WITH_X86
case TARGET(kX86):
kernel_contexts_[TargetType::kX86].As<X86Context>().CopyShared(
&ctx->As<X86Context>());
break;
#endif
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
kernel_contexts_[TargetType::kCUDA].As<CUDAContext>().CopyShared(
&ctx->As<CUDAContext>());
break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
kernel_contexts_[TargetType::kARM].As<ARMContext>().CopyShared(
&ctx->As<ARMContext>());
break;
#endif
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
}
return ctx;
}
private:
template <TargetType Type, typename ContextT>
void InitContext() {
kernel_contexts_[Type].As<ContextT>().InitOnce();
}
ContextScheduler() {
InitContext<TargetType::kHost, HostContext>();
#ifdef LITE_WITH_X86
InitContext<TargetType::kX86, X86Context>();
#endif
#ifdef LITE_WITH_CUDA
InitContext<TargetType::kCUDA, CUDAContext>();
#endif
#ifdef LITE_WITH_ARM
InitContext<TargetType::kARM, ARMContext>();
#endif
}
private:
std::map<TargetType, KernelContext> kernel_contexts_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/context.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
#ifdef LITE_WITH_X86
TEST(ContextScheduler, NewContext) {
auto ctx1_p = ContextScheduler::Global().NewContext(TargetType::kX86);
auto ctx2_p = ContextScheduler::Global().NewContext(TargetType::kX86);
ASSERT_FALSE(ctx1_p.get() == ctx2_p.get());
auto& ctx1 = ctx1_p->As<X86Context>();
auto& ctx2 = ctx2_p->As<X86Context>();
ASSERT_EQ(ctx1.name(), "X86Context");
ASSERT_EQ(ctx2.name(), "X86Context");
ASSERT_FALSE(ctx1.x86_device_context() == nullptr ||
ctx2.x86_device_context() == nullptr);
ASSERT_FALSE(ctx1.x86_execution_context() == nullptr ||
ctx2.x86_execution_context() == nullptr);
ASSERT_TRUE(ctx1.x86_device_context() != ctx2.x86_device_context());
ASSERT_TRUE(ctx1.x86_execution_context() != ctx2.x86_execution_context());
using device_ctx_t = ::paddle::platform::CPUDeviceContext;
using exec_ctx_t = ::paddle::framework::ExecutionContext;
auto* device_ctx = new device_ctx_t;
ctx1.SetX86DeviceContext(std::unique_ptr<device_ctx_t>(device_ctx));
ctx1.SetX86ExecutionContext(
std::unique_ptr<exec_ctx_t>(new exec_ctx_t(*device_ctx)));
}
#endif
} // namespace lite
} // namespace paddle
......@@ -21,85 +21,16 @@ namespace mir {
class RuntimeContextAssignPass : public StmtPass {
public:
RuntimeContextAssignPass() {
#ifdef LITE_WITH_CUDA
InitCudaBlas();
#endif
}
RuntimeContextAssignPass() {}
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
auto& inst = node.AsStmt();
switch (inst.picked_kernel().target()) {
case TARGET(kHost):
inst.picked_kernel().SetContext(NewHostContext());
break;
#ifdef LITE_WITH_X86
case TARGET(kX86):
inst.picked_kernel().SetContext(NewX86Context());
break;
#endif
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
inst.picked_kernel().SetContext(NewCudaContext());
break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
inst.picked_kernel().SetContext(NewARMContext());
break;
#endif
default:
LOG(FATAL) << "unsupported target "
<< TargetToStr(inst.picked_kernel().target());
}
}
}
std::unique_ptr<KernelContext> NewHostContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<HostContext>();
// Some initialization here.
return ctx;
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
}
#ifdef LITE_WITH_X86
std::unique_ptr<KernelContext> NewX86Context() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
return ctx;
}
#endif
#ifdef LITE_WITH_ARM
std::unique_ptr<KernelContext> NewARMContext() {
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
return ctx;
}
#endif
#ifdef LITE_WITH_CUDA
std::unique_ptr<KernelContext> NewCudaContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& cuda = ctx->As<CUDAContext>();
// Some initialization here.
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
cuda.blas_fp32 = cublas_fp32_;
return ctx;
}
void InitCudaBlas() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
#endif
private:
#ifdef LITE_WITH_CUDA
std::shared_ptr<lite::cuda::Blas<float>> cublas_fp32_;
#endif
};
} // namespace mir
......
......@@ -37,9 +37,9 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
void Run() override {
CHECK(ctx_) << "running context should be set first";
auto& context = ctx_->As<CUDAContext>();
CHECK(context.blas_fp32) << "blas should init first";
CHECK(context.cublas_fp32()) << "blas should init first";
/*
auto& blas = *context.blas_fp32;
auto& blas = *context.cublas_fp32();
CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>();
int x_h = param.x->dims()[0];
......
......@@ -62,10 +62,10 @@ class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
Activate<paddle::operators::SquareFunctor<T>>(*context.x86_device_context,
Activate<paddle::operators::SquareFunctor<T>>(*context.x86_device_context(),
&param.X->raw_tensor(),
&param.Out->raw_tensor());
}
......@@ -81,11 +81,11 @@ class SquareGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationGradParam>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>();
ActivateGrad<paddle::operators::SquareGradFunctor<T>>(
*context.x86_device_context, &param.X->raw_tensor(),
*context.x86_device_context(), &param.X->raw_tensor(),
&param.Out->raw_tensor(), &param.Out_grad->raw_tensor(),
&param.X_grad->raw_tensor());
}
......
......@@ -44,12 +44,12 @@ class ElementwiseSubCompute
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<SubFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context, &param.X->raw_tensor(),
*context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, SubFunctor<T>(),
&param.Out->raw_tensor());
}
......@@ -75,7 +75,7 @@ class ElementwiseSubGradCompute
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>();
param.Y_grad->template mutable_data<T>();
......@@ -86,8 +86,8 @@ class ElementwiseSubGradCompute
auto& skip = dout;
paddle::operators::ElemwiseExplicitGradCompute<
platform::CPUDeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
*context.x86_execution_context, skip, skip, skip, dout, param.axis, &dx,
&dy, SubGradDX<T>(), SubGradDY<T>());
*context.x86_execution_context(), skip, skip, skip, dout, param.axis,
&dx, &dy, SubGradDX<T>(), SubGradDY<T>());
}
virtual ~ElementwiseSubGradCompute() = default;
......@@ -101,11 +101,11 @@ class ElementwiseAddCompute
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<AddFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context, &param.X->raw_tensor(),
*context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, AddFunctor<T>(),
&param.Out->raw_tensor());
}
......
......@@ -32,12 +32,12 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::math::set_constant(
*context.x86_device_context, &param.Out->raw_tensor(), param.value);
*context.x86_device_context(), &param.Out->raw_tensor(), param.value);
}
virtual ~FillConstantCompute() = default;
......
......@@ -38,13 +38,13 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
auto X = EigenVector<T>::Flatten(param.X->raw_tensor());
auto y = EigenScalar<T>::From(param.Out->raw_tensor());
const auto& place = *(context.x86_device_context->eigen_device());
const auto& place = *(context.x86_device_context()->eigen_device());
y.device(place) = X.mean();
}
......@@ -61,13 +61,13 @@ class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1);
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>();
T x_grad_size = static_cast<T>(param.X_grad->raw_tensor().numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_grad_size));
EigenVector<T>::Flatten(param.X_grad->raw_tensor())
.device(*(context.x86_device_context->eigen_device())) =
.device(*(context.x86_device_context()->eigen_device())) =
(EigenVector<T>::From(param.Out_grad->raw_tensor()) / x_grad_size)
.broadcast(bcast);
}
......
......@@ -32,7 +32,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulParam>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
param.output->template mutable_data<T>();
......@@ -53,7 +53,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context);
*context.x86_device_context());
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
......@@ -70,7 +70,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulGradParam>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
......@@ -99,7 +99,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context);
*context.x86_device_context());
if (dx) {
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
param.x_grad->template mutable_data<T>();
......
......@@ -32,5 +32,8 @@ set(ops_lite
dropout_op_lite
PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite X86_DEPS fc_compute_x86)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite
X86_DEPS fc_compute_x86
ARM_DEPS fc_compute_arm)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
......@@ -20,7 +20,7 @@ namespace paddle {
namespace lite {
namespace operators {
TEST(fc_op_lite, test) {
TEST(fc_op_lite, TestX86) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
......@@ -57,9 +57,11 @@ TEST(fc_op_lite, test) {
FcOpLite fc("fc");
fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}});
fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
fc.Attach(desc, &scope);
auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)}});
auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
ASSERT_FALSE(kernels.empty());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册