Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
95bc0ce7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
95bc0ce7
编写于
6月 03, 2019
作者:
S
sangoly
提交者:
Yan Chunwei
6月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor context to support both server and light (#17762)
上级
a5501e3a
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
308 addition
and
144 deletion
+308
-144
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+13
-12
paddle/fluid/lite/api/light_api.h
paddle/fluid/lite/api/light_api.h
+2
-0
paddle/fluid/lite/api/light_api_test.cc
paddle/fluid/lite/api/light_api_test.cc
+15
-0
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/context.cc
paddle/fluid/lite/core/context.cc
+15
-14
paddle/fluid/lite/core/context.h
paddle/fluid/lite/core/context.h
+176
-19
paddle/fluid/lite/core/context_test.cc
paddle/fluid/lite/core/context_test.cc
+51
-0
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
+3
-72
paddle/fluid/lite/kernels/cuda/mul_compute.h
paddle/fluid/lite/kernels/cuda/mul_compute.h
+2
-2
paddle/fluid/lite/kernels/x86/activation_compute.cc
paddle/fluid/lite/kernels/x86/activation_compute.cc
+4
-4
paddle/fluid/lite/kernels/x86/elementwise_compute.cc
paddle/fluid/lite/kernels/x86/elementwise_compute.cc
+7
-7
paddle/fluid/lite/kernels/x86/fill_constant_compute.cc
paddle/fluid/lite/kernels/x86/fill_constant_compute.cc
+2
-2
paddle/fluid/lite/kernels/x86/mean_compute.cc
paddle/fluid/lite/kernels/x86/mean_compute.cc
+4
-4
paddle/fluid/lite/kernels/x86/mul_compute.cc
paddle/fluid/lite/kernels/x86/mul_compute.cc
+4
-4
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+4
-1
paddle/fluid/lite/operators/fc_op_test.cc
paddle/fluid/lite/operators/fc_op_test.cc
+5
-3
未找到文件。
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
95bc0ce7
...
@@ -25,24 +25,25 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc
...
@@ -25,24 +25,25 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc
set
(
LITE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
set
(
LITE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
"A path setting inference demo download directories."
)
"A path setting inference demo download directories."
)
# lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
# DEPS cxx_api_lite model_parser_lite target_wrapper_host
# PROFILE_DEPS basic_profiler_lite
# ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
# --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
if
((
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
AND WITH_TESTING
)
if
((
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
AND WITH_TESTING
)
lite_cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc
lite_cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host
DEPS cxx_api_lite model_parser_lite target_wrapper_host
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
PROFILE_DEPS basic_profiler_lite
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_naive_model
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_naive_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_naive_model_opt SERIAL
)
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_naive_model_opt SERIAL
)
lite_download_and_uncompress
(
${
LITE_MODEL_DIR
}
${
LITE_URL
}
"lite_naive_model.tar.gz"
)
lite_download_and_uncompress
(
${
LITE_MODEL_DIR
}
${
LITE_URL
}
"lite_naive_model.tar.gz"
)
add_dependencies
(
test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz
)
add_dependencies
(
test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz
)
endif
()
endif
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
if
(
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING
)
add_dependencies
(
test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz
)
endif
(
WITH_TESTING
)
# if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif()
lite_cc_binary
(
cxx_api_lite_bin SRCS cxx_api_bin.cc
lite_cc_binary
(
cxx_api_lite_bin SRCS cxx_api_bin.cc
...
...
paddle/fluid/lite/api/light_api.h
浏览文件 @
95bc0ce7
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
...
@@ -84,6 +85,7 @@ class LightPredictor {
...
@@ -84,6 +85,7 @@ class LightPredictor {
return
it
->
alias
()
==
alias
;
return
it
->
alias
()
==
alias
;
});
});
CHECK
(
it
!=
kernels
.
end
());
CHECK
(
it
!=
kernels
.
end
());
(
*
it
)
->
SetContext
(
ContextScheduler
::
Global
().
NewContext
((
*
it
)
->
target
()));
insts
.
emplace_back
(
op
,
std
::
move
(
*
it
));
insts
.
emplace_back
(
op
,
std
::
move
(
*
it
));
}
}
program_
.
reset
(
new
RuntimeProgram
(
std
::
move
(
insts
)));
program_
.
reset
(
new
RuntimeProgram
(
std
::
move
(
insts
)));
...
...
paddle/fluid/lite/api/light_api_test.cc
浏览文件 @
95bc0ce7
...
@@ -44,3 +44,18 @@ USE_LITE_OP(scale);
...
@@ -44,3 +44,18 @@ USE_LITE_OP(scale);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
io_copy
);
USE_LITE_OP
(
io_copy
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL
(
relu
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mul
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
fc
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
square
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_sub
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_add
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
softmax
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
dropout
,
kX86
,
kFloat
,
kNCHW
,
def
);
#endif
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
95bc0ce7
...
@@ -54,3 +54,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
...
@@ -54,3 +54,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
#lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite)
#lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite)
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_memory_lite SRCS memory_test.cc DEPS memory_lite
)
lite_cc_test
(
test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator
)
paddle/fluid/lite/core/context.cc
浏览文件 @
95bc0ce7
...
@@ -33,7 +33,7 @@ namespace lite {
...
@@ -33,7 +33,7 @@ namespace lite {
#ifdef LITE_WITH_ARM
#ifdef LITE_WITH_ARM
void
ARMContext
::
SetCache
(
int
l1size
,
int
l2size
,
int
l3size
)
{
void
Context
<
TargetType
::
kARM
>
::
SetCache
(
int
l1size
,
int
l2size
,
int
l3size
)
{
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
int
cpu_count
=
arm_get_cpucount
();
int
cpu_count
=
arm_get_cpucount
();
dev
.
L1_cache_
.
resize
(
cpu_count
);
dev
.
L1_cache_
.
resize
(
cpu_count
);
...
@@ -47,7 +47,7 @@ void ARMContext::SetCache(int l1size, int l2size, int l3size) {
...
@@ -47,7 +47,7 @@ void ARMContext::SetCache(int l1size, int l2size, int l3size) {
workspace_
.
Resize
({
2
*
(
l1size
+
l2size
)});
workspace_
.
Resize
({
2
*
(
l1size
+
l2size
)});
}
}
ARMContext
::
ARM
Context
()
{
Context
<
TargetType
::
kARM
>::
Context
()
{
active_ids_
=
{
0
};
active_ids_
=
{
0
};
mode_
=
LITE_POWER_HIGH
;
mode_
=
LITE_POWER_HIGH
;
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
...
@@ -62,11 +62,11 @@ ARMContext::ARMContext() {
...
@@ -62,11 +62,11 @@ ARMContext::ARMContext() {
#endif
#endif
}
}
PowerMode
ARMContext
::
mode
()
const
{
return
mode_
;
}
PowerMode
Context
<
TargetType
::
kARM
>
::
mode
()
const
{
return
mode_
;
}
int
ARMContext
::
threads
()
const
{
return
active_ids_
.
size
();
}
int
Context
<
TargetType
::
kARM
>
::
threads
()
const
{
return
active_ids_
.
size
();
}
ARMContext
::
ARM
Context
(
const
ARMContext
&
ctx
)
{
Context
<
TargetType
::
kARM
>::
Context
(
const
ARMContext
&
ctx
)
{
mode_
=
ctx
.
mode_
;
mode_
=
ctx
.
mode_
;
active_ids_
=
ctx
.
active_ids_
;
active_ids_
=
ctx
.
active_ids_
;
workspace_
=
ctx
.
workspace_
;
workspace_
=
ctx
.
workspace_
;
...
@@ -74,7 +74,7 @@ ARMContext::ARMContext(const ARMContext& ctx) {
...
@@ -74,7 +74,7 @@ ARMContext::ARMContext(const ARMContext& ctx) {
count_
=
ctx
.
count_
;
count_
=
ctx
.
count_
;
}
}
ARMContext
&
ARMContext
::
operator
=
(
const
ARMContext
&
ctx
)
{
ARMContext
&
Context
<
TargetType
::
kARM
>
::
operator
=
(
const
ARMContext
&
ctx
)
{
mode_
=
ctx
.
mode_
;
mode_
=
ctx
.
mode_
;
active_ids_
=
ctx
.
active_ids_
;
active_ids_
=
ctx
.
active_ids_
;
workspace_
=
ctx
.
workspace_
;
workspace_
=
ctx
.
workspace_
;
...
@@ -83,7 +83,7 @@ ARMContext& ARMContext::operator=(const ARMContext& ctx) {
...
@@ -83,7 +83,7 @@ ARMContext& ARMContext::operator=(const ARMContext& ctx) {
return
*
this
;
return
*
this
;
}
}
void
ARMContext
::
BindDev
()
{
void
Context
<
TargetType
::
kARM
>
::
BindDev
()
{
#ifdef USE_OPENMP
#ifdef USE_OPENMP
int
num_threads
=
active_ids_
.
size
();
int
num_threads
=
active_ids_
.
size
();
omp_set_num_threads
(
num_threads
);
omp_set_num_threads
(
num_threads
);
...
@@ -116,7 +116,7 @@ void ARMContext::BindDev() {
...
@@ -116,7 +116,7 @@ void ARMContext::BindDev() {
#endif // USE_OPENMP
#endif // USE_OPENMP
}
}
void
ARMContext
::
SetRunMode
(
PowerMode
mode
,
int
threads
)
{
void
Context
<
TargetType
::
kARM
>
::
SetRunMode
(
PowerMode
mode
,
int
threads
)
{
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
int
big_core_size
=
dev
.
big_core_ids_
.
size
();
int
big_core_size
=
dev
.
big_core_ids_
.
size
();
int
small_core_size
=
dev
.
little_core_ids_
.
size
();
int
small_core_size
=
dev
.
little_core_ids_
.
size
();
...
@@ -293,26 +293,26 @@ void ARMContext::SetRunMode(PowerMode mode, int threads) {
...
@@ -293,26 +293,26 @@ void ARMContext::SetRunMode(PowerMode mode, int threads) {
arch_
=
DeviceInfo
::
Global
().
archs_
[
active_ids_
[
0
]];
arch_
=
DeviceInfo
::
Global
().
archs_
[
active_ids_
[
0
]];
}
}
ARMArch
ARMContext
::
arch
()
const
{
return
arch_
;
}
ARMArch
Context
<
TargetType
::
kARM
>
::
arch
()
const
{
return
arch_
;
}
void
ARMContext
::
SetArch
(
ARMArch
arch
)
{
arch_
=
arch
;
}
void
Context
<
TargetType
::
kARM
>
::
SetArch
(
ARMArch
arch
)
{
arch_
=
arch
;
}
int
ARMContext
::
l1_cache_size
()
const
{
int
Context
<
TargetType
::
kARM
>
::
l1_cache_size
()
const
{
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
return
dev
.
L1_cache_
[
active_ids_
[
0
]];
return
dev
.
L1_cache_
[
active_ids_
[
0
]];
}
}
int
ARMContext
::
l2_cache_size
()
const
{
int
Context
<
TargetType
::
kARM
>
::
l2_cache_size
()
const
{
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
return
dev
.
L2_cache_
[
active_ids_
[
0
]];
return
dev
.
L2_cache_
[
active_ids_
[
0
]];
}
}
int
ARMContext
::
l3_cache_size
()
const
{
int
Context
<
TargetType
::
kARM
>
::
l3_cache_size
()
const
{
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
DeviceInfo
&
dev
=
DeviceInfo
::
Global
();
return
dev
.
L3_cache_
[
active_ids_
[
0
]];
return
dev
.
L3_cache_
[
active_ids_
[
0
]];
}
}
bool
ARMContext
::
ExtendWorkspace
(
DDimLite
dims
)
{
bool
Context
<
TargetType
::
kARM
>
::
ExtendWorkspace
(
DDimLite
dims
)
{
auto
count
=
dims
.
product
();
auto
count
=
dims
.
product
();
auto
old
=
workspace_
.
dims
();
auto
old
=
workspace_
.
dims
();
if
(
count
==
old
.
product
())
{
if
(
count
==
old
.
product
())
{
...
@@ -324,5 +324,6 @@ bool ARMContext::ExtendWorkspace(DDimLite dims) {
...
@@ -324,5 +324,6 @@ bool ARMContext::ExtendWorkspace(DDimLite dims) {
return
true
;
return
true
;
}
}
#endif // LITE_WITH_ARM
#endif // LITE_WITH_ARM
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/context.h
浏览文件 @
95bc0ce7
...
@@ -23,28 +23,55 @@
...
@@ -23,28 +23,55 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#endif
#endif
#include <map>
#include <memory>
#include <memory>
#include <set>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
struct
HostContext
{};
template
<
TargetType
Type
>
class
Context
;
using
HostContext
=
Context
<
TargetType
::
kHost
>
;
using
X86Context
=
Context
<
TargetType
::
kX86
>
;
using
CUDAContext
=
Context
<
TargetType
::
kCUDA
>
;
using
ARMContext
=
Context
<
TargetType
::
kARM
>
;
template
<
>
class
Context
<
TargetType
::
kHost
>
{
public:
// NOTE: InitOnce should only be used by ContextScheduler
void
InitOnce
()
{}
void
CopyShared
(
const
HostContext
*
ctx
)
{}
std
::
string
name
()
const
{
return
"HostContext"
;
}
};
#ifdef LITE_WITH_ARM
#ifdef LITE_WITH_ARM
struct
ARMContext
{
template
<
>
class
Context
<
TargetType
::
kARM
>
{
public:
public:
ARM
Context
();
Context
();
ARM
Context
(
PowerMode
mode
,
int
threads
);
Context
(
PowerMode
mode
,
int
threads
);
ARM
Context
(
const
ARMContext
&
ctx
);
explicit
Context
(
const
ARMContext
&
ctx
);
ARMContext
&
operator
=
(
const
ARMContext
&
ctx
);
ARMContext
&
operator
=
(
const
ARMContext
&
ctx
);
// NOTE: InitOnce should only be used by ContextScheduler
void
InitOnce
()
{
DeviceInfo
::
Init
();
}
void
CopyShared
(
const
ARMContext
*
ctx
)
{}
void
SetRunMode
(
PowerMode
mode
,
int
threads
);
void
SetRunMode
(
PowerMode
mode
,
int
threads
);
void
SetCache
(
int
l1size
,
int
l2size
,
int
l3size
);
void
SetCache
(
int
l1size
,
int
l2size
,
int
l3size
);
void
SetArch
(
ARMArch
arch
);
void
SetArch
(
ARMArch
arch
);
...
@@ -64,6 +91,8 @@ struct ARMContext {
...
@@ -64,6 +91,8 @@ struct ARMContext {
int
l3_cache_size
()
const
;
int
l3_cache_size
()
const
;
bool
ExtendWorkspace
(
DDimLite
dims
);
bool
ExtendWorkspace
(
DDimLite
dims
);
std
::
string
name
()
const
{
return
"ARMContext"
;
}
private:
private:
// LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_LOW stands for using small core,
// LITE_POWER_LOW stands for using small core,
...
@@ -78,33 +107,99 @@ struct ARMContext {
...
@@ -78,33 +107,99 @@ struct ARMContext {
#ifdef LITE_WITH_CUDA
#ifdef LITE_WITH_CUDA
// Only works with CUDA kernels.
// Only works with CUDA kernels.
struct
CUDAContext
{
template
<
>
class
Context
<
TargetType
::
kCUDA
>
{
public:
// NOTE: InitOnce should only be used by ContextScheduler
void
InitOnce
()
{
cublas_fp32_
=
std
::
make_shared
<
lite
::
cuda
::
Blas
<
float
>>
();
}
void
CopyShared
(
const
CUDAContext
*
ctx
)
{
CHECK
(
ctx
);
CHECK
(
cublas_fp32_
)
<<
"cublas_fp32 should be set first"
;
ctx
->
cublas_fp32_
=
cublas_fp32_
;
}
const
cudaStream_t
exec_stream
()
{
return
exec_stream_
;
}
void
SetExecStream
(
cudaStream_t
stream
)
{
exec_stream_
=
stream
;
}
const
cudaStream_t
io_stream
()
{
return
io_stream_
;
}
void
SetIoStream
(
cudaStream_t
stream
)
{
io_stream_
=
stream
;
}
std
::
shared_ptr
<
cuda
::
Blas
<
float
>>
cublas_fp32
()
{
return
cublas_fp32_
;
}
void
SetCuBlasFP32
(
std
::
shared_ptr
<
cuda
::
Blas
<
float
>>
cublas_fp32
)
{
cublas_fp32_
=
cublas_fp32
;
}
const
std
::
vector
<
cudaEvent_t
>&
input_events
()
{
return
input_events_
;
}
void
SetInputEvents
(
const
std
::
vector
<
cudaEvent_t
>&
input_events
)
{
input_events_
.
clear
();
input_events_
.
assign
(
input_events
.
begin
(),
input_events
.
end
());
}
const
std
::
vector
<
cudaEvent_t
>&
output_events
()
{
return
output_events_
;
}
void
SetOutputEvents
(
const
std
::
vector
<
cudaEvent_t
>&
output_events
)
{
output_events_
.
clear
();
output_events_
.
assign
(
output_events
.
begin
(),
output_events
.
end
());
}
std
::
string
name
()
const
{
return
"CUDAContext"
;
}
private:
// overall information
// overall information
cudaStream_t
exec_stream
;
cudaStream_t
exec_stream
_
;
cudaStream_t
io_stream
;
cudaStream_t
io_stream
_
;
// not thread-safe, should allocate for each thread.
// not thread-safe, should allocate for each thread.
std
::
shared_ptr
<
cuda
::
Blas
<
float
>>
blas_fp32
;
std
::
shared_ptr
<
cuda
::
Blas
<
float
>>
cublas_fp32_
;
// kernel information
// kernel information
std
::
vector
<
cudaEvent_t
>
input_events
;
std
::
vector
<
cudaEvent_t
>
input_events
_
;
std
::
vector
<
cudaEvent_t
>
output_events
;
std
::
vector
<
cudaEvent_t
>
output_events
_
;
};
};
#endif
#endif
#ifdef LITE_WITH_X86
#ifdef LITE_WITH_X86
struct
X86Context
{
template
<
>
// overall information
class
Context
<
TargetType
::
kX86
>
{
X86Context
()
{
public:
x86_device_context
.
reset
(
new
::
paddle
::
platform
::
CPUDeviceContext
);
using
device_ctx_t
=
::
paddle
::
platform
::
CPUDeviceContext
;
x86_execution_context
.
reset
(
using
execution_ctx_t
=
::
paddle
::
framework
::
ExecutionContext
;
new
::
paddle
::
framework
::
ExecutionContext
(
*
x86_device_context
));
Context
()
{
x86_device_context_
.
reset
(
new
::
paddle
::
platform
::
CPUDeviceContext
);
x86_execution_context_
.
reset
(
new
::
paddle
::
framework
::
ExecutionContext
(
*
x86_device_context_
));
}
// NOTE: InitOnce should only be used by ContextScheduler
void
InitOnce
()
{}
void
CopyShared
(
const
X86Context
*
ctx
)
{}
const
device_ctx_t
*
x86_device_context
()
{
return
x86_device_context_
.
get
();
}
void
SetX86DeviceContext
(
std
::
unique_ptr
<
device_ctx_t
>&&
ctx
)
{
x86_device_context_
=
std
::
move
(
ctx
);
}
const
execution_ctx_t
*
x86_execution_context
()
{
return
x86_execution_context_
.
get
();
}
void
SetX86ExecutionContext
(
std
::
unique_ptr
<
execution_ctx_t
>&&
ctx
)
{
x86_execution_context_
=
std
::
move
(
ctx
);
}
}
std
::
string
name
()
const
{
return
"X86Context"
;
}
private:
// overall information
//
// kernel information
// kernel information
// legacy info.
// legacy info.
std
::
unique_ptr
<
::
paddle
::
platform
::
CPUDeviceContext
>
x86_device_context
;
std
::
unique_ptr
<
device_ctx_t
>
x86_device_context_
;
std
::
unique_ptr
<
::
paddle
::
framework
::
ExecutionContext
>
x86_execution_context
;
std
::
unique_ptr
<
execution_ctx_t
>
x86_execution_context_
;
};
};
#endif
#endif
...
@@ -124,5 +219,67 @@ class KernelContext {
...
@@ -124,5 +219,67 @@ class KernelContext {
Any
ctx_
;
Any
ctx_
;
};
};
// The ContextScheduler helps to assign different context for each kernel.
class
ContextScheduler
{
public:
static
ContextScheduler
&
Global
()
{
static
auto
*
x
=
new
ContextScheduler
;
return
*
x
;
}
std
::
unique_ptr
<
KernelContext
>
NewContext
(
TargetType
target
)
{
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
switch
(
target
)
{
case
TARGET
(
kHost
):
kernel_contexts_
[
TargetType
::
kHost
].
As
<
HostContext
>
().
CopyShared
(
&
ctx
->
As
<
HostContext
>
());
break
;
#ifdef LITE_WITH_X86
case
TARGET
(
kX86
):
kernel_contexts_
[
TargetType
::
kX86
].
As
<
X86Context
>
().
CopyShared
(
&
ctx
->
As
<
X86Context
>
());
break
;
#endif
#ifdef LITE_WITH_CUDA
case
TARGET
(
kCUDA
):
kernel_contexts_
[
TargetType
::
kCUDA
].
As
<
CUDAContext
>
().
CopyShared
(
&
ctx
->
As
<
CUDAContext
>
());
break
;
#endif
#ifdef LITE_WITH_ARM
case
TARGET
(
kARM
):
kernel_contexts_
[
TargetType
::
kARM
].
As
<
ARMContext
>
().
CopyShared
(
&
ctx
->
As
<
ARMContext
>
());
break
;
#endif
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
}
return
ctx
;
}
private:
template
<
TargetType
Type
,
typename
ContextT
>
void
InitContext
()
{
kernel_contexts_
[
Type
].
As
<
ContextT
>
().
InitOnce
();
}
ContextScheduler
()
{
InitContext
<
TargetType
::
kHost
,
HostContext
>
();
#ifdef LITE_WITH_X86
InitContext
<
TargetType
::
kX86
,
X86Context
>
();
#endif
#ifdef LITE_WITH_CUDA
InitContext
<
TargetType
::
kCUDA
,
CUDAContext
>
();
#endif
#ifdef LITE_WITH_ARM
InitContext
<
TargetType
::
kARM
,
ARMContext
>
();
#endif
}
private:
std
::
map
<
TargetType
,
KernelContext
>
kernel_contexts_
;
};
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/context_test.cc
0 → 100644
浏览文件 @
95bc0ce7
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/context.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
lite
{
#ifdef LITE_WITH_X86
TEST
(
ContextScheduler
,
NewContext
)
{
auto
ctx1_p
=
ContextScheduler
::
Global
().
NewContext
(
TargetType
::
kX86
);
auto
ctx2_p
=
ContextScheduler
::
Global
().
NewContext
(
TargetType
::
kX86
);
ASSERT_FALSE
(
ctx1_p
.
get
()
==
ctx2_p
.
get
());
auto
&
ctx1
=
ctx1_p
->
As
<
X86Context
>
();
auto
&
ctx2
=
ctx2_p
->
As
<
X86Context
>
();
ASSERT_EQ
(
ctx1
.
name
(),
"X86Context"
);
ASSERT_EQ
(
ctx2
.
name
(),
"X86Context"
);
ASSERT_FALSE
(
ctx1
.
x86_device_context
()
==
nullptr
||
ctx2
.
x86_device_context
()
==
nullptr
);
ASSERT_FALSE
(
ctx1
.
x86_execution_context
()
==
nullptr
||
ctx2
.
x86_execution_context
()
==
nullptr
);
ASSERT_TRUE
(
ctx1
.
x86_device_context
()
!=
ctx2
.
x86_device_context
());
ASSERT_TRUE
(
ctx1
.
x86_execution_context
()
!=
ctx2
.
x86_execution_context
());
using
device_ctx_t
=
::
paddle
::
platform
::
CPUDeviceContext
;
using
exec_ctx_t
=
::
paddle
::
framework
::
ExecutionContext
;
auto
*
device_ctx
=
new
device_ctx_t
;
ctx1
.
SetX86DeviceContext
(
std
::
unique_ptr
<
device_ctx_t
>
(
device_ctx
));
ctx1
.
SetX86ExecutionContext
(
std
::
unique_ptr
<
exec_ctx_t
>
(
new
exec_ctx_t
(
*
device_ctx
)));
}
#endif
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
浏览文件 @
95bc0ce7
...
@@ -21,85 +21,16 @@ namespace mir {
...
@@ -21,85 +21,16 @@ namespace mir {
class
RuntimeContextAssignPass
:
public
StmtPass
{
class
RuntimeContextAssignPass
:
public
StmtPass
{
public:
public:
RuntimeContextAssignPass
()
{
RuntimeContextAssignPass
()
{}
#ifdef LITE_WITH_CUDA
InitCudaBlas
();
#endif
}
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
{
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsStmt
())
continue
;
if
(
!
node
.
IsStmt
())
continue
;
auto
&
inst
=
node
.
AsStmt
();
auto
&
inst
=
node
.
AsStmt
();
switch
(
inst
.
picked_kernel
().
target
())
{
inst
.
picked_kernel
().
SetContext
(
case
TARGET
(
kHost
):
ContextScheduler
::
Global
().
NewContext
(
inst
.
picked_kernel
().
target
()));
inst
.
picked_kernel
().
SetContext
(
NewHostContext
());
break
;
#ifdef LITE_WITH_X86
case
TARGET
(
kX86
):
inst
.
picked_kernel
().
SetContext
(
NewX86Context
());
break
;
#endif
#ifdef LITE_WITH_CUDA
case
TARGET
(
kCUDA
):
inst
.
picked_kernel
().
SetContext
(
NewCudaContext
());
break
;
#endif
#ifdef LITE_WITH_ARM
case
TARGET
(
kARM
):
inst
.
picked_kernel
().
SetContext
(
NewARMContext
());
break
;
#endif
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
inst
.
picked_kernel
().
target
());
}
}
}
}
}
std
::
unique_ptr
<
KernelContext
>
NewHostContext
()
{
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
ctx
->
As
<
HostContext
>
();
// Some initialization here.
return
ctx
;
}
#ifdef LITE_WITH_X86
std
::
unique_ptr
<
KernelContext
>
NewX86Context
()
{
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
ctx
->
As
<
X86Context
>
();
return
ctx
;
}
#endif
#ifdef LITE_WITH_ARM
std
::
unique_ptr
<
KernelContext
>
NewARMContext
()
{
DeviceInfo
::
Init
();
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
ctx
->
As
<
ARMContext
>
();
return
ctx
;
}
#endif
#ifdef LITE_WITH_CUDA
std
::
unique_ptr
<
KernelContext
>
NewCudaContext
()
{
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
cuda
=
ctx
->
As
<
CUDAContext
>
();
// Some initialization here.
CHECK
(
cublas_fp32_
)
<<
"cublas_fp32 should be set first"
;
cuda
.
blas_fp32
=
cublas_fp32_
;
return
ctx
;
}
void
InitCudaBlas
()
{
cublas_fp32_
=
std
::
make_shared
<
lite
::
cuda
::
Blas
<
float
>>
();
}
#endif
private:
#ifdef LITE_WITH_CUDA
std
::
shared_ptr
<
lite
::
cuda
::
Blas
<
float
>>
cublas_fp32_
;
#endif
};
};
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/kernels/cuda/mul_compute.h
浏览文件 @
95bc0ce7
...
@@ -37,9 +37,9 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
...
@@ -37,9 +37,9 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
CHECK
(
ctx_
)
<<
"running context should be set first"
;
CHECK
(
ctx_
)
<<
"running context should be set first"
;
auto
&
context
=
ctx_
->
As
<
CUDAContext
>
();
auto
&
context
=
ctx_
->
As
<
CUDAContext
>
();
CHECK
(
context
.
blas_fp32
)
<<
"blas should init first"
;
CHECK
(
context
.
cublas_fp32
()
)
<<
"blas should init first"
;
/*
/*
auto& blas = *context.
blas_fp32
;
auto& blas = *context.
cublas_fp32()
;
CHECK(param.x->target() == TARGET(kCUDA));
CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>();
auto* x = param.x->data<float>();
int x_h = param.x->dims()[0];
int x_h = param.x->dims()[0];
...
...
paddle/fluid/lite/kernels/x86/activation_compute.cc
浏览文件 @
95bc0ce7
...
@@ -62,10 +62,10 @@ class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -62,10 +62,10 @@ class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
ActivationParam
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
ActivationParam
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
Out
->
template
mutable_data
<
T
>();
param
.
Out
->
template
mutable_data
<
T
>();
Activate
<
paddle
::
operators
::
SquareFunctor
<
T
>>
(
*
context
.
x86_device_context
,
Activate
<
paddle
::
operators
::
SquareFunctor
<
T
>>
(
*
context
.
x86_device_context
()
,
&
param
.
X
->
raw_tensor
(),
&
param
.
X
->
raw_tensor
(),
&
param
.
Out
->
raw_tensor
());
&
param
.
Out
->
raw_tensor
());
}
}
...
@@ -81,11 +81,11 @@ class SquareGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -81,11 +81,11 @@ class SquareGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
ActivationGradParam
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
ActivationGradParam
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
X_grad
->
template
mutable_data
<
T
>();
param
.
X_grad
->
template
mutable_data
<
T
>();
ActivateGrad
<
paddle
::
operators
::
SquareGradFunctor
<
T
>>
(
ActivateGrad
<
paddle
::
operators
::
SquareGradFunctor
<
T
>>
(
*
context
.
x86_device_context
,
&
param
.
X
->
raw_tensor
(),
*
context
.
x86_device_context
()
,
&
param
.
X
->
raw_tensor
(),
&
param
.
Out
->
raw_tensor
(),
&
param
.
Out_grad
->
raw_tensor
(),
&
param
.
Out
->
raw_tensor
(),
&
param
.
Out_grad
->
raw_tensor
(),
&
param
.
X_grad
->
raw_tensor
());
&
param
.
X_grad
->
raw_tensor
());
}
}
...
...
paddle/fluid/lite/kernels/x86/elementwise_compute.cc
浏览文件 @
95bc0ce7
...
@@ -44,12 +44,12 @@ class ElementwiseSubCompute
...
@@ -44,12 +44,12 @@ class ElementwiseSubCompute
void
Run
()
override
{
void
Run
()
override
{
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
Out
->
template
mutable_data
<
T
>();
param
.
Out
->
template
mutable_data
<
T
>();
paddle
::
operators
::
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
paddle
::
operators
::
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
platform
::
CPUDeviceContext
,
T
>
(
platform
::
CPUDeviceContext
,
T
>
(
*
context
.
x86_execution_context
,
&
param
.
X
->
raw_tensor
(),
*
context
.
x86_execution_context
()
,
&
param
.
X
->
raw_tensor
(),
&
param
.
Y
->
raw_tensor
(),
param
.
axis
,
SubFunctor
<
T
>
(),
&
param
.
Y
->
raw_tensor
(),
param
.
axis
,
SubFunctor
<
T
>
(),
&
param
.
Out
->
raw_tensor
());
&
param
.
Out
->
raw_tensor
());
}
}
...
@@ -75,7 +75,7 @@ class ElementwiseSubGradCompute
...
@@ -75,7 +75,7 @@ class ElementwiseSubGradCompute
void
Run
()
override
{
void
Run
()
override
{
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
X_grad
->
template
mutable_data
<
T
>();
param
.
X_grad
->
template
mutable_data
<
T
>();
param
.
Y_grad
->
template
mutable_data
<
T
>();
param
.
Y_grad
->
template
mutable_data
<
T
>();
...
@@ -86,8 +86,8 @@ class ElementwiseSubGradCompute
...
@@ -86,8 +86,8 @@ class ElementwiseSubGradCompute
auto
&
skip
=
dout
;
auto
&
skip
=
dout
;
paddle
::
operators
::
ElemwiseExplicitGradCompute
<
paddle
::
operators
::
ElemwiseExplicitGradCompute
<
platform
::
CPUDeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
platform
::
CPUDeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
*
context
.
x86_execution_context
,
skip
,
skip
,
skip
,
dout
,
param
.
axis
,
&
dx
,
*
context
.
x86_execution_context
(),
skip
,
skip
,
skip
,
dout
,
param
.
axis
,
&
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
&
d
x
,
&
d
y
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
}
virtual
~
ElementwiseSubGradCompute
()
=
default
;
virtual
~
ElementwiseSubGradCompute
()
=
default
;
...
@@ -101,11 +101,11 @@ class ElementwiseAddCompute
...
@@ -101,11 +101,11 @@ class ElementwiseAddCompute
void
Run
()
override
{
void
Run
()
override
{
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
Out
->
template
mutable_data
<
T
>();
param
.
Out
->
template
mutable_data
<
T
>();
paddle
::
operators
::
ElementwiseComputeEx
<
AddFunctor
<
T
>
,
paddle
::
operators
::
ElementwiseComputeEx
<
AddFunctor
<
T
>
,
platform
::
CPUDeviceContext
,
T
>
(
platform
::
CPUDeviceContext
,
T
>
(
*
context
.
x86_execution_context
,
&
param
.
X
->
raw_tensor
(),
*
context
.
x86_execution_context
()
,
&
param
.
X
->
raw_tensor
(),
&
param
.
Y
->
raw_tensor
(),
param
.
axis
,
AddFunctor
<
T
>
(),
&
param
.
Y
->
raw_tensor
(),
param
.
axis
,
AddFunctor
<
T
>
(),
&
param
.
Out
->
raw_tensor
());
&
param
.
Out
->
raw_tensor
());
}
}
...
...
paddle/fluid/lite/kernels/x86/fill_constant_compute.cc
浏览文件 @
95bc0ce7
...
@@ -32,12 +32,12 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -32,12 +32,12 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
Out
->
template
mutable_data
<
T
>();
param
.
Out
->
template
mutable_data
<
T
>();
paddle
::
operators
::
math
::
set_constant
(
paddle
::
operators
::
math
::
set_constant
(
*
context
.
x86_device_context
,
&
param
.
Out
->
raw_tensor
(),
param
.
value
);
*
context
.
x86_device_context
()
,
&
param
.
Out
->
raw_tensor
(),
param
.
value
);
}
}
virtual
~
FillConstantCompute
()
=
default
;
virtual
~
FillConstantCompute
()
=
default
;
...
...
paddle/fluid/lite/kernels/x86/mean_compute.cc
浏览文件 @
95bc0ce7
...
@@ -38,13 +38,13 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -38,13 +38,13 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
Out
->
template
mutable_data
<
T
>();
param
.
Out
->
template
mutable_data
<
T
>();
auto
X
=
EigenVector
<
T
>::
Flatten
(
param
.
X
->
raw_tensor
());
auto
X
=
EigenVector
<
T
>::
Flatten
(
param
.
X
->
raw_tensor
());
auto
y
=
EigenScalar
<
T
>::
From
(
param
.
Out
->
raw_tensor
());
auto
y
=
EigenScalar
<
T
>::
From
(
param
.
Out
->
raw_tensor
());
const
auto
&
place
=
*
(
context
.
x86_device_context
->
eigen_device
());
const
auto
&
place
=
*
(
context
.
x86_device_context
()
->
eigen_device
());
y
.
device
(
place
)
=
X
.
mean
();
y
.
device
(
place
)
=
X
.
mean
();
}
}
...
@@ -61,13 +61,13 @@ class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -61,13 +61,13 @@ class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
param
=
*
param_
.
get_mutable
<
param_t
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
CHECK_EQ
(
param
.
Out_grad
->
raw_tensor
().
numel
(),
1
);
CHECK_EQ
(
param
.
Out_grad
->
raw_tensor
().
numel
(),
1
);
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
X_grad
->
template
mutable_data
<
T
>();
param
.
X_grad
->
template
mutable_data
<
T
>();
T
x_grad_size
=
static_cast
<
T
>
(
param
.
X_grad
->
raw_tensor
().
numel
());
T
x_grad_size
=
static_cast
<
T
>
(
param
.
X_grad
->
raw_tensor
().
numel
());
Eigen
::
DSizes
<
int
,
1
>
bcast
(
static_cast
<
int
>
(
x_grad_size
));
Eigen
::
DSizes
<
int
,
1
>
bcast
(
static_cast
<
int
>
(
x_grad_size
));
EigenVector
<
T
>::
Flatten
(
param
.
X_grad
->
raw_tensor
())
EigenVector
<
T
>::
Flatten
(
param
.
X_grad
->
raw_tensor
())
.
device
(
*
(
context
.
x86_device_context
->
eigen_device
()))
=
.
device
(
*
(
context
.
x86_device_context
()
->
eigen_device
()))
=
(
EigenVector
<
T
>::
From
(
param
.
Out_grad
->
raw_tensor
())
/
x_grad_size
)
(
EigenVector
<
T
>::
From
(
param
.
Out_grad
->
raw_tensor
())
/
x_grad_size
)
.
broadcast
(
bcast
);
.
broadcast
(
bcast
);
}
}
...
...
paddle/fluid/lite/kernels/x86/mul_compute.cc
浏览文件 @
95bc0ce7
...
@@ -32,7 +32,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -32,7 +32,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
MulParam
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
MulParam
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
param
.
output
->
template
mutable_data
<
T
>();
param
.
output
->
template
mutable_data
<
T
>();
...
@@ -53,7 +53,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -53,7 +53,7 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
}
}
auto
blas
=
paddle
::
operators
::
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
auto
blas
=
paddle
::
operators
::
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
*
context
.
x86_device_context
);
*
context
.
x86_device_context
()
);
blas
.
MatMul
(
x_matrix
,
y_matrix
,
z
);
blas
.
MatMul
(
x_matrix
,
y_matrix
,
z
);
if
(
z_dim
.
size
()
!=
2
)
{
if
(
z_dim
.
size
()
!=
2
)
{
...
@@ -70,7 +70,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -70,7 +70,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void
Run
()
override
{
void
Run
()
override
{
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
MulGradParam
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
MulGradParam
>
();
CHECK
(
context
.
x86_device_context
);
CHECK
(
context
.
x86_device_context
()
);
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
x
=
&
param
.
x
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
auto
*
y
=
&
param
.
y
->
raw_tensor
();
...
@@ -99,7 +99,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -99,7 +99,7 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
}
}
auto
blas
=
paddle
::
operators
::
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
auto
blas
=
paddle
::
operators
::
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
*
context
.
x86_device_context
);
*
context
.
x86_device_context
()
);
if
(
dx
)
{
if
(
dx
)
{
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
param
.
x_grad
->
template
mutable_data
<
T
>();
param
.
x_grad
->
template
mutable_data
<
T
>();
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
95bc0ce7
...
@@ -32,5 +32,8 @@ set(ops_lite
...
@@ -32,5 +32,8 @@ set(ops_lite
dropout_op_lite
dropout_op_lite
PARENT_SCOPE
)
PARENT_SCOPE
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite X86_DEPS fc_compute_x86
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite
X86_DEPS fc_compute_x86
ARM_DEPS fc_compute_arm
)
lite_cc_test
(
test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite
)
lite_cc_test
(
test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite
)
paddle/fluid/lite/operators/fc_op_test.cc
浏览文件 @
95bc0ce7
...
@@ -20,7 +20,7 @@ namespace paddle {
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
operators
{
namespace
operators
{
TEST
(
fc_op_lite
,
test
)
{
TEST
(
fc_op_lite
,
TestX86
)
{
// prepare variables
// prepare variables
Scope
scope
;
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
...
@@ -57,9 +57,11 @@ TEST(fc_op_lite, test) {
...
@@ -57,9 +57,11 @@ TEST(fc_op_lite, test) {
FcOpLite
fc
(
"fc"
);
FcOpLite
fc
(
"fc"
);
fc
.
SetValidPlaces
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
fc
.
SetValidPlaces
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)}});
fc
.
Attach
(
desc
,
&
scope
);
fc
.
Attach
(
desc
,
&
scope
);
auto
kernels
=
fc
.
CreateKernels
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
auto
kernels
=
fc
.
CreateKernels
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)}});
ASSERT_FALSE
(
kernels
.
empty
());
ASSERT_FALSE
(
kernels
.
empty
());
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录