Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
d1d1e932
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d1d1e932
编写于
7月 10, 2018
作者:
L
liuruilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add conv add batch norm relu metal
上级
4a63c7bb
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
650 addition
and
101 deletion
+650
-101
metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj
...-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj
+11
-19
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
+32
-0
metal/paddle-mobile/paddle-mobile/Common/Extensions.swift
metal/paddle-mobile/paddle-mobile/Common/Extensions.swift
+16
-0
metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
...l/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
+1
-2
metal/paddle-mobile/paddle-mobile/Common/Types.swift
metal/paddle-mobile/paddle-mobile/Common/Types.swift
+54
-5
metal/paddle-mobile/paddle-mobile/Executor.swift
metal/paddle-mobile/paddle-mobile/Executor.swift
+11
-1
metal/paddle-mobile/paddle-mobile/Loader.swift
metal/paddle-mobile/paddle-mobile/Loader.swift
+2
-5
metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
...addle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
+11
-7
metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
...paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
+12
-12
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
...l/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
+2
-2
metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift
...bile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift
+56
-12
metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift
+87
-0
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
+4
-3
metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
...dle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
+4
-3
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
+2
-2
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
+2
-2
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift
...mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift
+51
-3
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift
...obile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift
+19
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal
...e-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal
+45
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift
...e-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift
+25
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
...obile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
+26
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift
...obile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift
+25
-0
metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift
metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift
+39
-0
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
+2
-2
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
+39
-0
metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift
metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift
+39
-0
metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift
metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift
+1
-1
metal/paddle-mobile/paddle-mobile/Program/Program.swift
metal/paddle-mobile/paddle-mobile/Program/Program.swift
+2
-2
metal/paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift
...paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift
+30
-18
未找到文件。
metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj
浏览文件 @
d1d1e932
...
@@ -214,10 +214,8 @@
...
@@ -214,10 +214,8 @@
FC0E2DB420EDC03C009C1FAC
/* conv2d_27.w_0 in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2CEA20EDC03B009C1FAC
/* conv2d_27.w_0 */
;
};
FC0E2DB420EDC03C009C1FAC
/* conv2d_27.w_0 in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2CEA20EDC03B009C1FAC
/* conv2d_27.w_0 */
;
};
FC0E2DB520EDC03C009C1FAC
/* conv2d_33.w_0 in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2CEB20EDC03B009C1FAC
/* conv2d_33.w_0 */
;
};
FC0E2DB520EDC03C009C1FAC
/* conv2d_33.w_0 in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2CEB20EDC03B009C1FAC
/* conv2d_33.w_0 */
;
};
FC0E2DB620EDC03C009C1FAC
/* depthwise_conv2d_7.w_0 in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2CEC20EDC03B009C1FAC
/* depthwise_conv2d_7.w_0 */
;
};
FC0E2DB620EDC03C009C1FAC
/* depthwise_conv2d_7.w_0 in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2CEC20EDC03B009C1FAC
/* depthwise_conv2d_7.w_0 */
;
};
FCEBC0FC20F227C60099DBAF
/* mobilenet in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0F820F227C60099DBAF
/* mobilenet */
;
};
FCD04E6320F3146B0007374F
/* params in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6120F3146A0007374F
/* params */
;
};
FCEBC0FD20F227C60099DBAF
/* params in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0F920F227C60099DBAF
/* params */
;
};
FCD04E6420F3146B0007374F
/* model in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6220F3146A0007374F
/* model */
;
};
FCEBC0FE20F227C60099DBAF
/* model in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0FA20F227C60099DBAF
/* model */
;
};
FCEBC0FF20F227C60099DBAF
/* yolo in Resources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0FB20F227C60099DBAF
/* yolo */
;
};
FCEBEC2C20E1391F00C0B14D
/* paddle_mobile.framework in Frameworks */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBEC2B20E1391F00C0B14D
/* paddle_mobile.framework */
;
};
FCEBEC2C20E1391F00C0B14D
/* paddle_mobile.framework in Frameworks */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBEC2B20E1391F00C0B14D
/* paddle_mobile.framework */
;
};
FCEBEC2D20E1391F00C0B14D
/* paddle_mobile.framework in Embed Frameworks */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBEC2B20E1391F00C0B14D
/* paddle_mobile.framework */
;
settings
=
{
ATTRIBUTES
=
(
CodeSignOnCopy
,
RemoveHeadersOnCopy
,
);
};
};
FCEBEC2D20E1391F00C0B14D
/* paddle_mobile.framework in Embed Frameworks */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBEC2B20E1391F00C0B14D
/* paddle_mobile.framework */
;
settings
=
{
ATTRIBUTES
=
(
CodeSignOnCopy
,
RemoveHeadersOnCopy
,
);
};
};
/* End PBXBuildFile section */
/* End PBXBuildFile section */
...
@@ -448,10 +446,8 @@
...
@@ -448,10 +446,8 @@
FC0E2CEA20EDC03B009C1FAC
/* conv2d_27.w_0 */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
conv2d_27.w_0
;
sourceTree
=
"<group>"
;
};
FC0E2CEA20EDC03B009C1FAC
/* conv2d_27.w_0 */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
conv2d_27.w_0
;
sourceTree
=
"<group>"
;
};
FC0E2CEB20EDC03B009C1FAC
/* conv2d_33.w_0 */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
conv2d_33.w_0
;
sourceTree
=
"<group>"
;
};
FC0E2CEB20EDC03B009C1FAC
/* conv2d_33.w_0 */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
conv2d_33.w_0
;
sourceTree
=
"<group>"
;
};
FC0E2CEC20EDC03B009C1FAC
/* depthwise_conv2d_7.w_0 */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
depthwise_conv2d_7.w_0
;
sourceTree
=
"<group>"
;
};
FC0E2CEC20EDC03B009C1FAC
/* depthwise_conv2d_7.w_0 */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
depthwise_conv2d_7.w_0
;
sourceTree
=
"<group>"
;
};
FCEBC0F820F227C60099DBAF
/* mobilenet */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
text
;
path
=
mobilenet
;
sourceTree
=
"<group>"
;
};
FCD04E6120F3146A0007374F
/* params */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
params
;
sourceTree
=
"<group>"
;
};
FCEBC0F920F227C60099DBAF
/* params */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
params
;
sourceTree
=
"<group>"
;
};
FCD04E6220F3146A0007374F
/* model */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
model
;
sourceTree
=
"<group>"
;
};
FCEBC0FA20F227C60099DBAF
/* model */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
file
;
path
=
model
;
sourceTree
=
"<group>"
;
};
FCEBC0FB20F227C60099DBAF
/* yolo */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
text
;
path
=
yolo
;
sourceTree
=
"<group>"
;
};
FCEBEC2B20E1391F00C0B14D
/* paddle_mobile.framework */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
wrapper.framework
;
path
=
paddle_mobile.framework
;
sourceTree
=
BUILT_PRODUCTS_DIR
;
};
FCEBEC2B20E1391F00C0B14D
/* paddle_mobile.framework */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
wrapper.framework
;
path
=
paddle_mobile.framework
;
sourceTree
=
BUILT_PRODUCTS_DIR
;
};
/* End PBXFileReference section */
/* End PBXFileReference section */
...
@@ -531,7 +527,7 @@
...
@@ -531,7 +527,7 @@
FC0E2C2020EDC03B009C1FAC
/* models */
=
{
FC0E2C2020EDC03B009C1FAC
/* models */
=
{
isa
=
PBXGroup
;
isa
=
PBXGroup
;
children
=
(
children
=
(
FC
EBC0F720F227C60099DBAF
/* yolo
*/
,
FC
D04E6020F3146A0007374F
/* mobilenet
*/
,
FC0E2C2420EDC03B009C1FAC
/* mobilenetssd */
,
FC0E2C2420EDC03B009C1FAC
/* mobilenetssd */
,
);
);
name
=
models
;
name
=
models
;
...
@@ -745,15 +741,13 @@
...
@@ -745,15 +741,13 @@
path
=
mobilenetssd
;
path
=
mobilenetssd
;
sourceTree
=
"<group>"
;
sourceTree
=
"<group>"
;
};
};
FC
EBC0F720F227C60099DBAF
/* yolo
*/
=
{
FC
D04E6020F3146A0007374F
/* mobilenet
*/
=
{
isa
=
PBXGroup
;
isa
=
PBXGroup
;
children
=
(
children
=
(
FCEBC0F820F227C60099DBAF
/* mobilenet */
,
FCD04E6120F3146A0007374F
/* params */
,
FCEBC0F920F227C60099DBAF
/* params */
,
FCD04E6220F3146A0007374F
/* model */
,
FCEBC0FA20F227C60099DBAF
/* model */
,
FCEBC0FB20F227C60099DBAF
/* yolo */
,
);
);
path
=
yolo
;
path
=
mobilenet
;
sourceTree
=
"<group>"
;
sourceTree
=
"<group>"
;
};
};
/* End PBXGroup section */
/* End PBXGroup section */
...
@@ -828,7 +822,6 @@
...
@@ -828,7 +822,6 @@
FC0E2D2920EDC03B009C1FAC
/* batch_norm_2.b_0 in Resources */
,
FC0E2D2920EDC03B009C1FAC
/* batch_norm_2.b_0 in Resources */
,
FC0E2DA920EDC03C009C1FAC
/* conv2d_26.w_0 in Resources */
,
FC0E2DA920EDC03C009C1FAC
/* conv2d_26.w_0 in Resources */
,
FC0E2D0420EDC03B009C1FAC
/* batch_norm_16.w_2 in Resources */
,
FC0E2D0420EDC03B009C1FAC
/* batch_norm_16.w_2 in Resources */
,
FCEBC0FE20F227C60099DBAF
/* model in Resources */
,
FC0E2D0720EDC03B009C1FAC
/* batch_norm_6.w_1 in Resources */
,
FC0E2D0720EDC03B009C1FAC
/* batch_norm_6.w_1 in Resources */
,
FC0E2DB020EDC03C009C1FAC
/* batch_norm_30.w_2 in Resources */
,
FC0E2DB020EDC03C009C1FAC
/* batch_norm_30.w_2 in Resources */
,
FC0E2D9720EDC03C009C1FAC
/* conv2d_25.w_0 in Resources */
,
FC0E2D9720EDC03C009C1FAC
/* conv2d_25.w_0 in Resources */
,
...
@@ -844,11 +837,9 @@
...
@@ -844,11 +837,9 @@
FC0E2DA620EDC03C009C1FAC
/* depthwise_conv2d_4.w_0 in Resources */
,
FC0E2DA620EDC03C009C1FAC
/* depthwise_conv2d_4.w_0 in Resources */
,
FC0E2D6920EDC03C009C1FAC
/* conv2d_6.w_0 in Resources */
,
FC0E2D6920EDC03C009C1FAC
/* conv2d_6.w_0 in Resources */
,
FC0E2D6520EDC03C009C1FAC
/* conv2d_7.w_0 in Resources */
,
FC0E2D6520EDC03C009C1FAC
/* conv2d_7.w_0 in Resources */
,
FCEBC0FD20F227C60099DBAF
/* params in Resources */
,
FC0E2DAB20EDC03C009C1FAC
/* batch_norm_19.w_2 in Resources */
,
FC0E2DAB20EDC03C009C1FAC
/* batch_norm_19.w_2 in Resources */
,
FC0E2D9920EDC03C009C1FAC
/* conv2d_31.w_0 in Resources */
,
FC0E2D9920EDC03C009C1FAC
/* conv2d_31.w_0 in Resources */
,
FC0E2D3020EDC03B009C1FAC
/* batch_norm_34.w_0 in Resources */
,
FC0E2D3020EDC03B009C1FAC
/* batch_norm_34.w_0 in Resources */
,
FCEBC0FC20F227C60099DBAF
/* mobilenet in Resources */
,
FC0E2D1220EDC03B009C1FAC
/* batch_norm_34.b_0 in Resources */
,
FC0E2D1220EDC03B009C1FAC
/* batch_norm_34.b_0 in Resources */
,
FC0E2D4D20EDC03C009C1FAC
/* batch_norm_7.b_0 in Resources */
,
FC0E2D4D20EDC03C009C1FAC
/* batch_norm_7.b_0 in Resources */
,
FC0E2D2520EDC03B009C1FAC
/* batch_norm_21.w_1 in Resources */
,
FC0E2D2520EDC03B009C1FAC
/* batch_norm_21.w_1 in Resources */
,
...
@@ -857,6 +848,7 @@
...
@@ -857,6 +848,7 @@
FC0E2D8620EDC03C009C1FAC
/* conv2d_23.w_0 in Resources */
,
FC0E2D8620EDC03C009C1FAC
/* conv2d_23.w_0 in Resources */
,
FC0E2CFE20EDC03B009C1FAC
/* depthwise_conv2d_9.w_0 in Resources */
,
FC0E2CFE20EDC03B009C1FAC
/* depthwise_conv2d_9.w_0 in Resources */
,
FC0E2D4C20EDC03C009C1FAC
/* batch_norm_8.w_2 in Resources */
,
FC0E2D4C20EDC03C009C1FAC
/* batch_norm_8.w_2 in Resources */
,
FCD04E6320F3146B0007374F
/* params in Resources */
,
FC0E2D5820EDC03C009C1FAC
/* conv2d_5.w_0 in Resources */
,
FC0E2D5820EDC03C009C1FAC
/* conv2d_5.w_0 in Resources */
,
FC0E2D1620EDC03B009C1FAC
/* batch_norm_3.w_1 in Resources */
,
FC0E2D1620EDC03B009C1FAC
/* batch_norm_3.w_1 in Resources */
,
FC0E2DB120EDC03C009C1FAC
/* batch_norm_24.w_2 in Resources */
,
FC0E2DB120EDC03C009C1FAC
/* batch_norm_24.w_2 in Resources */
,
...
@@ -949,7 +941,6 @@
...
@@ -949,7 +941,6 @@
FC0E2D0F20EDC03B009C1FAC
/* batch_norm_5.w_0 in Resources */
,
FC0E2D0F20EDC03B009C1FAC
/* batch_norm_5.w_0 in Resources */
,
FC0E2D4520EDC03C009C1FAC
/* batch_norm_9.w_2 in Resources */
,
FC0E2D4520EDC03C009C1FAC
/* batch_norm_9.w_2 in Resources */
,
FC0E2D9020EDC03C009C1FAC
/* batch_norm_23.w_2 in Resources */
,
FC0E2D9020EDC03C009C1FAC
/* batch_norm_23.w_2 in Resources */
,
FCEBC0FF20F227C60099DBAF
/* yolo in Resources */
,
FC0E2D6720EDC03C009C1FAC
/* conv2d_31.b_0 in Resources */
,
FC0E2D6720EDC03C009C1FAC
/* conv2d_31.b_0 in Resources */
,
FC0E2DA020EDC03C009C1FAC
/* conv2d_18.w_0 in Resources */
,
FC0E2DA020EDC03C009C1FAC
/* conv2d_18.w_0 in Resources */
,
FC0E2D1C20EDC03B009C1FAC
/* conv2d_13.w_0 in Resources */
,
FC0E2D1C20EDC03B009C1FAC
/* conv2d_13.w_0 in Resources */
,
...
@@ -972,6 +963,7 @@
...
@@ -972,6 +963,7 @@
FC0E2D0920EDC03B009C1FAC
/* conv2d_14.w_0 in Resources */
,
FC0E2D0920EDC03B009C1FAC
/* conv2d_14.w_0 in Resources */
,
FC0E2CF720EDC03B009C1FAC
/* batch_norm_28.w_2 in Resources */
,
FC0E2CF720EDC03B009C1FAC
/* batch_norm_28.w_2 in Resources */
,
FC0E2D9520EDC03C009C1FAC
/* depthwise_conv2d_5.w_0 in Resources */
,
FC0E2D9520EDC03C009C1FAC
/* depthwise_conv2d_5.w_0 in Resources */
,
FCD04E6420F3146B0007374F
/* model in Resources */
,
FC0E2D4A20EDC03C009C1FAC
/* conv2d_9.w_0 in Resources */
,
FC0E2D4A20EDC03C009C1FAC
/* conv2d_9.w_0 in Resources */
,
FC0E2D4E20EDC03C009C1FAC
/* batch_norm_19.w_1 in Resources */
,
FC0E2D4E20EDC03C009C1FAC
/* batch_norm_19.w_1 in Resources */
,
FC0E2D3620EDC03C009C1FAC
/* batch_norm_18.w_0 in Resources */
,
FC0E2D3620EDC03C009C1FAC
/* batch_norm_18.w_0 in Resources */
,
...
...
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
浏览文件 @
d1d1e932
...
@@ -45,6 +45,14 @@
...
@@ -45,6 +45,14 @@
FC9D038020E22FBB000F735A
/* FeedOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC9D037F20E22FBB000F735A
/* FeedOp.swift */
;
};
FC9D038020E22FBB000F735A
/* FeedOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC9D037F20E22FBB000F735A
/* FeedOp.swift */
;
};
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC9D038120E2312E000F735A
/* FetchOp.swift */
;
};
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC9D038120E2312E000F735A
/* FetchOp.swift */
;
};
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC9D038320E23B01000F735A
/* Texture.swift */
;
};
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC9D038320E23B01000F735A
/* Texture.swift */
;
};
FCD04E6620F314C50007374F
/* PoolOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6520F314C50007374F
/* PoolOp.swift */
;
};
FCD04E6820F315020007374F
/* PoolKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6720F315020007374F
/* PoolKernel.swift */
;
};
FCD04E6A20F319EC0007374F
/* SoftmaxOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6920F319EC0007374F
/* SoftmaxOp.swift */
;
};
FCD04E6C20F31A280007374F
/* SoftmaxKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6B20F31A280007374F
/* SoftmaxKernel.swift */
;
};
FCD04E6E20F31B4B0007374F
/* ReshapeOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6D20F31B4B0007374F
/* ReshapeOp.swift */
;
};
FCD04E7020F31B720007374F
/* ReshapeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E6F20F31B720007374F
/* ReshapeKernel.swift */
;
};
FCD04E7220F343420007374F
/* ConvAddOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E7120F343420007374F
/* ConvAddOp.swift */
;
};
FCD04E7420F3437E0007374F
/* ConvAddKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCD04E7320F3437E0007374F
/* ConvAddKernel.swift */
;
};
FCEBC0F420F1FDD90099DBAF
/* ConvAddBatchNormReluOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0F320F1FDD90099DBAF
/* ConvAddBatchNormReluOp.swift */
;
};
FCEBC0F420F1FDD90099DBAF
/* ConvAddBatchNormReluOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0F320F1FDD90099DBAF
/* ConvAddBatchNormReluOp.swift */
;
};
FCEBC0F620F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0F520F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift */
;
};
FCEBC0F620F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCEBC0F520F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift */
;
};
FCF2D73820E64E70007AC5F5
/* Kernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
;
};
FCF2D73820E64E70007AC5F5
/* Kernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
;
};
...
@@ -93,6 +101,14 @@
...
@@ -93,6 +101,14 @@
FC9D037F20E22FBB000F735A
/* FeedOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FeedOp.swift
;
sourceTree
=
"<group>"
;
};
FC9D037F20E22FBB000F735A
/* FeedOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FeedOp.swift
;
sourceTree
=
"<group>"
;
};
FC9D038120E2312E000F735A
/* FetchOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FetchOp.swift
;
sourceTree
=
"<group>"
;
};
FC9D038120E2312E000F735A
/* FetchOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FetchOp.swift
;
sourceTree
=
"<group>"
;
};
FC9D038320E23B01000F735A
/* Texture.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
Texture.swift
;
sourceTree
=
"<group>"
;
};
FC9D038320E23B01000F735A
/* Texture.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
Texture.swift
;
sourceTree
=
"<group>"
;
};
FCD04E6520F314C50007374F
/* PoolOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
PoolOp.swift
;
sourceTree
=
"<group>"
;
};
FCD04E6720F315020007374F
/* PoolKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
PoolKernel.swift
;
sourceTree
=
"<group>"
;
};
FCD04E6920F319EC0007374F
/* SoftmaxOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
SoftmaxOp.swift
;
sourceTree
=
"<group>"
;
};
FCD04E6B20F31A280007374F
/* SoftmaxKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
SoftmaxKernel.swift
;
sourceTree
=
"<group>"
;
};
FCD04E6D20F31B4B0007374F
/* ReshapeOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ReshapeOp.swift
;
sourceTree
=
"<group>"
;
};
FCD04E6F20F31B720007374F
/* ReshapeKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ReshapeKernel.swift
;
sourceTree
=
"<group>"
;
};
FCD04E7120F343420007374F
/* ConvAddOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ConvAddOp.swift
;
sourceTree
=
"<group>"
;
};
FCD04E7320F3437E0007374F
/* ConvAddKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ConvAddKernel.swift
;
sourceTree
=
"<group>"
;
};
FCEBC0F320F1FDD90099DBAF
/* ConvAddBatchNormReluOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
name
=
ConvAddBatchNormReluOp.swift
;
path
=
"paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"
;
sourceTree
=
SOURCE_ROOT
;
};
FCEBC0F320F1FDD90099DBAF
/* ConvAddBatchNormReluOp.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
name
=
ConvAddBatchNormReluOp.swift
;
path
=
"paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"
;
sourceTree
=
SOURCE_ROOT
;
};
FCEBC0F520F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ConvAddBatchNormReluKernel.swift
;
sourceTree
=
"<group>"
;
};
FCEBC0F520F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ConvAddBatchNormReluKernel.swift
;
sourceTree
=
"<group>"
;
};
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
name
=
Kernel.swift
;
path
=
"paddle-mobile/Operators/Kernels/Kernel.swift"
;
sourceTree
=
SOURCE_ROOT
;
};
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
name
=
Kernel.swift
;
path
=
"paddle-mobile/Operators/Kernels/Kernel.swift"
;
sourceTree
=
SOURCE_ROOT
;
};
...
@@ -193,6 +209,10 @@
...
@@ -193,6 +209,10 @@
FC039BA820E11CBC0081E9F8
/* ReluOp.swift */
,
FC039BA820E11CBC0081E9F8
/* ReluOp.swift */
,
FC9D037F20E22FBB000F735A
/* FeedOp.swift */
,
FC9D037F20E22FBB000F735A
/* FeedOp.swift */
,
FC9D038120E2312E000F735A
/* FetchOp.swift */
,
FC9D038120E2312E000F735A
/* FetchOp.swift */
,
FCD04E6520F314C50007374F
/* PoolOp.swift */
,
FCD04E6920F319EC0007374F
/* SoftmaxOp.swift */
,
FCD04E6D20F31B4B0007374F
/* ReshapeOp.swift */
,
FCD04E7120F343420007374F
/* ConvAddOp.swift */
,
);
);
path
=
Operators
;
path
=
Operators
;
sourceTree
=
"<group>"
;
sourceTree
=
"<group>"
;
...
@@ -227,6 +247,10 @@
...
@@ -227,6 +247,10 @@
FC5163F520EF556E00636C28
/* Texture2DTo2DArrayKernel.swift */
,
FC5163F520EF556E00636C28
/* Texture2DTo2DArrayKernel.swift */
,
FC4CB74820F0B954007C0C6D
/* ConvKernel.metal */
,
FC4CB74820F0B954007C0C6D
/* ConvKernel.metal */
,
FCEBC0F520F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift */
,
FCEBC0F520F1FE120099DBAF
/* ConvAddBatchNormReluKernel.swift */
,
FCD04E6720F315020007374F
/* PoolKernel.swift */
,
FCD04E6B20F31A280007374F
/* SoftmaxKernel.swift */
,
FCD04E6F20F31B720007374F
/* ReshapeKernel.swift */
,
FCD04E7320F3437E0007374F
/* ConvAddKernel.swift */
,
);
);
path
=
Kernels
;
path
=
Kernels
;
sourceTree
=
"<group>"
;
sourceTree
=
"<group>"
;
...
@@ -346,6 +370,8 @@
...
@@ -346,6 +370,8 @@
FC0E2DBC20EE45FE009C1FAC
/* ConvKernel.swift in Sources */
,
FC0E2DBC20EE45FE009C1FAC
/* ConvKernel.swift in Sources */
,
FC039BAA20E11CBC0081E9F8
/* ElementwiseAddOp.swift in Sources */
,
FC039BAA20E11CBC0081E9F8
/* ElementwiseAddOp.swift in Sources */
,
FC039B9B20E11CA00081E9F8
/* Executor.swift in Sources */
,
FC039B9B20E11CA00081E9F8
/* Executor.swift in Sources */
,
FCD04E7020F31B720007374F
/* ReshapeKernel.swift in Sources */
,
FCD04E7220F343420007374F
/* ConvAddOp.swift in Sources */
,
FC039BBB20E11CC20081E9F8
/* ProgramDesc.swift in Sources */
,
FC039BBB20E11CC20081E9F8
/* ProgramDesc.swift in Sources */
,
FC9D037920E229E4000F735A
/* OpParam.swift in Sources */
,
FC9D037920E229E4000F735A
/* OpParam.swift in Sources */
,
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
,
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
,
...
@@ -362,21 +388,27 @@
...
@@ -362,21 +388,27 @@
FC4CB74920F0B954007C0C6D
/* ConvKernel.metal in Sources */
,
FC4CB74920F0B954007C0C6D
/* ConvKernel.metal in Sources */
,
FC039BA920E11CBC0081E9F8
/* ConvOp.swift in Sources */
,
FC039BA920E11CBC0081E9F8
/* ConvOp.swift in Sources */
,
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
,
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
,
FCD04E6E20F31B4B0007374F
/* ReshapeOp.swift in Sources */
,
FC039B9820E11C9A0081E9F8
/* Errors.swift in Sources */
,
FC039B9820E11C9A0081E9F8
/* Errors.swift in Sources */
,
FC039BBF20E11CC20081E9F8
/* Attribute.swift in Sources */
,
FC039BBF20E11CC20081E9F8
/* Attribute.swift in Sources */
,
FCD04E7420F3437E0007374F
/* ConvAddKernel.swift in Sources */
,
FC039BB920E11CC20081E9F8
/* Scope.swift in Sources */
,
FC039BB920E11CC20081E9F8
/* Scope.swift in Sources */
,
FCD04E6620F314C50007374F
/* PoolOp.swift in Sources */
,
FC039BAC20E11CBC0081E9F8
/* BatchNormOp.swift in Sources */
,
FC039BAC20E11CBC0081E9F8
/* BatchNormOp.swift in Sources */
,
FC039BBC20E11CC20081E9F8
/* VarDesc.swift in Sources */
,
FC039BBC20E11CC20081E9F8
/* VarDesc.swift in Sources */
,
FC0E2DBA20EE3B8D009C1FAC
/* ReluKernel.swift in Sources */
,
FC0E2DBA20EE3B8D009C1FAC
/* ReluKernel.swift in Sources */
,
FC82735920E3C04200BE430A
/* OpCreator.swift in Sources */
,
FC82735920E3C04200BE430A
/* OpCreator.swift in Sources */
,
FC0E2DBE20EE460D009C1FAC
/* BatchNormKernel.swift in Sources */
,
FC0E2DBE20EE460D009C1FAC
/* BatchNormKernel.swift in Sources */
,
FC039BAB20E11CBC0081E9F8
/* Operator.swift in Sources */
,
FC039BAB20E11CBC0081E9F8
/* Operator.swift in Sources */
,
FCD04E6A20F319EC0007374F
/* SoftmaxOp.swift in Sources */
,
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
,
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
,
FC039BBD20E11CC20081E9F8
/* Program.swift in Sources */
,
FC039BBD20E11CC20081E9F8
/* Program.swift in Sources */
,
FC039BA220E11CB70081E9F8
/* Loader.swift in Sources */
,
FC039BA220E11CB70081E9F8
/* Loader.swift in Sources */
,
FCD04E6C20F31A280007374F
/* SoftmaxKernel.swift in Sources */
,
FC4CB74B20F12C30007C0C6D
/* ProgramOptimize.swift in Sources */
,
FC4CB74B20F12C30007C0C6D
/* ProgramOptimize.swift in Sources */
,
FC5163F620EF556E00636C28
/* Texture2DTo2DArrayKernel.swift in Sources */
,
FC5163F620EF556E00636C28
/* Texture2DTo2DArrayKernel.swift in Sources */
,
FC039BC020E11CC20081E9F8
/* BlockDesc.swift in Sources */
,
FC039BC020E11CC20081E9F8
/* BlockDesc.swift in Sources */
,
FCD04E6820F315020007374F
/* PoolKernel.swift in Sources */
,
FC039BAD20E11CBC0081E9F8
/* ReluOp.swift in Sources */
,
FC039BAD20E11CBC0081E9F8
/* ReluOp.swift in Sources */
,
FC039BBE20E11CC20081E9F8
/* OpDesc.swift in Sources */
,
FC039BBE20E11CC20081E9F8
/* OpDesc.swift in Sources */
,
FC039B9720E11C9A0081E9F8
/* Extensions.swift in Sources */
,
FC039B9720E11C9A0081E9F8
/* Extensions.swift in Sources */
,
...
...
metal/paddle-mobile/paddle-mobile/Common/Extensions.swift
浏览文件 @
d1d1e932
...
@@ -72,6 +72,16 @@ extension Array: CIntIndex{
...
@@ -72,6 +72,16 @@ extension Array: CIntIndex{
}
}
}
}
extension
Array
where
Element
:
AnyObject
{
mutating
func
remove
(
element
:
Element
)
{
if
let
index
=
index
(
where
:
{
(
node
)
->
Bool
in
return
unsafeBitCast
(
element
,
to
:
Int
.
self
)
==
unsafeBitCast
(
node
,
to
:
Int
.
self
)
})
{
remove
(
at
:
index
)
}
}
}
//MARK: Array extension
//MARK: Array extension
extension
Array
where
Element
:
Comparable
{
extension
Array
where
Element
:
Comparable
{
...
@@ -92,4 +102,10 @@ extension String{
...
@@ -92,4 +102,10 @@ extension String{
}
}
}
}
func
address
<
T
:
AnyObject
>
(
o
:
T
)
->
String
{
return
String
.
init
(
format
:
"%018p"
,
unsafeBitCast
(
o
,
to
:
Int
.
self
))
}
metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
浏览文件 @
d1d1e932
...
@@ -42,7 +42,6 @@ extension MTLDevice {
...
@@ -42,7 +42,6 @@ extension MTLDevice {
}
}
}
}
func
pipeLine
(
funcName
:
String
,
inPaddleMobileLib
:
Bool
=
true
)
->
MTLComputePipelineState
{
func
pipeLine
(
funcName
:
String
,
inPaddleMobileLib
:
Bool
=
true
)
->
MTLComputePipelineState
{
let
useLib
=
inPaddleMobileLib
?
paddleMobileLibrary
()
:
defaultLibrary
()
let
useLib
=
inPaddleMobileLib
?
paddleMobileLibrary
()
:
defaultLibrary
()
guard
let
function
=
useLib
.
makeFunction
(
name
:
funcName
)
else
{
guard
let
function
=
useLib
.
makeFunction
(
name
:
funcName
)
else
{
...
@@ -65,7 +64,7 @@ extension MTLComputeCommandEncoder {
...
@@ -65,7 +64,7 @@ extension MTLComputeCommandEncoder {
let
width
=
computePipline
.
threadExecutionWidth
let
width
=
computePipline
.
threadExecutionWidth
let
height
=
computePipline
.
maxTotalThreadsPerThreadgroup
/
width
let
height
=
computePipline
.
maxTotalThreadsPerThreadgroup
/
width
let
threadsPerGroup
=
MTLSize
.
init
(
width
:
width
,
height
:
height
,
depth
:
1
)
let
threadsPerGroup
=
MTLSize
.
init
(
width
:
width
,
height
:
height
,
depth
:
1
)
// print(" thread: threads per group: \(threadsPerGroup) ")
// print(" thread: threads per group: \(threadsPerGroup) ")
// print(" thread: out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
// print(" thread: out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
...
...
metal/paddle-mobile/paddle-mobile/Common/Types.swift
浏览文件 @
d1d1e932
...
@@ -14,22 +14,71 @@
...
@@ -14,22 +14,71 @@
import
Foundation
import
Foundation
public
protocol
SummableMultipliable
:
Equatable
{
static
func
+
(
lhs
:
Self
,
rhs
:
Self
)
->
Self
static
func
*
(
lhs
:
Self
,
rhs
:
Self
)
->
Self
static
func
-
(
lhs
:
Self
,
rhs
:
Self
)
->
Self
}
public
protocol
PrecisionType
:
SummableMultipliable
{
init
(
inFloat
:
Float32
)
init
(
inFloat16
:
Float16
)
init
<
P
:
PrecisionType
>
(
_
inP
:
P
)
static
var
bitSize
:
UInt
{
get
}
}
public
typealias
Float16
=
Int16
public
typealias
Float16
=
Int16
extension
Float16
:
PrecisionType
{
extension
Float16
:
PrecisionType
{
public
static
func
*
(
prefix
:
Float16
,
postfix
:
Float16
)
{
return
prefix
*
postfix
}
public
init
<
P
>
(
_
inP
:
P
)
where
P
:
PrecisionType
{
if
P
.
bitSize
==
Float32
.
bitSize
{
self
=
Float16
(
inFloat
:
inP
as!
Float32
)
}
else
if
P
.
bitSize
==
Float16
.
bitSize
{
self
=
inP
as!
Float16
}
else
{
fatalError
()
}
}
public
static
var
bitSize
:
UInt
{
return
16
}
public
init
(
inFloat16
:
Float16
)
{
self
=
inFloat16
}
public
init
(
inFloat
:
Float32
)
{
public
init
(
inFloat
:
Float32
)
{
self
=
Int16
(
inFloat
)
self
=
Int16
(
inFloat
)
}
}
}
}
public
protocol
PrecisionType
{
init
(
inFloat
:
Float32
)
}
extension
Float32
:
PrecisionType
{
extension
Float32
:
PrecisionType
{
public
init
<
P
>
(
_
inP
:
P
)
where
P
:
PrecisionType
{
if
P
.
bitSize
==
Float32
.
bitSize
{
self
=
inP
as!
Float32
}
else
if
P
.
bitSize
==
Float16
.
bitSize
{
self
=
Float32
.
init
(
inP
as!
Float16
)
}
else
{
fatalError
()
}
}
public
init
(
inFloat
:
Float32
)
{
public
init
(
inFloat
:
Float32
)
{
self
=
inFloat
self
=
inFloat
}
}
public
init
(
inFloat16
:
Float16
)
{
self
=
Float32
.
init
(
inFloat16
)
}
public
static
var
bitSize
:
UInt
{
return
32
}
}
}
public
enum
DataLayout
{
public
enum
DataLayout
{
...
...
metal/paddle-mobile/paddle-mobile/Executor.swift
浏览文件 @
d1d1e932
...
@@ -55,7 +55,8 @@ public class Executor<P: PrecisionType> {
...
@@ -55,7 +55,8 @@ public class Executor<P: PrecisionType> {
device
=
inDevice
device
=
inDevice
queue
=
inQueue
queue
=
inQueue
for
block
in
inProgram
.
programDesc
.
blocks
{
for
block
in
inProgram
.
programDesc
.
blocks
{
for
op
in
block
.
ops
{
for
i
in
0
..<
7
{
let
op
=
block
.
ops
[
i
]
do
{
do
{
let
op
=
try
OpCreator
<
P
>.
shared
.
creat
(
device
:
inDevice
,
opDesc
:
op
,
scope
:
inProgram
.
scope
)
let
op
=
try
OpCreator
<
P
>.
shared
.
creat
(
device
:
inDevice
,
opDesc
:
op
,
scope
:
inProgram
.
scope
)
op
.
inferShape
()
op
.
inferShape
()
...
@@ -64,6 +65,15 @@ public class Executor<P: PrecisionType> {
...
@@ -64,6 +65,15 @@ public class Executor<P: PrecisionType> {
throw
error
throw
error
}
}
}
}
// for op in block.ops {
// do {
// let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
// op.inferShape()
// ops.append(op)
// } catch let error {
// throw error
// }
// }
}
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/Loader.swift
浏览文件 @
d1d1e932
...
@@ -104,12 +104,9 @@ public class Loader<P: PrecisionType> {
...
@@ -104,12 +104,9 @@ public class Loader<P: PrecisionType> {
serializedData
:
modelData
)
serializedData
:
modelData
)
let
originProgramDesc
=
ProgramDesc
.
init
(
protoProgram
:
protoProgram
)
let
originProgramDesc
=
ProgramDesc
.
init
(
protoProgram
:
protoProgram
)
let
programDesc
=
ProgramOptimize
<
P
>.
init
()
.
optimize
(
originProgramDesc
:
originProgramDesc
)
let
programDesc
=
ProgramOptimize
<
P
>.
init
()
.
optimize
(
originProgramDesc
:
originProgramDesc
)
print
(
programDesc
)
print
(
programDesc
)
fatalError
()
guard
let
paraLoader
=
try
?
ParaLoader
.
init
(
paramPath
:
paraPath
)
else
{
guard
let
paraLoader
=
try
?
ParaLoader
.
init
(
paramPath
:
paraPath
)
else
{
throw
PaddleMobileError
.
loaderError
(
message
:
"load para error"
)
throw
PaddleMobileError
.
loaderError
(
message
:
"load para error"
)
}
}
...
@@ -180,7 +177,7 @@ public class Loader<P: PrecisionType> {
...
@@ -180,7 +177,7 @@ public class Loader<P: PrecisionType> {
}
}
}
}
let
program
=
Program
.
init
(
protoProgramDesc
:
protoProgram
,
inParamPath
:
paraPath
,
inScope
:
scope
)
let
program
=
Program
.
init
(
inProgramDesc
:
programDesc
,
inParamPath
:
paraPath
,
inScope
:
scope
)
return
program
return
program
}
catch
_
{
}
catch
_
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
浏览文件 @
d1d1e932
...
@@ -40,13 +40,17 @@ class OpCreator<P: PrecisionType> {
...
@@ -40,13 +40,17 @@ class OpCreator<P: PrecisionType> {
}
}
let
opCreators
:
[
String
:
(
MTLDevice
,
OpDesc
,
Scope
)
throws
->
Runable
&
InferShaperable
]
=
let
opCreators
:
[
String
:
(
MTLDevice
,
OpDesc
,
Scope
)
throws
->
Runable
&
InferShaperable
]
=
[
gConvType
:
ConvOp
<
P
>.
creat
,
[
gConvType
:
ConvOp
<
P
>.
creat
,
gBatchNormType
:
BatchNormOp
<
P
>.
creat
,
gBatchNormType
:
BatchNormOp
<
P
>.
creat
,
gReluType
:
ReluOp
<
P
>.
creat
,
gReluType
:
ReluOp
<
P
>.
creat
,
gElementwiseAdd
:
ElementwiseAddOp
<
P
>.
creat
,
gElementwiseAdd
:
ElementwiseAddOp
<
P
>.
creat
,
gFeedType
:
FeedOp
<
P
>.
creat
,
gFeedType
:
FeedOp
<
P
>.
creat
,
gFetchType
:
FetchOp
<
P
>.
creat
,
gFetchType
:
FetchOp
<
P
>.
creat
,
gConvAddBatchNormReluType
:
ConvAddBatchNormReluOp
<
P
>.
creat
]
gConvAddBatchNormReluType
:
ConvAddBatchNormReluOp
<
P
>.
creat
,
gPooType
:
PoolOp
<
P
>.
creat
,
gSoftmaxType
:
SoftmaxOp
<
P
>.
creat
,
gReshapeType
:
ReshapeOp
<
P
>.
creat
,
gConvAddType
:
ConvAddOp
<
P
>.
creat
]
private
init
(){}
private
init
(){}
}
}
metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
浏览文件 @
d1d1e932
...
@@ -18,9 +18,9 @@ import Foundation
...
@@ -18,9 +18,9 @@ import Foundation
protocol
Fusion
{
protocol
Fusion
{
static
func
fusionNode
()
->
Node
static
func
fusionNode
()
->
Node
static
func
change
()
->
[
String
:
[(
from
:
String
,
to
:
String
)]]
static
func
change
()
->
[
String
:
[(
from
:
String
,
to
:
String
)]]
static
func
fusionType
()
->
String
}
}
protocol
Runable
{
protocol
Runable
{
func
run
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
func
run
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
...
@@ -117,20 +117,20 @@ let gBatchNormType = "batch_norm"
...
@@ -117,20 +117,20 @@ let gBatchNormType = "batch_norm"
let
gReluType
=
"relu"
let
gReluType
=
"relu"
let
gElementwiseAdd
=
"elementwise_add"
let
gElementwiseAdd
=
"elementwise_add"
let
gConvAddBatchNormReluType
=
"conv_add_batchnorm_relu"
let
gConvAddBatchNormReluType
=
"conv_add_batchnorm_relu"
let
gPooType
=
"pool2d"
let
gSoftmaxType
=
"softmax"
let
gReshapeType
=
"reshape"
let
gConvAddType
=
"conv_add"
let
opInfos
=
[
gConvType
:
(
inputs
:
[
"Input"
],
outputs
:
[
"Output"
]),
let
opInfos
=
[
gConvType
:
(
inputs
:
[
"Input"
],
outputs
:
[
"Output"
]),
gBatchNormType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Y"
]),
gBatchNormType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Y"
]),
gReluType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gReluType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gElementwiseAdd
:
(
inputs
:
[
"X"
,
"Y"
],
outputs
:
[
"Out"
]),
gElementwiseAdd
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gFeedType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gFeedType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gFetchType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gFetchType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gConvAddBatchNormReluType
:
(
inputs
:
[
"Input"
],
outputs
:
[
"Out"
])]
gConvAddBatchNormReluType
:
(
inputs
:
[
"Input"
],
outputs
:
[
"Out"
]),
gPooType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gSoftmaxType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gReshapeType
:
(
inputs
:
[
"X"
],
outputs
:
[
"Out"
]),
gConvAddType
:
(
inputs
:
[
"Input"
],
outputs
:
[
"Out"
])]
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
浏览文件 @
d1d1e932
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
import
Foundation
import
Foundation
struct
BatchNormParam
<
P
:
PrecisionType
>
:
OpParam
{
class
BatchNormParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
input
=
try
BatchNormParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
input
=
try
BatchNormParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
BatchNormParam
.
outputY
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
output
=
try
BatchNormParam
.
outputY
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
...
...
metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift
浏览文件 @
d1d1e932
...
@@ -8,21 +8,49 @@
...
@@ -8,21 +8,49 @@
import
Foundation
import
Foundation
class
ConvAddBatchNormReluParam
<
P
:
PrecisionType
>
:
OpParam
{
class
ConvAddBatchNormReluOp
<
P
:
PrecisionType
>
:
Operator
<
ConvAddBatchNormReluKernel
<
P
>
,
ConvParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
,
Fusion
{
typealias
ParamPrecisionType
=
P
static
func
fusionNode
()
->
Node
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
let
beginNode
=
Node
.
init
(
inType
:
gConvType
)
do
{
_
=
beginNode
filter
=
try
ConvAddBatchNormReluParam
.
inputFilter
(
paraInputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
-->
Node
.
init
(
inType
:
gElementwiseAdd
)
input
=
try
ConvAddBatchNormReluParam
.
input
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
-->
Node
.
init
(
inType
:
gBatchNormType
)
output
=
try
ConvAddBatchNormReluParam
.
output
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
-->
Node
.
init
(
inType
:
gReluType
)
stride
=
try
ConvAddBatchNormReluParam
.
getAttr
(
key
:
"strides"
,
attrs
:
opDesc
.
attrs
)
return
beginNode
paddings
=
try
ConvAddBatchNormReluParam
.
getAttr
(
key
:
"paddings"
,
attrs
:
opDesc
.
attrs
)
dilations
=
try
ConvAddBatchNormReluParam
.
getAttr
(
key
:
"dilations"
,
attrs
:
opDesc
.
attrs
)
epsilon
=
try
ConvAddBatchNormReluParam
.
getAttr
(
key
:
"epsilon"
,
attrs
:
opDesc
.
attrs
)
groups
=
try
ConvAddBatchNormReluParam
.
getAttr
(
key
:
"groups"
,
attrs
:
opDesc
.
attrs
)
variance
=
try
ConvAddBatchNormReluParam
.
inputVariance
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
bias
=
try
ConvAddBatchNormReluParam
.
inputBiase
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
scale
=
try
ConvAddBatchNormReluParam
.
inputScale
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
mean
=
try
ConvAddBatchNormReluParam
.
inputMean
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
y
=
try
ConvAddBatchNormReluParam
.
inputY
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
}
catch
let
error
{
throw
error
}
}
}
static
func
change
()
->
[
String
:
[(
from
:
String
,
to
:
String
)]]
{
let
input
:
Texture
<
P
>
return
[:]
}
let
variance
:
Tensor
<
ParamPrecisionType
>
let
bias
:
Tensor
<
ParamPrecisionType
>
let
mean
:
Tensor
<
ParamPrecisionType
>
let
scale
:
Tensor
<
ParamPrecisionType
>
let
y
:
Tensor
<
ParamPrecisionType
>
let
filter
:
Tensor
<
ParamPrecisionType
>
let
epsilon
:
Float32
var
newScale
:
MTLBuffer
?
var
newBiase
:
MTLBuffer
?
var
output
:
Texture
<
P
>
let
stride
:
[
Int32
]
let
paddings
:
[
Int32
]
let
dilations
:
[
Int32
]
let
groups
:
Int
}
class
ConvAddBatchNormReluOp
<
P
:
PrecisionType
>
:
Operator
<
ConvAddBatchNormReluKernel
<
P
>
,
ConvAddBatchNormReluParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
,
Fusion
{
typealias
OpType
=
ConvAddBatchNormReluOp
<
P
>
typealias
OpType
=
ConvAddBatchNormReluOp
<
P
>
func
inferShape
()
{
func
inferShape
()
{
...
@@ -55,4 +83,20 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer
...
@@ -55,4 +83,20 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer
}
}
}
}
static
func
fusionNode
()
->
Node
{
let
beginNode
=
Node
.
init
(
inType
:
gConvType
)
_
=
beginNode
-->
Node
.
init
(
inType
:
gElementwiseAdd
)
-->
Node
.
init
(
inType
:
gBatchNormType
)
-->
Node
.
init
(
inType
:
gReluType
)
return
beginNode
}
static
func
change
()
->
[
String
:
[(
from
:
String
,
to
:
String
)]]
{
return
[:]
}
static
func
fusionType
()
->
String
{
return
gConvAddBatchNormReluType
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift
0 → 100644
浏览文件 @
d1d1e932
//
// ConvAddBatchNormReluOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/8.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ConvAddParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
filter
=
try
ConvAddParam
.
inputFilter
(
paraInputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
input
=
try
ConvAddParam
.
input
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
ConvAddParam
.
output
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
stride
=
try
ConvAddParam
.
getAttr
(
key
:
"strides"
,
attrs
:
opDesc
.
attrs
)
paddings
=
try
ConvAddParam
.
getAttr
(
key
:
"paddings"
,
attrs
:
opDesc
.
attrs
)
dilations
=
try
ConvAddParam
.
getAttr
(
key
:
"dilations"
,
attrs
:
opDesc
.
attrs
)
groups
=
try
ConvAddParam
.
getAttr
(
key
:
"groups"
,
attrs
:
opDesc
.
attrs
)
y
=
try
ConvAddParam
.
inputY
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
let
y
:
Tensor
<
ParamPrecisionType
>
let
filter
:
Tensor
<
ParamPrecisionType
>
var
output
:
Texture
<
P
>
let
stride
:
[
Int32
]
let
paddings
:
[
Int32
]
let
dilations
:
[
Int32
]
let
groups
:
Int
}
class
ConvAddOp
<
P
:
PrecisionType
>
:
Operator
<
ConvAddKernel
<
P
>
,
ConvAddParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
,
Fusion
{
static
func
fusionNode
()
->
Node
{
let
beginNode
=
Node
.
init
(
inType
:
gConvType
)
_
=
beginNode
-->
Node
.
init
(
inType
:
gElementwiseAdd
)
return
beginNode
}
static
func
change
()
->
[
String
:
[(
from
:
String
,
to
:
String
)]]
{
return
[:]
}
static
func
fusionType
()
->
String
{
return
gConvAddType
}
typealias
OpType
=
ConvAddOp
<
P
>
func
inferShape
()
{
let
inDims
=
para
.
input
.
dim
let
filterDim
=
para
.
filter
.
dim
let
strides
=
para
.
stride
let
paddings
=
para
.
paddings
let
dilations
=
para
.
dilations
var
outDim
=
[
inDims
[
0
]]
for
i
in
0
..<
strides
.
count
{
let
dilation
:
Int
=
Int
(
dilations
[
i
])
let
filterSize
:
Int
=
filterDim
[
i
+
1
]
let
inputSize
:
Int
=
inDims
[
i
+
1
]
let
padding
:
Int
=
Int
(
paddings
[
i
])
let
stride
:
Int
=
Int
(
strides
[
i
])
let
dKernel
=
dilation
*
(
filterSize
-
1
)
+
1
let
outputSize
=
(
inputSize
+
2
*
padding
-
dKernel
)
/
stride
+
1
outDim
.
append
(
outputSize
)
}
outDim
.
append
(
filterDim
[
0
])
para
.
output
.
dim
=
Dim
.
init
(
inDim
:
outDim
)
}
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
do
{
try
kernel
.
compute
(
commandBuffer
:
buffer
,
param
:
para
)
}
catch
let
error
{
throw
error
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
浏览文件 @
d1d1e932
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
import
Foundation
import
Foundation
struct
ConvParam
<
P
:
PrecisionType
>
:
OpParam
{
class
ConvParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
filter
=
try
ConvParam
.
inputFilter
(
paraInputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
filter
=
try
ConvParam
.
inputFilter
(
paraInputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
input
=
try
ConvParam
.
input
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
input
=
try
ConvParam
.
input
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
...
@@ -25,14 +25,15 @@ struct ConvParam<P: PrecisionType>: OpParam {
...
@@ -25,14 +25,15 @@ struct ConvParam<P: PrecisionType>: OpParam {
paddings
=
try
ConvParam
.
getAttr
(
key
:
"paddings"
,
attrs
:
opDesc
.
attrs
)
paddings
=
try
ConvParam
.
getAttr
(
key
:
"paddings"
,
attrs
:
opDesc
.
attrs
)
dilations
=
try
ConvParam
.
getAttr
(
key
:
"dilations"
,
attrs
:
opDesc
.
attrs
)
dilations
=
try
ConvParam
.
getAttr
(
key
:
"dilations"
,
attrs
:
opDesc
.
attrs
)
groups
=
try
ConvParam
.
getAttr
(
key
:
"groups"
,
attrs
:
opDesc
.
attrs
)
groups
=
try
ConvParam
.
getAttr
(
key
:
"groups"
,
attrs
:
opDesc
.
attrs
)
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
}
}
let
input
:
Texture
<
P
>
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
filter
:
Tensor
<
ParamPrecisionType
>
let
filter
:
Tensor
<
ParamPrecisionType
>
var
output
:
Texture
<
P
>
let
stride
:
[
Int32
]
let
stride
:
[
Int32
]
let
paddings
:
[
Int32
]
let
paddings
:
[
Int32
]
let
dilations
:
[
Int32
]
let
dilations
:
[
Int32
]
...
...
metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
浏览文件 @
d1d1e932
...
@@ -14,12 +14,13 @@
...
@@ -14,12 +14,13 @@
import
Foundation
import
Foundation
struct
ElementwiseAddParam
<
P
:
PrecisionType
>
:
OpParam
{
class
ElementwiseAddParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
input
=
try
ElementwiseAddParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
input
=
try
ElementwiseAddParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
inputY
=
try
ElementwiseAddParam
.
inputY
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
inputY
=
try
ElementwiseAddParam
.
inputY
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
output
=
try
ElementwiseAddParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
output
=
try
ElementwiseAddParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
axis
=
try
ElementwiseAddParam
.
getAttr
(
key
:
"axis"
,
attrs
:
opDesc
.
attrs
)
axis
=
try
ElementwiseAddParam
.
getAttr
(
key
:
"axis"
,
attrs
:
opDesc
.
attrs
)
}
catch
let
error
{
}
catch
let
error
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
浏览文件 @
d1d1e932
...
@@ -14,14 +14,14 @@
...
@@ -14,14 +14,14 @@
import
Foundation
import
Foundation
struct
FeedParam
<
P
:
PrecisionType
>
:
OpParam
{
class
FeedParam
<
P
:
PrecisionType
>
:
OpParam
{
var
output
:
Texture
<
P
>
var
output
:
Texture
<
P
>
var
input
:
InputTexture
{
var
input
:
InputTexture
{
return
scope
.
input
()
as!
InputTexture
return
scope
.
input
()
as!
InputTexture
}
}
let
scope
:
Scope
let
scope
:
Scope
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
scope
=
inScope
scope
=
inScope
do
{
do
{
output
=
try
FeedParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
output
=
try
FeedParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
...
...
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
浏览文件 @
d1d1e932
...
@@ -14,11 +14,11 @@
...
@@ -14,11 +14,11 @@
import
Foundation
import
Foundation
struct
FetchParam
<
P
:
PrecisionType
>
:
OpParam
{
class
FetchParam
<
P
:
PrecisionType
>
:
OpParam
{
var
output
:
ResultHolder
<
P
>
=
ResultHolder
.
init
(
inDim
:
[],
inResult
:
[])
var
output
:
ResultHolder
<
P
>
=
ResultHolder
.
init
(
inDim
:
[],
inResult
:
[])
let
input
:
Texture
<
P
>
let
input
:
Texture
<
P
>
let
scope
:
Scope
let
scope
:
Scope
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
scope
=
inScope
scope
=
inScope
do
{
do
{
input
=
try
FetchParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
input
=
try
FetchParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift
浏览文件 @
d1d1e932
...
@@ -9,11 +9,59 @@
...
@@ -9,11 +9,59 @@
import
Foundation
import
Foundation
class
ConvAddBatchNormReluKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
class
ConvAddBatchNormReluKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
,
param
:
ConvParam
<
P
>
)
{
var
metalParam
:
MetalConvParam
!
super
.
init
(
device
:
device
,
inFunctionName
:
"conv3x3"
)
required
init
(
device
:
MTLDevice
,
param
:
ConvAddBatchNormReluParam
<
P
>
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"conv_add_batch_norm_relu_3x3"
)
let
offsetX
=
param
.
filter
.
dim
[
2
]
/
2
-
Int
(
param
.
paddings
[
0
])
let
offsetY
=
param
.
filter
.
dim
[
1
]
/
2
-
Int
(
param
.
paddings
[
1
])
let
offsetZ
=
0.0
metalParam
=
MetalConvParam
.
init
(
offsetX
:
Int16
(
offsetX
),
offsetY
:
Int16
(
offsetY
),
offsetZ
:
Int16
(
offsetZ
),
strideX
:
UInt16
(
param
.
stride
[
0
]),
strideY
:
UInt16
(
param
.
stride
[
1
]),
paddedZ
:
UInt16
(
param
.
input
.
metalTexture
.
arrayLength
*
4
-
param
.
input
.
dim
[
3
]))
var
invs
:
[
P
]
=
[]
let
varianceContents
=
param
.
variance
.
buffer
.
contents
()
.
assumingMemoryBound
(
to
:
P
.
self
)
for
i
in
0
..<
param
.
variance
.
buffer
.
length
/
MemoryLayout
<
P
>.
stride
{
let
inv
=
pow
(
Float32
.
init
(
varianceContents
[
i
])
+
param
.
epsilon
,
0.5
)
invs
.
append
(
P
(
inv
))
}
let
newScale
:
UnsafeMutablePointer
<
P
>
=
UnsafeMutablePointer
<
P
>.
allocate
(
capacity
:
param
.
scale
.
buffer
.
length
)
let
newBiase
:
UnsafeMutablePointer
<
P
>
=
UnsafeMutablePointer
<
P
>.
allocate
(
capacity
:
param
.
bias
.
buffer
.
length
)
let
scaleContents
=
param
.
variance
.
buffer
.
contents
()
.
assumingMemoryBound
(
to
:
P
.
self
)
let
biaseContents
=
param
.
bias
.
buffer
.
contents
()
.
assumingMemoryBound
(
to
:
P
.
self
)
let
meanContents
=
param
.
mean
.
buffer
.
contents
()
.
assumingMemoryBound
(
to
:
P
.
self
)
for
i
in
0
..<
param
.
scale
.
buffer
.
length
/
MemoryLayout
<
P
>.
stride
{
newScale
[
i
]
=
invs
[
i
]
*
scaleContents
[
i
]
newBiase
[
i
]
=
biaseContents
[
i
]
-
meanContents
[
i
]
*
invs
[
i
]
*
scaleContents
[
i
]
}
param
.
newBiase
=
device
.
makeBuffer
(
bytes
:
newBiase
,
length
:
param
.
bias
.
buffer
.
length
)
param
.
newScale
=
device
.
makeBuffer
(
bytes
:
newScale
,
length
:
param
.
scale
.
buffer
.
length
)
newScale
.
deinitialize
(
count
:
param
.
scale
.
buffer
.
length
)
newScale
.
deallocate
()
newBiase
.
deinitialize
(
count
:
param
.
bias
.
buffer
.
length
)
newBiase
.
deallocate
()
}
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ConvParam
<
P
>
)
throws
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ConvAddBatchNormReluParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encode is nil"
)
}
print
(
"ConvAddBatchNormReluKernel compute"
)
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
encoder
.
setBytes
(
&
metalParam
,
length
:
MemoryLayout
<
MetalConvParam
>.
size
,
index
:
0
)
encoder
.
setBuffer
(
param
.
filter
.
buffer
,
offset
:
0
,
index
:
1
)
encoder
.
setBuffer
(
param
.
bias
.
buffer
,
offset
:
0
,
index
:
2
)
encoder
.
setBuffer
(
param
.
newScale
!
,
offset
:
0
,
index
:
3
)
encoder
.
setBuffer
(
param
.
newBiase
!
,
offset
:
0
,
index
:
4
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
}
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift
0 → 100644
浏览文件 @
d1d1e932
//
// ConvKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ConvAddKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
,
param
:
ConvAddParam
<
P
>
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"conv3x3"
)
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ConvAddParam
<
P
>
)
throws
{
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal
浏览文件 @
d1d1e932
...
@@ -52,3 +52,48 @@ kernel void conv3x3(texture2d_array<half, access::sample> inTexture [[texture(0)
...
@@ -52,3 +52,48 @@ kernel void conv3x3(texture2d_array<half, access::sample> inTexture [[texture(0)
}
}
outTexture.write(output, gid.xy, gid.z);
outTexture.write(output, gid.xy, gid.z);
}
}
kernel void conv_add_batch_norm_relu_3x3(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam ¶m [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
const device half4 *new_scale [[buffer(3)]],
const device half4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint wightSliceCount = 36;
uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
half4 output = 0.0;
for (uint i = 0; i < inTexture.get_array_size(); ++i) {
half4 input[9];
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight = weights[weithTo + wightSliceCount * i + j * 4];
output += dot(input[j], weight);
}
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0h);
outTexture.write(output, gid.xy, gid.z);
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift
0 → 100644
浏览文件 @
d1d1e932
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
class
PoolKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
PoolParam
<
P
>
)
throws
{
}
required
init
(
device
:
MTLDevice
,
param
:
PoolParam
<
P
>
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"relu"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
0 → 100644
浏览文件 @
d1d1e932
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
class
ReshapeKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
,
param
:
ReshapeParam
<
P
>
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"relu"
)
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ReshapeParam
<
P
>
)
throws
{
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift
0 → 100644
浏览文件 @
d1d1e932
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
class
SoftmaxKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
SoftmaxParam
<
P
>
)
throws
{
}
required
init
(
device
:
MTLDevice
,
param
:
SoftmaxParam
<
P
>
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"relu"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift
0 → 100644
浏览文件 @
d1d1e932
//
// PoolOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/9.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
PoolParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
input
=
try
PoolParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
PoolParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
}
class
PoolOp
<
P
:
PrecisionType
>
:
Operator
<
PoolKernel
<
P
>
,
PoolParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
}
typealias
OpType
=
PoolOp
<
P
>
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
do
{
try
kernel
.
compute
(
commandBuffer
:
buffer
,
param
:
para
)
}
catch
let
error
{
throw
error
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
浏览文件 @
d1d1e932
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
import
Foundation
import
Foundation
struct
ReluParam
<
P
:
PrecisionType
>
:
OpParam
{
class
ReluParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
input
=
try
ReluParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
input
=
try
ReluParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
ReluParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
output
=
try
ReluParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
...
...
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
0 → 100644
浏览文件 @
d1d1e932
//
// PoolOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/9.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ReshapeParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
input
=
try
ReshapeParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
ReshapeParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
}
class
ReshapeOp
<
P
:
PrecisionType
>
:
Operator
<
ReshapeKernel
<
P
>
,
ReshapeParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
}
typealias
OpType
=
ReshapeOp
<
P
>
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
do
{
try
kernel
.
compute
(
commandBuffer
:
buffer
,
param
:
para
)
}
catch
let
error
{
throw
error
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift
0 → 100644
浏览文件 @
d1d1e932
//
// PoolOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/9.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
SoftmaxParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
input
=
try
SoftmaxParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
SoftmaxParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
}
class
SoftmaxOp
<
P
:
PrecisionType
>
:
Operator
<
SoftmaxKernel
<
P
>
,
SoftmaxParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
}
typealias
OpType
=
SoftmaxOp
<
P
>
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
do
{
try
kernel
.
compute
(
commandBuffer
:
buffer
,
param
:
para
)
}
catch
let
error
{
throw
error
}
}
}
metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift
浏览文件 @
d1d1e932
...
@@ -20,7 +20,7 @@ struct OpDesc {
...
@@ -20,7 +20,7 @@ struct OpDesc {
let
outputs
:
[
String
:
[
String
]]
let
outputs
:
[
String
:
[
String
]]
let
unusedOutputs
:
[
String
:
[
String
]]
let
unusedOutputs
:
[
String
:
[
String
]]
var
attrs
:
[
String
:
Attr
]
=
[:]
var
attrs
:
[
String
:
Attr
]
=
[:]
let
type
:
String
var
type
:
String
init
(
protoOpDesc
:
PaddleMobile_Framework_Proto_OpDesc
)
{
init
(
protoOpDesc
:
PaddleMobile_Framework_Proto_OpDesc
)
{
type
=
protoOpDesc
.
type
type
=
protoOpDesc
.
type
let
creator
=
{
(
vars
:
[
PaddleMobile_Framework_Proto_OpDesc
.
Var
],
canAdd
:
(
String
)
->
Bool
)
->
[
String
:
[
String
]]
in
let
creator
=
{
(
vars
:
[
PaddleMobile_Framework_Proto_OpDesc
.
Var
],
canAdd
:
(
String
)
->
Bool
)
->
[
String
:
[
String
]]
in
...
...
metal/paddle-mobile/paddle-mobile/Program/Program.swift
浏览文件 @
d1d1e932
...
@@ -18,8 +18,8 @@ public struct Program {
...
@@ -18,8 +18,8 @@ public struct Program {
let
paramPath
:
String
let
paramPath
:
String
let
programDesc
:
ProgramDesc
let
programDesc
:
ProgramDesc
let
scope
:
Scope
let
scope
:
Scope
init
(
protoProgramDesc
:
PaddleMobile_Framework_Proto_
ProgramDesc
,
inParamPath
:
String
,
inScope
:
Scope
)
{
init
(
inProgramDesc
:
ProgramDesc
,
inParamPath
:
String
,
inScope
:
Scope
)
{
programDesc
=
ProgramDesc
.
init
(
protoProgram
:
protoProgramDesc
)
programDesc
=
inProgramDesc
paramPath
=
inParamPath
paramPath
=
inParamPath
scope
=
inScope
scope
=
inScope
}
}
...
...
metal/paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift
浏览文件 @
d1d1e932
...
@@ -18,7 +18,7 @@ infix operator --> : ChainNode
...
@@ -18,7 +18,7 @@ infix operator --> : ChainNode
class
Node
{
class
Node
{
var
inputs
:
[
Node
]
=
[]
var
inputs
:
[
Node
]
=
[]
var
outputs
:
[
Node
]
=
[]
var
outputs
:
[
Node
]
=
[]
let
type
:
String
var
type
:
String
var
opDesc
:
OpDesc
?
var
opDesc
:
OpDesc
?
init
(
inOpDesc
:
OpDesc
)
{
init
(
inOpDesc
:
OpDesc
)
{
type
=
inOpDesc
.
type
type
=
inOpDesc
.
type
...
@@ -36,11 +36,12 @@ class Node {
...
@@ -36,11 +36,12 @@ class Node {
}
}
func
depth
(
begin
:
UInt
=
1
)
->
UInt
{
func
depth
(
begin
:
UInt
=
1
)
->
UInt
{
var
beginMax
:
UInt
=
0
var
beginMax
:
UInt
=
1
for
output
in
outputs
{
for
output
in
outputs
{
let
subDepth
=
output
.
depth
(
begin
:
begin
+
1
)
let
subDepth
=
output
.
depth
(
begin
:
begin
+
1
)
beginMax
=
max
(
begin
,
subDepth
)
beginMax
=
max
(
begin
,
subDepth
)
}
}
beginMax
=
max
(
begin
,
beginMax
)
return
beginMax
return
beginMax
}
}
...
@@ -50,23 +51,26 @@ class Node {
...
@@ -50,23 +51,26 @@ class Node {
return
beginNode
return
beginNode
}
}
func
folderWith
(
fusion
:
Fusion
.
Type
)
{
func
folderWith
(
fusion
:
Fusion
.
Type
,
removedNodes
:
inout
[
Node
]
)
{
let
fusionNode
=
fusion
.
fusionNode
()
let
fusionNode
=
fusion
.
fusionNode
()
let
change
=
fusion
.
change
()
let
change
=
fusion
.
change
()
let
inOutputs
=
outputs
let
inOutputs
=
outputs
outputs
.
removeAll
()
outputs
.
removeAll
()
for
i
in
0
..<
inOutputs
.
count
{
for
i
in
0
..<
inOutputs
.
count
{
inOutputs
[
i
]
.
folderWith
(
beginNode
:
self
,
matchNode
:
fusionNode
.
outputs
[
i
],
change
:
change
)
inOutputs
[
i
]
.
folderWith
(
beginNode
:
self
,
matchNode
:
fusionNode
.
outputs
[
i
],
change
:
change
,
removedNodes
:
&
removedNodes
)
}
}
opDesc
?
.
type
=
fusion
.
fusionType
()
type
=
fusion
.
fusionType
()
}
}
private
func
folderWith
(
beginNode
:
Node
,
matchNode
:
Node
,
change
:
[
String
:
[(
from
:
String
,
to
:
String
)]])
{
private
func
folderWith
(
beginNode
:
Node
,
matchNode
:
Node
,
change
:
[
String
:
[(
from
:
String
,
to
:
String
)]]
,
removedNodes
:
inout
[
Node
]
)
{
guard
let
inOpdesc
=
opDesc
else
{
guard
let
inOpdesc
=
opDesc
else
{
fatalError
()
fatalError
()
}
}
for
attr
in
inOpdesc
.
attrs
{
for
attr
in
inOpdesc
.
attrs
{
beginNode
.
opDesc
?
.
attrs
[
attr
.
key
]
=
attr
.
value
beginNode
.
opDesc
?
.
attrs
[
attr
.
key
]
=
attr
.
value
// print(beginNode.opDesc?.attrs)
}
}
for
paraInput
in
inOpdesc
.
paraInputs
{
for
paraInput
in
inOpdesc
.
paraInputs
{
...
@@ -86,6 +90,11 @@ class Node {
...
@@ -86,6 +90,11 @@ class Node {
if
matchNode
.
outputs
.
count
==
0
{
if
matchNode
.
outputs
.
count
==
0
{
beginNode
.
outputs
.
append
(
contentsOf
:
outputs
)
beginNode
.
outputs
.
append
(
contentsOf
:
outputs
)
}
}
removedNodes
.
append
(
self
)
for
i
in
0
..<
matchNode
.
outputs
.
count
{
outputs
[
i
]
.
folderWith
(
beginNode
:
beginNode
,
matchNode
:
matchNode
.
outputs
[
i
],
change
:
change
,
removedNodes
:
&
removedNodes
)
}
}
}
...
@@ -122,11 +131,10 @@ extension Node: Equatable {
...
@@ -122,11 +131,10 @@ extension Node: Equatable {
return
true
return
true
}
}
}
}
class
ProgramOptimize
<
P
:
PrecisionType
>
{
class
ProgramOptimize
<
P
:
PrecisionType
>
{
let
fusionOps
:
[
Fusion
.
Type
]
=
[
ConvAddBatchNormReluOp
<
P
>.
self
]
let
fusionOps
:
[
Fusion
.
Type
]
=
[
ConvAddBatchNormReluOp
<
P
>.
self
,
ConvAddOp
<
P
>.
self
]
func
optimize
(
originProgramDesc
:
ProgramDesc
)
->
ProgramDesc
{
func
optimize
(
originProgramDesc
:
ProgramDesc
)
->
ProgramDesc
{
guard
originProgramDesc
.
blocks
.
count
==
1
else
{
guard
originProgramDesc
.
blocks
.
count
==
1
else
{
...
@@ -141,7 +149,7 @@ class ProgramOptimize<P: PrecisionType> {
...
@@ -141,7 +149,7 @@ class ProgramOptimize<P: PrecisionType> {
guard
let
opInputKeys
=
opInfos
[
opDesc
.
type
]?
.
inputs
,
let
outputKeys
=
opInfos
[
opDesc
.
type
]?
.
outputs
else
{
guard
let
opInputKeys
=
opInfos
[
opDesc
.
type
]?
.
inputs
,
let
outputKeys
=
opInfos
[
opDesc
.
type
]?
.
outputs
else
{
fatalError
()
fatalError
()
}
}
let
node
=
Node
.
init
(
inOpDesc
:
opDesc
)
let
node
=
Node
.
init
(
inOpDesc
:
opDesc
)
for
inputKey
in
opInputKeys
{
for
inputKey
in
opInputKeys
{
if
let
inputs
=
opDesc
.
inputs
[
inputKey
]
{
if
let
inputs
=
opDesc
.
inputs
[
inputKey
]
{
...
@@ -164,28 +172,32 @@ class ProgramOptimize<P: PrecisionType> {
...
@@ -164,28 +172,32 @@ class ProgramOptimize<P: PrecisionType> {
nodes
.
append
(
node
)
nodes
.
append
(
node
)
if
var
n
odes
=
typeMapNodes
[
opDesc
.
type
]
{
if
var
inN
odes
=
typeMapNodes
[
opDesc
.
type
]
{
n
odes
.
append
(
node
)
inN
odes
.
append
(
node
)
typeMapNodes
[
opDesc
.
type
]
=
n
odes
typeMapNodes
[
opDesc
.
type
]
=
inN
odes
}
else
{
}
else
{
typeMapNodes
[
opDesc
.
type
]
=
[]
typeMapNodes
[
opDesc
.
type
]
=
[
node
]
}
}
}
}
for
fusion
in
fusionOps
{
for
fusion
in
fusionOps
{
let
fusionNode
=
fusion
.
fusionNode
()
let
fusionNode
=
fusion
.
fusionNode
()
let
depth
=
fusionNode
.
depth
()
let
depth
=
fusionNode
.
depth
()
print
(
depth
)
if
let
toMatchNodes
=
typeMapNodes
[
fusionNode
.
type
]
{
if
let
nodes
=
typeMapNodes
[
fusionNode
.
type
]
{
for
node
in
toMatchNodes
{
for
node
in
nodes
{
let
toNode
=
node
.
to
(
depth
:
depth
)
let
toNode
=
node
.
to
(
depth
:
4
)
if
toNode
==
fusionNode
{
// match
if
toNode
==
fusionNode
{
// match
node
.
folderWith
(
fusion
:
fusion
)
var
removeNodes
:
[
Node
]
=
[]
node
.
folderWith
(
fusion
:
fusion
,
removedNodes
:
&
removeNodes
)
for
removeNode
in
removeNodes
{
nodes
.
remove
(
element
:
removeNode
)
}
}
}
}
}
}
}
}
}
var
ops
:
[
OpDesc
]
=
[]
var
ops
:
[
OpDesc
]
=
[]
for
node
in
nodes
{
for
node
in
nodes
{
ops
.
append
(
node
.
opDesc
!
)
ops
.
append
(
node
.
opDesc
!
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录