Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
64a171ca
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
64a171ca
编写于
7月 05, 2018
作者:
L
liuruilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add relu kernel
上级
b06a4e75
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
347 addition
and
93 deletion
+347
-93
metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
...liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
+1
-1
metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
...addle-mobile-demo/paddle-mobile-demo/ViewController.swift
+3
-3
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
+16
-0
metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
...liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
+1
-1
metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
...l/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
+7
-1
metal/paddle-mobile/paddle-mobile/Common/Types.swift
metal/paddle-mobile/paddle-mobile/Common/Types.swift
+11
-3
metal/paddle-mobile/paddle-mobile/Executor.swift
metal/paddle-mobile/paddle-mobile/Executor.swift
+26
-4
metal/paddle-mobile/paddle-mobile/Loader.swift
metal/paddle-mobile/paddle-mobile/Loader.swift
+21
-8
metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
...addle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
+3
-3
metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
...paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
+24
-13
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
...l/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
+4
-4
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
+4
-4
metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
...dle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
+4
-4
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
+14
-7
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
+3
-3
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
...ile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
+19
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
...e-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
+20
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift
...addle-mobile/Operators/Kernels/ElementwiseAddKernel.swift
+20
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift
...addle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift
+6
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal
...ddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal
+73
-16
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift
...e-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift
+32
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift
...mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift
+16
-11
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
+9
-5
metal/paddle-mobile/paddle-mobile/framework/Texture.swift
metal/paddle-mobile/paddle-mobile/framework/Texture.swift
+10
-2
未找到文件。
metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
浏览文件 @
64a171ca
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
<key>
paddle-mobile-demo.xcscheme
</key>
<key>
paddle-mobile-demo.xcscheme
</key>
<dict>
<dict>
<key>
orderHint
</key>
<key>
orderHint
</key>
<integer>
4
</integer>
<integer>
3
</integer>
</dict>
</dict>
</dict>
</dict>
</dict>
</dict>
...
...
metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
浏览文件 @
64a171ca
...
@@ -36,13 +36,13 @@ class ViewController: UIViewController {
...
@@ -36,13 +36,13 @@ class ViewController: UIViewController {
fatalError
(
" texture is nil !"
)
fatalError
(
" texture is nil !"
)
}
}
let
loader
=
Loader
<
Float
>.
init
()
let
loader
=
Loader
<
Float
16
>.
init
()
do
{
do
{
let
modelPath
=
Bundle
.
main
.
path
(
forResource
:
"model"
,
ofType
:
nil
)
?
!
"model null"
let
modelPath
=
Bundle
.
main
.
path
(
forResource
:
"model"
,
ofType
:
nil
)
?
!
"model null"
let
paraPath
=
Bundle
.
main
.
path
(
forResource
:
"params"
,
ofType
:
nil
)
?
!
"para null"
let
paraPath
=
Bundle
.
main
.
path
(
forResource
:
"params"
,
ofType
:
nil
)
?
!
"para null"
let
program
=
try
loader
.
load
(
device
:
device
,
modelPath
:
modelPath
,
paraPath
:
paraPath
)
let
program
=
try
loader
.
load
(
device
:
device
,
modelPath
:
modelPath
,
paraPath
:
paraPath
)
let
executor
=
try
Executor
<
Float
>.
init
(
inProgram
:
program
)
let
executor
=
try
Executor
<
Float
16
>.
init
(
inDevice
:
device
,
inQueue
:
queue
!
,
inProgram
:
program
)
let
output
=
try
executor
.
predict
(
input
:
inTexture
,
expect
:
[
1
,
22
4
,
224
,
3
])
let
output
=
try
executor
.
predict
(
input
:
inTexture
,
expect
:
[
1
,
22
7
,
227
,
3
])
print
(
output
)
print
(
output
)
}
catch
let
error
{
}
catch
let
error
{
print
(
error
)
print
(
error
)
...
...
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
浏览文件 @
64a171ca
...
@@ -30,6 +30,10 @@
...
@@ -30,6 +30,10 @@
FC039BBE20E11CC20081E9F8
/* OpDesc.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB520E11CC20081E9F8
/* OpDesc.swift */
;
};
FC039BBE20E11CC20081E9F8
/* OpDesc.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB520E11CC20081E9F8
/* OpDesc.swift */
;
};
FC039BBF20E11CC20081E9F8
/* Attribute.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB620E11CC20081E9F8
/* Attribute.swift */
;
};
FC039BBF20E11CC20081E9F8
/* Attribute.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB620E11CC20081E9F8
/* Attribute.swift */
;
};
FC039BC020E11CC20081E9F8
/* BlockDesc.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB720E11CC20081E9F8
/* BlockDesc.swift */
;
};
FC039BC020E11CC20081E9F8
/* BlockDesc.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB720E11CC20081E9F8
/* BlockDesc.swift */
;
};
FC0E2DBA20EE3B8D009C1FAC
/* ReluKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DB920EE3B8D009C1FAC
/* ReluKernel.swift */
;
};
FC0E2DBC20EE45FE009C1FAC
/* ConvKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DBB20EE45FE009C1FAC
/* ConvKernel.swift */
;
};
FC0E2DBE20EE460D009C1FAC
/* BatchNormKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DBD20EE460D009C1FAC
/* BatchNormKernel.swift */
;
};
FC0E2DC020EE461F009C1FAC
/* ElementwiseAddKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DBF20EE461F009C1FAC
/* ElementwiseAddKernel.swift */
;
};
FC1B16B320EC9A4F00678B91
/* Kernels.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
;
};
FC1B16B320EC9A4F00678B91
/* Kernels.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
;
};
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
;
};
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
;
};
FC60DB8920E9AAA500FF203F
/* MetalExtension.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC60DB8820E9AAA500FF203F
/* MetalExtension.swift */
;
};
FC60DB8920E9AAA500FF203F
/* MetalExtension.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC60DB8820E9AAA500FF203F
/* MetalExtension.swift */
;
};
...
@@ -69,6 +73,10 @@
...
@@ -69,6 +73,10 @@
FC039BB520E11CC20081E9F8
/* OpDesc.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
OpDesc.swift
;
sourceTree
=
"<group>"
;
};
FC039BB520E11CC20081E9F8
/* OpDesc.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
OpDesc.swift
;
sourceTree
=
"<group>"
;
};
FC039BB620E11CC20081E9F8
/* Attribute.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
Attribute.swift
;
sourceTree
=
"<group>"
;
};
FC039BB620E11CC20081E9F8
/* Attribute.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
Attribute.swift
;
sourceTree
=
"<group>"
;
};
FC039BB720E11CC20081E9F8
/* BlockDesc.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
BlockDesc.swift
;
sourceTree
=
"<group>"
;
};
FC039BB720E11CC20081E9F8
/* BlockDesc.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
BlockDesc.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DB920EE3B8D009C1FAC
/* ReluKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ReluKernel.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DBB20EE45FE009C1FAC
/* ConvKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ConvKernel.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DBD20EE460D009C1FAC
/* BatchNormKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
BatchNormKernel.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DBF20EE461F009C1FAC
/* ElementwiseAddKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ElementwiseAddKernel.swift
;
sourceTree
=
"<group>"
;
};
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.metal
;
path
=
Kernels.metal
;
sourceTree
=
"<group>"
;
};
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.metal
;
path
=
Kernels.metal
;
sourceTree
=
"<group>"
;
};
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ResizeKernel.swift
;
sourceTree
=
"<group>"
;
};
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ResizeKernel.swift
;
sourceTree
=
"<group>"
;
};
FC60DB8820E9AAA500FF203F
/* MetalExtension.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
MetalExtension.swift
;
sourceTree
=
"<group>"
;
};
FC60DB8820E9AAA500FF203F
/* MetalExtension.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
MetalExtension.swift
;
sourceTree
=
"<group>"
;
};
...
@@ -197,9 +205,13 @@
...
@@ -197,9 +205,13 @@
FC086BA520E67E8500D85EF7
/* Kernels */
=
{
FC086BA520E67E8500D85EF7
/* Kernels */
=
{
isa
=
PBXGroup
;
isa
=
PBXGroup
;
children
=
(
children
=
(
FC0E2DBB20EE45FE009C1FAC
/* ConvKernel.swift */
,
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
,
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
,
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
,
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
,
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
,
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
,
FC0E2DB920EE3B8D009C1FAC
/* ReluKernel.swift */
,
FC0E2DBD20EE460D009C1FAC
/* BatchNormKernel.swift */
,
FC0E2DBF20EE461F009C1FAC
/* ElementwiseAddKernel.swift */
,
);
);
path
=
Kernels
;
path
=
Kernels
;
sourceTree
=
"<group>"
;
sourceTree
=
"<group>"
;
...
@@ -316,12 +328,14 @@
...
@@ -316,12 +328,14 @@
files
=
(
files
=
(
FC9D038020E22FBB000F735A
/* FeedOp.swift in Sources */
,
FC9D038020E22FBB000F735A
/* FeedOp.swift in Sources */
,
FC039B9F20E11CB20081E9F8
/* Tensor.swift in Sources */
,
FC039B9F20E11CB20081E9F8
/* Tensor.swift in Sources */
,
FC0E2DBC20EE45FE009C1FAC
/* ConvKernel.swift in Sources */
,
FC039BAA20E11CBC0081E9F8
/* ElementwiseAddOp.swift in Sources */
,
FC039BAA20E11CBC0081E9F8
/* ElementwiseAddOp.swift in Sources */
,
FC039B9B20E11CA00081E9F8
/* Executor.swift in Sources */
,
FC039B9B20E11CA00081E9F8
/* Executor.swift in Sources */
,
FC039BBB20E11CC20081E9F8
/* ProgramDesc.swift in Sources */
,
FC039BBB20E11CC20081E9F8
/* ProgramDesc.swift in Sources */
,
FC9D037920E229E4000F735A
/* OpParam.swift in Sources */
,
FC9D037920E229E4000F735A
/* OpParam.swift in Sources */
,
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
,
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
,
FCF2D73820E64E70007AC5F5
/* Kernel.swift in Sources */
,
FCF2D73820E64E70007AC5F5
/* Kernel.swift in Sources */
,
FC0E2DC020EE461F009C1FAC
/* ElementwiseAddKernel.swift in Sources */
,
FC60DB8920E9AAA500FF203F
/* MetalExtension.swift in Sources */
,
FC60DB8920E9AAA500FF203F
/* MetalExtension.swift in Sources */
,
FC1B16B320EC9A4F00678B91
/* Kernels.metal in Sources */
,
FC1B16B320EC9A4F00678B91
/* Kernels.metal in Sources */
,
FC039BBA20E11CC20081E9F8
/* TensorDesc.swift in Sources */
,
FC039BBA20E11CC20081E9F8
/* TensorDesc.swift in Sources */
,
...
@@ -335,7 +349,9 @@
...
@@ -335,7 +349,9 @@
FC039BB920E11CC20081E9F8
/* Scope.swift in Sources */
,
FC039BB920E11CC20081E9F8
/* Scope.swift in Sources */
,
FC039BAC20E11CBC0081E9F8
/* BatchNormOp.swift in Sources */
,
FC039BAC20E11CBC0081E9F8
/* BatchNormOp.swift in Sources */
,
FC039BBC20E11CC20081E9F8
/* VarDesc.swift in Sources */
,
FC039BBC20E11CC20081E9F8
/* VarDesc.swift in Sources */
,
FC0E2DBA20EE3B8D009C1FAC
/* ReluKernel.swift in Sources */
,
FC82735920E3C04200BE430A
/* OpCreator.swift in Sources */
,
FC82735920E3C04200BE430A
/* OpCreator.swift in Sources */
,
FC0E2DBE20EE460D009C1FAC
/* BatchNormKernel.swift in Sources */
,
FC039BAB20E11CBC0081E9F8
/* Operator.swift in Sources */
,
FC039BAB20E11CBC0081E9F8
/* Operator.swift in Sources */
,
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
,
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
,
FC039BBD20E11CC20081E9F8
/* Program.swift in Sources */
,
FC039BBD20E11CC20081E9F8
/* Program.swift in Sources */
,
...
...
metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
浏览文件 @
64a171ca
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
<key>
paddle-mobile.xcscheme
</key>
<key>
paddle-mobile.xcscheme
</key>
<dict>
<dict>
<key>
orderHint
</key>
<key>
orderHint
</key>
<integer>
3
</integer>
<integer>
4
</integer>
</dict>
</dict>
</dict>
</dict>
</dict>
</dict>
...
...
metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
浏览文件 @
64a171ca
...
@@ -29,11 +29,11 @@ extension MTLDevice {
...
@@ -29,11 +29,11 @@ extension MTLDevice {
fatalError
(
"Counld't find paddle mobile library"
)
fatalError
(
"Counld't find paddle mobile library"
)
}
}
do
{
do
{
print
(
path
)
paddleMobileMetalLibrary
=
try
makeLibrary
(
filepath
:
path
)
paddleMobileMetalLibrary
=
try
makeLibrary
(
filepath
:
path
)
}
catch
_
{
}
catch
_
{
fatalError
(
"Counld't load paddle mobile library"
)
fatalError
(
"Counld't load paddle mobile library"
)
}
}
paddleMobileMetalLibrary
=
makeDefaultLibrary
()
}
}
if
let
inPaddleMobileLib
=
paddleMobileMetalLibrary
{
if
let
inPaddleMobileLib
=
paddleMobileMetalLibrary
{
...
@@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder {
...
@@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder {
let
height
=
computePipline
.
maxTotalThreadsPerThreadgroup
/
width
let
height
=
computePipline
.
maxTotalThreadsPerThreadgroup
/
width
let
threadsPerGroup
=
MTLSize
.
init
(
width
:
width
,
height
:
height
,
depth
:
1
)
let
threadsPerGroup
=
MTLSize
.
init
(
width
:
width
,
height
:
height
,
depth
:
1
)
print
(
" threads per group:
\(
threadsPerGroup
)
"
)
print
(
" out texture width:
\(
outTexture
.
width
)
, out texture height:
\(
outTexture
.
height
)
"
)
let
groupWidth
=
(
outTexture
.
width
+
width
-
1
)
/
width
let
groupWidth
=
(
outTexture
.
width
+
width
-
1
)
/
width
let
groupHeight
=
(
outTexture
.
height
+
height
-
1
)
/
height
let
groupHeight
=
(
outTexture
.
height
+
height
-
1
)
/
height
let
groupDepth
=
slices
let
groupDepth
=
slices
let
groups
=
MTLSize
.
init
(
width
:
groupWidth
,
height
:
groupHeight
,
depth
:
groupDepth
)
let
groups
=
MTLSize
.
init
(
width
:
groupWidth
,
height
:
groupHeight
,
depth
:
groupDepth
)
print
(
"groups:
\(
groups
)
"
)
setComputePipelineState
(
computePipline
)
setComputePipelineState
(
computePipline
)
dispatchThreadgroups
(
groups
,
threadsPerThreadgroup
:
threadsPerGroup
)
dispatchThreadgroups
(
groups
,
threadsPerThreadgroup
:
threadsPerGroup
)
}
}
...
...
metal/paddle-mobile/paddle-mobile/Common/Types.swift
浏览文件 @
64a171ca
...
@@ -14,14 +14,22 @@
...
@@ -14,14 +14,22 @@
import
Foundation
import
Foundation
//typealias Float16 = Int16
public
typealias
Float16
=
Int16
//extension Float16: PrecisionType {
extension
Float16
:
PrecisionType
{
//}
public
init
(
inFloat
:
Float32
)
{
self
=
Int16
(
inFloat
)
}
}
public
protocol
PrecisionType
{
public
protocol
PrecisionType
{
init
(
inFloat
:
Float32
)
}
}
extension
Float32
:
PrecisionType
{
extension
Float32
:
PrecisionType
{
public
init
(
inFloat
:
Float32
)
{
self
=
inFloat
}
}
}
public
enum
DataLayout
{
public
enum
DataLayout
{
...
...
metal/paddle-mobile/paddle-mobile/Executor.swift
浏览文件 @
64a171ca
...
@@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
...
@@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
public
class
Executor
<
P
:
PrecisionType
>
{
public
class
Executor
<
P
:
PrecisionType
>
{
var
ops
:
[
Runable
&
InferShaperable
]
=
[]
var
ops
:
[
Runable
&
InferShaperable
]
=
[]
let
program
:
Program
let
program
:
Program
let
device
:
MTLDevice
public
init
(
inProgram
:
Program
)
throws
{
let
queue
:
MTLCommandQueue
public
init
(
inDevice
:
MTLDevice
,
inQueue
:
MTLCommandQueue
,
inProgram
:
Program
)
throws
{
program
=
inProgram
program
=
inProgram
device
=
inDevice
queue
=
inQueue
for
block
in
inProgram
.
programDesc
.
blocks
{
for
block
in
inProgram
.
programDesc
.
blocks
{
for
op
in
block
.
ops
{
for
op
in
block
.
ops
{
do
{
do
{
let
op
=
try
OpCreator
<
P
>.
shared
.
creat
(
opDesc
:
op
,
scope
:
inProgram
.
scope
)
let
op
=
try
OpCreator
<
P
>.
shared
.
creat
(
device
:
inDevice
,
opDesc
:
op
,
scope
:
inProgram
.
scope
)
op
.
inferShape
()
op
.
inferShape
()
ops
.
append
(
op
)
ops
.
append
(
op
)
}
catch
let
error
{
}
catch
let
error
{
...
@@ -65,12 +68,29 @@ public class Executor<P: PrecisionType> {
...
@@ -65,12 +68,29 @@ public class Executor<P: PrecisionType> {
}
}
public
func
predict
(
input
:
MTLTexture
,
expect
:
[
Int
])
throws
->
ResultHolder
<
P
>
{
public
func
predict
(
input
:
MTLTexture
,
expect
:
[
Int
])
throws
->
ResultHolder
<
P
>
{
let
beforeDate
=
Date
.
init
()
let
inputTexture
=
InputTexture
.
init
(
inMTLTexture
:
input
,
inExpectDim
:
Dim
.
init
(
inDim
:
expect
))
let
inputTexture
=
InputTexture
.
init
(
inMTLTexture
:
input
,
inExpectDim
:
Dim
.
init
(
inDim
:
expect
))
program
.
scope
.
setInput
(
input
:
inputTexture
)
program
.
scope
.
setInput
(
input
:
inputTexture
)
guard
let
buffer
=
queue
.
makeCommandBuffer
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
"CommandBuffer is nil"
)
}
for
op
in
ops
{
for
op
in
ops
{
op
.
run
()
do
{
try
op
.
run
(
device
:
device
,
buffer
:
buffer
)
}
catch
let
error
{
throw
error
}
}
}
buffer
.
addCompletedHandler
{
(
commandbuffer
)
in
let
afterDate
=
Date
.
init
()
print
(
afterDate
.
timeIntervalSince
(
beforeDate
))
print
(
" encoder end ! "
)
}
buffer
.
commit
()
guard
let
outputVar
=
program
.
scope
.
output
()
else
{
guard
let
outputVar
=
program
.
scope
.
output
()
else
{
throw
PaddleMobileError
.
netError
(
message
:
"output nil"
)
throw
PaddleMobileError
.
netError
(
message
:
"output nil"
)
}
}
...
@@ -78,6 +98,8 @@ public class Executor<P: PrecisionType> {
...
@@ -78,6 +98,8 @@ public class Executor<P: PrecisionType> {
guard
let
output
=
outputVar
as?
ResultHolder
<
P
>
else
{
guard
let
output
=
outputVar
as?
ResultHolder
<
P
>
else
{
throw
PaddleMobileError
.
netError
(
message
:
"output var type error"
)
throw
PaddleMobileError
.
netError
(
message
:
"output var type error"
)
}
}
return
output
return
output
}
}
...
...
metal/paddle-mobile/paddle-mobile/Loader.swift
浏览文件 @
64a171ca
...
@@ -68,11 +68,24 @@ public class Loader<P: PrecisionType> {
...
@@ -68,11 +68,24 @@ public class Loader<P: PrecisionType> {
/*
/*
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
*/
*/
let
bytesRead
=
fread
(
tensor
.
data
.
pointer
,
1
,
tensor
.
data
.
size
,
file
)
//现在模型传入模型为 Float 类型, 这块应该根据模型来
guard
bytesRead
==
tensor
.
data
.
size
else
{
let
tmpCapacity
=
MemoryLayout
<
Float
>.
size
*
tensor
.
numel
()
throw
PaddleMobileError
.
loaderError
(
message
:
"param read size error"
)
let
tmpPointer
=
UnsafeMutablePointer
<
Float
>.
allocate
(
capacity
:
tmpCapacity
);
// let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file)
// guard bytesRead == tensor.data.size else {
// throw PaddleMobileError.loaderError(message: "param read size error")
// }
// TODO: use script to convert
let
bytesRead
=
fread
(
tmpPointer
,
1
,
tmpCapacity
,
file
)
for
i
in
0
..<
tensor
.
numel
()
{
tensor
.
data
[
i
]
=
P
.
init
(
inFloat
:
tmpPointer
[
i
])
}
}
tmpPointer
.
deinitialize
(
count
:
tmpCapacity
)
tmpPointer
.
deallocate
()
nowIndex
+=
bytesRead
nowIndex
+=
bytesRead
}
}
...
@@ -125,9 +138,9 @@ public class Loader<P: PrecisionType> {
...
@@ -125,9 +138,9 @@ public class Loader<P: PrecisionType> {
throw
PaddleMobileError
.
loaderError
(
message
:
"get tensor desc failed"
)
throw
PaddleMobileError
.
loaderError
(
message
:
"get tensor desc failed"
)
}
}
guard
(
try
?
tensorDesc
.
dataType
.
dataTypeSize
())
==
MemoryLayout
<
P
>.
size
else
{
//
guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
throw
PaddleMobileError
.
memoryError
(
message
:
"PrecisionType not support"
)
//
throw PaddleMobileError.memoryError(message: "PrecisionType not support")
}
//
}
if
(
varDesc
.
persistable
if
(
varDesc
.
persistable
&&
varDesc
.
type
!=
.
FeedMiniBatch
&&
varDesc
.
type
!=
.
FeedMiniBatch
...
@@ -149,7 +162,7 @@ public class Loader<P: PrecisionType> {
...
@@ -149,7 +162,7 @@ public class Loader<P: PrecisionType> {
scope
[
varDesc
.
name
]
=
tensor
scope
[
varDesc
.
name
]
=
tensor
}
else
{
}
else
{
let
dim
=
Dim
.
init
(
inDim
:
tensorDesc
.
NHWCDim
)
let
dim
=
Dim
.
init
(
inDim
:
tensorDesc
.
NHWCDim
)
scope
[
varDesc
.
name
]
=
Texture
.
init
(
device
:
device
,
inDim
:
dim
)
scope
[
varDesc
.
name
]
=
Texture
<
P
>
.
init
(
device
:
device
,
inDim
:
dim
)
}
}
}
else
{
}
else
{
if
varDesc
.
name
==
fetchKey
{
if
varDesc
.
name
==
fetchKey
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
浏览文件 @
64a171ca
...
@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> {
...
@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> {
}
}
}
}
func
creat
(
opDesc
:
OpDesc
,
scope
:
Scope
)
throws
->
Runable
&
InferShaperable
{
func
creat
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
scope
:
Scope
)
throws
->
Runable
&
InferShaperable
{
guard
let
opCreator
=
opCreators
[
opDesc
.
type
]
else
{
guard
let
opCreator
=
opCreators
[
opDesc
.
type
]
else
{
throw
PaddleMobileError
.
opError
(
message
:
"there is no "
+
opDesc
.
type
+
" yet"
)
throw
PaddleMobileError
.
opError
(
message
:
"there is no "
+
opDesc
.
type
+
" yet"
)
}
}
do
{
do
{
return
try
opCreator
(
opDesc
,
scope
)
return
try
opCreator
(
device
,
opDesc
,
scope
)
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
}
}
let
opCreators
:
[
String
:
(
OpDesc
,
Scope
)
throws
->
Runable
&
InferShaperable
]
=
let
opCreators
:
[
String
:
(
MTLDevice
,
OpDesc
,
Scope
)
throws
->
Runable
&
InferShaperable
]
=
[
gConvType
:
ConvOp
<
P
>.
creat
,
[
gConvType
:
ConvOp
<
P
>.
creat
,
gBatchNormType
:
BatchNormOp
<
P
>.
creat
,
gBatchNormType
:
BatchNormOp
<
P
>.
creat
,
gReluType
:
ReluOp
<
P
>.
creat
,
gReluType
:
ReluOp
<
P
>.
creat
,
...
...
metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
浏览文件 @
64a171ca
...
@@ -12,29 +12,35 @@
...
@@ -12,29 +12,35 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
import
Metal
import
Foundation
import
Foundation
protocol
Runable
{
protocol
Runable
{
func
run
(
)
func
run
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
func
runImpl
(
)
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
}
}
extension
Runable
where
Self
:
OperatorProtocol
{
extension
Runable
where
Self
:
OperatorProtocol
{
func
run
()
{
func
run
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
runImpl
()
do
{
try
runImpl
(
device
:
device
,
buffer
:
buffer
)
}
catch
let
error
{
throw
error
}
print
(
type
+
": "
+
para
.
outputDesc
())
print
(
type
+
": "
+
para
.
outputDesc
())
}
}
}
}
protocol
Creator
where
Self
:
OperatorProtocol
{
protocol
Creator
where
Self
:
OperatorProtocol
{
associatedtype
OpType
:
OperatorProtocol
&
Runable
&
InferShaperable
associatedtype
OpType
:
OperatorProtocol
&
Runable
&
InferShaperable
static
func
creat
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
static
func
creat
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
}
}
extension
Creator
where
Self
:
OperatorProtocol
{
extension
Creator
where
Self
:
OperatorProtocol
{
static
func
creat
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
{
static
func
creat
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
{
do
{
do
{
return
try
OpType
.
provide
(
opDesc
:
opDesc
,
inScope
:
inScope
)
return
try
OpType
.
provide
(
device
:
device
,
opDesc
:
opDesc
,
inScope
:
inScope
)
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
...
@@ -47,19 +53,21 @@ protocol InferShaperable {
...
@@ -47,19 +53,21 @@ protocol InferShaperable {
protocol
OperatorProtocol
{
protocol
OperatorProtocol
{
associatedtype
ParamType
:
OpParam
associatedtype
ParamType
:
OpParam
associatedtype
KerType
:
Computable
var
type
:
String
{
get
}
var
type
:
String
{
get
}
var
inputs
:
[
String
:
[
String
]]
{
get
}
var
inputs
:
[
String
:
[
String
]]
{
get
}
var
paraInputs
:
[
String
:
[
String
]]
{
get
}
var
paraInputs
:
[
String
:
[
String
]]
{
get
}
var
outpus
:
[
String
:
[
String
]]
{
get
}
var
outpus
:
[
String
:
[
String
]]
{
get
}
var
attrs
:
[
String
:
Attr
]
{
get
}
var
attrs
:
[
String
:
Attr
]
{
get
}
var
para
:
ParamType
{
get
}
var
para
:
ParamType
{
get
}
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
var
kernel
:
KerType
{
get
}
init
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
}
}
extension
OperatorProtocol
{
extension
OperatorProtocol
{
static
func
provide
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
Self
{
static
func
provide
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
Self
{
do
{
do
{
return
try
Self
.
init
(
opDesc
:
opDesc
,
inScope
:
inScope
)
return
try
Self
.
init
(
device
:
device
,
opDesc
:
opDesc
,
inScope
:
inScope
)
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
...
@@ -67,20 +75,23 @@ extension OperatorProtocol {
...
@@ -67,20 +75,23 @@ extension OperatorProtocol {
}
}
class
Operator
<
ParameterType
:
OpParam
>
:
OperatorProtocol
{
class
Operator
<
ParameterType
:
OpParam
,
KernelType
:
Computable
>
:
OperatorProtocol
{
typealias
ParamType
=
ParameterType
typealias
ParamType
=
ParameterType
typealias
KerType
=
KernelType
let
type
:
String
let
type
:
String
let
inputs
:
[
String
:
[
String
]]
let
inputs
:
[
String
:
[
String
]]
let
paraInputs
:
[
String
:
[
String
]]
let
paraInputs
:
[
String
:
[
String
]]
let
outpus
:
[
String
:
[
String
]]
let
outpus
:
[
String
:
[
String
]]
let
attrs
:
[
String
:
Attr
]
let
attrs
:
[
String
:
Attr
]
let
para
:
ParamType
let
para
:
ParamType
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
var
kernel
:
KerType
required
init
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
type
=
opDesc
.
type
type
=
opDesc
.
type
inputs
=
opDesc
.
inputs
inputs
=
opDesc
.
inputs
outpus
=
opDesc
.
outputs
outpus
=
opDesc
.
outputs
attrs
=
opDesc
.
attrs
attrs
=
opDesc
.
attrs
paraInputs
=
opDesc
.
paraInputs
paraInputs
=
opDesc
.
paraInputs
kernel
=
KerType
.
init
(
device
:
device
)
do
{
do
{
para
=
try
ParamType
.
init
(
opDesc
:
opDesc
,
inScope
:
inScope
)
para
=
try
ParamType
.
init
(
opDesc
:
opDesc
,
inScope
:
inScope
)
}
catch
let
error
{
}
catch
let
error
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
浏览文件 @
64a171ca
...
@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
...
@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
throw
error
throw
error
}
}
}
}
let
input
:
Texture
let
input
:
Texture
<
P
>
var
output
:
Texture
var
output
:
Texture
<
P
>
let
inputBias
:
Tensor
<
ParamPrecisionType
>
let
inputBias
:
Tensor
<
ParamPrecisionType
>
let
inputMean
:
Tensor
<
ParamPrecisionType
>
let
inputMean
:
Tensor
<
ParamPrecisionType
>
let
inputScale
:
Tensor
<
ParamPrecisionType
>
let
inputScale
:
Tensor
<
ParamPrecisionType
>
...
@@ -42,12 +42,12 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
...
@@ -42,12 +42,12 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let
is_test
:
Bool
let
is_test
:
Bool
}
}
class
BatchNormOp
<
P
:
PrecisionType
>
:
Operator
<
BatchNormParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
BatchNormOp
<
P
:
PrecisionType
>
:
Operator
<
BatchNormParam
<
P
>
,
BatchNormKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
para
.
output
.
dim
=
para
.
input
.
dim
}
}
typealias
OpType
=
BatchNormOp
<
P
>
typealias
OpType
=
BatchNormOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is BatchNormOp"
)
print
(
"this is BatchNormOp"
)
}
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
浏览文件 @
64a171ca
...
@@ -30,8 +30,8 @@ struct ConvParam<P: PrecisionType>: OpParam {
...
@@ -30,8 +30,8 @@ struct ConvParam<P: PrecisionType>: OpParam {
}
}
}
}
let
input
:
Texture
let
input
:
Texture
<
P
>
var
output
:
Texture
var
output
:
Texture
<
P
>
let
filter
:
Tensor
<
ParamPrecisionType
>
let
filter
:
Tensor
<
ParamPrecisionType
>
let
stride
:
[
Int32
]
let
stride
:
[
Int32
]
let
paddings
:
[
Int32
]
let
paddings
:
[
Int32
]
...
@@ -39,7 +39,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
...
@@ -39,7 +39,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
let
groups
:
Int
let
groups
:
Int
}
}
class
ConvOp
<
P
:
PrecisionType
>
:
Operator
<
ConvParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ConvOp
<
P
:
PrecisionType
>
:
Operator
<
ConvParam
<
P
>
,
ConvKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
func
inferShape
()
{
let
inDims
=
para
.
input
.
dim
let
inDims
=
para
.
input
.
dim
let
filterDim
=
para
.
filter
.
dim
let
filterDim
=
para
.
filter
.
dim
...
@@ -63,7 +63,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferS
...
@@ -63,7 +63,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferS
}
}
typealias
OpType
=
ConvOp
<
P
>
typealias
OpType
=
ConvOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is conv"
)
print
(
"this is conv"
)
}
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
浏览文件 @
64a171ca
...
@@ -26,20 +26,20 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
...
@@ -26,20 +26,20 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
throw
error
throw
error
}
}
}
}
let
input
:
Texture
let
input
:
Texture
<
P
>
let
inputY
:
Tensor
<
P
>
let
inputY
:
Tensor
<
P
>
var
output
:
Texture
var
output
:
Texture
<
P
>
let
axis
:
Int
let
axis
:
Int
}
}
class
ElementwiseAddOp
<
P
:
PrecisionType
>
:
Operator
<
ElementwiseAddParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ElementwiseAddOp
<
P
:
PrecisionType
>
:
Operator
<
ElementwiseAddParam
<
P
>
,
ElementwiseAddKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
para
.
output
.
dim
=
para
.
input
.
dim
}
}
typealias
OpType
=
ElementwiseAddOp
<
P
>
typealias
OpType
=
ElementwiseAddOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is ElementwiseAddOp"
)
print
(
"this is ElementwiseAddOp"
)
}
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
浏览文件 @
64a171ca
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
Foundation
import
Foundation
struct
FeedParam
<
P
:
PrecisionType
>
:
OpParam
{
struct
FeedParam
<
P
:
PrecisionType
>
:
OpParam
{
var
output
:
Texture
var
output
:
Texture
<
P
>
var
input
:
InputTexture
{
var
input
:
InputTexture
{
return
scope
.
input
()
as!
InputTexture
return
scope
.
input
()
as!
InputTexture
}
}
...
@@ -33,19 +33,26 @@ struct FeedParam<P: PrecisionType>: OpParam{
...
@@ -33,19 +33,26 @@ struct FeedParam<P: PrecisionType>: OpParam{
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
}
}
class
FeedOp
<
P
:
PrecisionType
>
:
Operator
<
FeedParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
FeedOp
<
P
:
PrecisionType
>
:
Operator
<
FeedParam
<
P
>
,
ResizeKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
typealias
OpType
=
FeedOp
<
P
>
typealias
OpType
=
FeedOp
<
P
>
func
inferShape
()
{
func
inferShape
()
{
// print("feed input: \(para.input.expectDim)")
// print("feed input: \(para.input.expectDim)")
print
(
"feed output:
\(
para
.
output
.
dim
)
"
)
print
(
"feed output:
\(
para
.
output
.
dim
)
"
)
// para.output.dim =
// para.ou
/
tput.dim = para.input.expectDim
// para.output.dim = para.input.expectDim
}
}
func
runImpl
()
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"feed op"
)
let
resizeKernel
=
ResizeKernel
<
P
>.
init
(
device
:
device
)
// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>)
let
resizeParam
=
ResizeParam
.
init
(
input
:
para
.
input
.
mtlTexture
,
output
:
para
.
output
.
metalTexture
,
expectDim
:
para
.
input
.
expectDim
)
do
{
print
(
"feed op to compute "
)
try
resizeKernel
.
compute
(
commandBuffer
:
buffer
,
param
:
resizeParam
)
print
(
"feed op end compute "
)
}
catch
let
error
{
throw
error
}
}
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
浏览文件 @
64a171ca
...
@@ -16,7 +16,7 @@ import Foundation
...
@@ -16,7 +16,7 @@ import Foundation
struct
FetchParam
<
P
:
PrecisionType
>
:
OpParam
{
struct
FetchParam
<
P
:
PrecisionType
>
:
OpParam
{
var
output
:
ResultHolder
<
P
>
=
ResultHolder
.
init
(
inDim
:
[],
inResult
:
[])
var
output
:
ResultHolder
<
P
>
=
ResultHolder
.
init
(
inDim
:
[],
inResult
:
[])
let
input
:
Texture
let
input
:
Texture
<
P
>
let
scope
:
Scope
let
scope
:
Scope
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
scope
=
inScope
scope
=
inScope
...
@@ -30,14 +30,14 @@ struct FetchParam<P: PrecisionType>: OpParam{
...
@@ -30,14 +30,14 @@ struct FetchParam<P: PrecisionType>: OpParam{
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
}
}
class
FetchOp
<
P
:
PrecisionType
>
:
Operator
<
FetchParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
FetchOp
<
P
:
PrecisionType
>
:
Operator
<
FetchParam
<
P
>
,
ResizeKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
func
inferShape
()
{
print
(
para
.
input
.
dim
)
print
(
para
.
input
.
dim
)
}
}
typealias
OpType
=
FetchOp
<
P
>
typealias
OpType
=
FetchOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"fetch op"
)
print
(
"fetch op"
)
}
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
0 → 100644
浏览文件 @
64a171ca
//
// BatchNormKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
BatchNormKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"batchnorm"
)
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
BatchNormParam
<
P
>
)
throws
{
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
0 → 100644
浏览文件 @
64a171ca
//
// ConvKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ConvKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ConvParam
<
P
>
)
throws
{
}
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"conv"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift
0 → 100644
浏览文件 @
64a171ca
//
// ElementwiseAddKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ElementwiseAddKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"conv"
)
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ElementwiseAddParam
<
P
>
)
throws
{
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift
浏览文件 @
64a171ca
...
@@ -18,6 +18,12 @@ import Foundation
...
@@ -18,6 +18,12 @@ import Foundation
protocol
Computable
{
protocol
Computable
{
associatedtype
ParamType
associatedtype
ParamType
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ParamType
)
throws
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ParamType
)
throws
init
(
device
:
MTLDevice
)
}
protocol
KernelProtocol
{
var
pipline
:
MTLComputePipelineState
{
get
set
}
var
functionName
:
String
{
get
set
}
}
}
class
Kernel
{
class
Kernel
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal
浏览文件 @
64a171ca
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Kernels.metal
// paddle-mobile
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/7/4.
You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
//
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include <metal_stdlib>
using namespace metal;
using namespace metal;
...
@@ -16,19 +22,70 @@ struct OutputDim {
...
@@ -16,19 +22,70 @@ struct OutputDim {
ushort strideY;
ushort strideY;
};
};
kernel void resize(
kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
texture2d<half, access::write> outTexture [[texture(1)]],
constant OutputDim ¶ms [[buffer(0)]],
constant OutputDim ¶ms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
gid.y >= outTexture.get_height() ||
return;
gid.z >= outTexture.get_array_size()) return;
}
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos);
const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z,
0.0h), gid
);
outTexture.write(half4(input.x, input.y, input.z,
input.w), gid.xy, gid.z
);
}
}
kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float4 relu = fmax((float4)input, 0.0);
outTexture.write(half4(relu), gid.xy, gid.z);
}
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift
0 → 100644
浏览文件 @
64a171ca
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
class
ReluKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ReluParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encode is nil"
)
}
print
(
" the usage of input of relu
\(
param
.
input
.
metalTexture
.
usage
)
"
)
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
}
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"relu"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift
浏览文件 @
64a171ca
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// ResizeKernel.swift
// paddle-mobile
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/7/4.
You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
//
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
import
Foundation
...
@@ -22,15 +28,14 @@ struct OutputDim {
...
@@ -22,15 +28,14 @@ struct OutputDim {
let
strideY
:
UInt16
let
strideY
:
UInt16
}
}
class
ResizeKernel
:
Kernel
,
Computable
{
class
ResizeKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ResizeParam
)
throws
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ResizeParam
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encode is nil"
)
throw
PaddleMobileError
.
predictError
(
message
:
" encode is nil"
)
}
}
encoder
.
setTexture
(
param
.
input
,
index
:
0
)
encoder
.
setTexture
(
param
.
input
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
,
index
:
1
)
encoder
.
setTexture
(
param
.
output
,
index
:
1
)
let
strideX
=
param
.
input
.
width
/
param
.
expectDim
[
2
]
let
strideX
=
param
.
input
.
width
/
param
.
expectDim
[
2
]
let
strideY
=
param
.
input
.
height
/
param
.
expectDim
[
1
]
let
strideY
=
param
.
input
.
height
/
param
.
expectDim
[
1
]
var
outputDim
=
OutputDim
.
init
(
width
:
UInt16
(
param
.
expectDim
[
1
]),
height
:
UInt16
(
param
.
expectDim
[
2
]),
strideX
:
UInt16
(
strideX
),
strideY
:
UInt16
(
strideY
))
var
outputDim
=
OutputDim
.
init
(
width
:
UInt16
(
param
.
expectDim
[
1
]),
height
:
UInt16
(
param
.
expectDim
[
2
]),
strideX
:
UInt16
(
strideX
),
strideY
:
UInt16
(
strideY
))
...
@@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{
...
@@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{
encoder
.
endEncoding
()
encoder
.
endEncoding
()
}
}
init
(
device
:
MTLDevice
)
{
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"resize"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"resize"
)
}
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
浏览文件 @
64a171ca
...
@@ -24,19 +24,23 @@ struct ReluParam<P: PrecisionType>: OpParam {
...
@@ -24,19 +24,23 @@ struct ReluParam<P: PrecisionType>: OpParam {
throw
error
throw
error
}
}
}
}
let
input
:
Texture
let
input
:
Texture
<
P
>
var
output
:
Texture
var
output
:
Texture
<
P
>
}
}
class
ReluOp
<
P
:
PrecisionType
>
:
Operator
<
ReluParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ReluOp
<
P
:
PrecisionType
>
:
Operator
<
ReluParam
<
P
>
,
ReluKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
para
.
output
.
dim
=
para
.
input
.
dim
}
}
typealias
OpType
=
ReluOp
<
P
>
typealias
OpType
=
ReluOp
<
P
>
func
runImpl
()
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is ReluOp"
)
do
{
try
kernel
.
compute
(
commandBuffer
:
buffer
,
param
:
para
)
}
catch
let
error
{
throw
error
}
}
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/framework/Texture.swift
浏览文件 @
64a171ca
...
@@ -38,7 +38,7 @@ extension InputTexture {
...
@@ -38,7 +38,7 @@ extension InputTexture {
}
}
}
}
public
class
Texture
:
Tensorial
{
public
class
Texture
<
P
:
PrecisionType
>
:
Tensorial
{
var
dim
:
Dim
var
dim
:
Dim
let
textureDesc
:
MTLTextureDescriptor
let
textureDesc
:
MTLTextureDescriptor
var
metalTexture
:
MTLTexture
var
metalTexture
:
MTLTexture
...
@@ -61,7 +61,15 @@ public class Texture: Tensorial {
...
@@ -61,7 +61,15 @@ public class Texture: Tensorial {
}
else
{
}
else
{
fatalError
(
" didn't support yet"
)
fatalError
(
" didn't support yet"
)
}
}
tmpTextureDes
.
pixelFormat
=
.
r32Float
if
MemoryLayout
<
P
>.
size
==
1
{
tmpTextureDes
.
pixelFormat
=
.
r8Sint
}
else
if
MemoryLayout
<
P
>.
size
==
2
{
tmpTextureDes
.
pixelFormat
=
.
r16Float
}
else
if
MemoryLayout
<
P
>.
size
==
4
{
tmpTextureDes
.
pixelFormat
=
.
r32Float
}
tmpTextureDes
.
usage
=
.
unknown
tmpTextureDes
.
storageMode
=
.
shared
tmpTextureDes
.
storageMode
=
.
shared
textureDesc
=
tmpTextureDes
textureDesc
=
tmpTextureDes
metalTexture
=
device
.
makeTexture
(
descriptor
:
tmpTextureDes
)
?
!
" texture nil "
metalTexture
=
device
.
makeTexture
(
descriptor
:
tmpTextureDes
)
?
!
" texture nil "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录