Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
92849d43
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92849d43
编写于
7月 05, 2018
作者:
R
Ruilong Liu
提交者:
GitHub
7月 05, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #521 from codeWorm2015/metal
add relu kernel
上级
52d5f1f7
46a96221
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
347 addition
and
93 deletion
+347
-93
metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
...liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
+1
-1
metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
...addle-mobile-demo/paddle-mobile-demo/ViewController.swift
+3
-3
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
+16
-0
metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
...liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
+1
-1
metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
...l/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
+7
-1
metal/paddle-mobile/paddle-mobile/Common/Types.swift
metal/paddle-mobile/paddle-mobile/Common/Types.swift
+11
-3
metal/paddle-mobile/paddle-mobile/Executor.swift
metal/paddle-mobile/paddle-mobile/Executor.swift
+26
-4
metal/paddle-mobile/paddle-mobile/Loader.swift
metal/paddle-mobile/paddle-mobile/Loader.swift
+21
-8
metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
...addle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
+3
-3
metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
...paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
+24
-13
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
...l/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
+4
-4
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
+4
-4
metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
...dle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
+4
-4
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
+14
-7
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
+3
-3
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
...ile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
+19
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
...e-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
+20
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift
...addle-mobile/Operators/Kernels/ElementwiseAddKernel.swift
+20
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift
...addle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift
+6
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal
...ddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal
+73
-16
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift
...e-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift
+32
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift
...mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift
+16
-11
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
+9
-5
metal/paddle-mobile/paddle-mobile/framework/Texture.swift
metal/paddle-mobile/paddle-mobile/framework/Texture.swift
+10
-2
未找到文件。
metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
浏览文件 @
92849d43
...
...
@@ -7,7 +7,7 @@
<key>
paddle-mobile-demo.xcscheme
</key>
<dict>
<key>
orderHint
</key>
<integer>
4
</integer>
<integer>
3
</integer>
</dict>
</dict>
</dict>
...
...
metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
浏览文件 @
92849d43
...
...
@@ -36,13 +36,13 @@ class ViewController: UIViewController {
fatalError
(
" texture is nil !"
)
}
let
loader
=
Loader
<
Float
>.
init
()
let
loader
=
Loader
<
Float
16
>.
init
()
do
{
let
modelPath
=
Bundle
.
main
.
path
(
forResource
:
"model"
,
ofType
:
nil
)
?
!
"model null"
let
paraPath
=
Bundle
.
main
.
path
(
forResource
:
"params"
,
ofType
:
nil
)
?
!
"para null"
let
program
=
try
loader
.
load
(
device
:
device
,
modelPath
:
modelPath
,
paraPath
:
paraPath
)
let
executor
=
try
Executor
<
Float
>.
init
(
inProgram
:
program
)
let
output
=
try
executor
.
predict
(
input
:
inTexture
,
expect
:
[
1
,
22
4
,
224
,
3
])
let
executor
=
try
Executor
<
Float
16
>.
init
(
inDevice
:
device
,
inQueue
:
queue
!
,
inProgram
:
program
)
let
output
=
try
executor
.
predict
(
input
:
inTexture
,
expect
:
[
1
,
22
7
,
227
,
3
])
print
(
output
)
}
catch
let
error
{
print
(
error
)
...
...
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
浏览文件 @
92849d43
...
...
@@ -30,6 +30,10 @@
FC039BBE20E11CC20081E9F8
/* OpDesc.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB520E11CC20081E9F8
/* OpDesc.swift */
;
};
FC039BBF20E11CC20081E9F8
/* Attribute.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB620E11CC20081E9F8
/* Attribute.swift */
;
};
FC039BC020E11CC20081E9F8
/* BlockDesc.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC039BB720E11CC20081E9F8
/* BlockDesc.swift */
;
};
FC0E2DBA20EE3B8D009C1FAC
/* ReluKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DB920EE3B8D009C1FAC
/* ReluKernel.swift */
;
};
FC0E2DBC20EE45FE009C1FAC
/* ConvKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DBB20EE45FE009C1FAC
/* ConvKernel.swift */
;
};
FC0E2DBE20EE460D009C1FAC
/* BatchNormKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DBD20EE460D009C1FAC
/* BatchNormKernel.swift */
;
};
FC0E2DC020EE461F009C1FAC
/* ElementwiseAddKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC0E2DBF20EE461F009C1FAC
/* ElementwiseAddKernel.swift */
;
};
FC1B16B320EC9A4F00678B91
/* Kernels.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
;
};
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
;
};
FC60DB8920E9AAA500FF203F
/* MetalExtension.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
FC60DB8820E9AAA500FF203F
/* MetalExtension.swift */
;
};
...
...
@@ -69,6 +73,10 @@
FC039BB520E11CC20081E9F8
/* OpDesc.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
OpDesc.swift
;
sourceTree
=
"<group>"
;
};
FC039BB620E11CC20081E9F8
/* Attribute.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
Attribute.swift
;
sourceTree
=
"<group>"
;
};
FC039BB720E11CC20081E9F8
/* BlockDesc.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
BlockDesc.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DB920EE3B8D009C1FAC
/* ReluKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ReluKernel.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DBB20EE45FE009C1FAC
/* ConvKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ConvKernel.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DBD20EE460D009C1FAC
/* BatchNormKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
BatchNormKernel.swift
;
sourceTree
=
"<group>"
;
};
FC0E2DBF20EE461F009C1FAC
/* ElementwiseAddKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ElementwiseAddKernel.swift
;
sourceTree
=
"<group>"
;
};
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.metal
;
path
=
Kernels.metal
;
sourceTree
=
"<group>"
;
};
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ResizeKernel.swift
;
sourceTree
=
"<group>"
;
};
FC60DB8820E9AAA500FF203F
/* MetalExtension.swift */
=
{
isa
=
PBXFileReference
;
lastKnownFileType
=
sourcecode.swift
;
path
=
MetalExtension.swift
;
sourceTree
=
"<group>"
;
};
...
...
@@ -197,9 +205,13 @@
FC086BA520E67E8500D85EF7
/* Kernels */
=
{
isa
=
PBXGroup
;
children
=
(
FC0E2DBB20EE45FE009C1FAC
/* ConvKernel.swift */
,
FCF2D73720E64E70007AC5F5
/* Kernel.swift */
,
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
,
FC1B186520ECF1C600678B91
/* ResizeKernel.swift */
,
FC0E2DB920EE3B8D009C1FAC
/* ReluKernel.swift */
,
FC0E2DBD20EE460D009C1FAC
/* BatchNormKernel.swift */
,
FC0E2DBF20EE461F009C1FAC
/* ElementwiseAddKernel.swift */
,
);
path
=
Kernels
;
sourceTree
=
"<group>"
;
...
...
@@ -316,12 +328,14 @@
files
=
(
FC9D038020E22FBB000F735A
/* FeedOp.swift in Sources */
,
FC039B9F20E11CB20081E9F8
/* Tensor.swift in Sources */
,
FC0E2DBC20EE45FE009C1FAC
/* ConvKernel.swift in Sources */
,
FC039BAA20E11CBC0081E9F8
/* ElementwiseAddOp.swift in Sources */
,
FC039B9B20E11CA00081E9F8
/* Executor.swift in Sources */
,
FC039BBB20E11CC20081E9F8
/* ProgramDesc.swift in Sources */
,
FC9D037920E229E4000F735A
/* OpParam.swift in Sources */
,
FC1B186620ECF1C600678B91
/* ResizeKernel.swift in Sources */
,
FCF2D73820E64E70007AC5F5
/* Kernel.swift in Sources */
,
FC0E2DC020EE461F009C1FAC
/* ElementwiseAddKernel.swift in Sources */
,
FC60DB8920E9AAA500FF203F
/* MetalExtension.swift in Sources */
,
FC1B16B320EC9A4F00678B91
/* Kernels.metal in Sources */
,
FC039BBA20E11CC20081E9F8
/* TensorDesc.swift in Sources */
,
...
...
@@ -335,7 +349,9 @@
FC039BB920E11CC20081E9F8
/* Scope.swift in Sources */
,
FC039BAC20E11CBC0081E9F8
/* BatchNormOp.swift in Sources */
,
FC039BBC20E11CC20081E9F8
/* VarDesc.swift in Sources */
,
FC0E2DBA20EE3B8D009C1FAC
/* ReluKernel.swift in Sources */
,
FC82735920E3C04200BE430A
/* OpCreator.swift in Sources */
,
FC0E2DBE20EE460D009C1FAC
/* BatchNormKernel.swift in Sources */
,
FC039BAB20E11CBC0081E9F8
/* Operator.swift in Sources */
,
FC9D038220E2312E000F735A
/* FetchOp.swift in Sources */
,
FC039BBD20E11CC20081E9F8
/* Program.swift in Sources */
,
...
...
metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist
浏览文件 @
92849d43
...
...
@@ -7,7 +7,7 @@
<key>
paddle-mobile.xcscheme
</key>
<dict>
<key>
orderHint
</key>
<integer>
3
</integer>
<integer>
4
</integer>
</dict>
</dict>
</dict>
...
...
metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift
浏览文件 @
92849d43
...
...
@@ -29,11 +29,11 @@ extension MTLDevice {
fatalError
(
"Counld't find paddle mobile library"
)
}
do
{
print
(
path
)
paddleMobileMetalLibrary
=
try
makeLibrary
(
filepath
:
path
)
}
catch
_
{
fatalError
(
"Counld't load paddle mobile library"
)
}
paddleMobileMetalLibrary
=
makeDefaultLibrary
()
}
if
let
inPaddleMobileLib
=
paddleMobileMetalLibrary
{
...
...
@@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder {
let
height
=
computePipline
.
maxTotalThreadsPerThreadgroup
/
width
let
threadsPerGroup
=
MTLSize
.
init
(
width
:
width
,
height
:
height
,
depth
:
1
)
print
(
" threads per group:
\(
threadsPerGroup
)
"
)
print
(
" out texture width:
\(
outTexture
.
width
)
, out texture height:
\(
outTexture
.
height
)
"
)
let
groupWidth
=
(
outTexture
.
width
+
width
-
1
)
/
width
let
groupHeight
=
(
outTexture
.
height
+
height
-
1
)
/
height
let
groupDepth
=
slices
let
groups
=
MTLSize
.
init
(
width
:
groupWidth
,
height
:
groupHeight
,
depth
:
groupDepth
)
print
(
"groups:
\(
groups
)
"
)
setComputePipelineState
(
computePipline
)
dispatchThreadgroups
(
groups
,
threadsPerThreadgroup
:
threadsPerGroup
)
}
...
...
metal/paddle-mobile/paddle-mobile/Common/Types.swift
浏览文件 @
92849d43
...
...
@@ -14,14 +14,22 @@
import
Foundation
//typealias Float16 = Int16
//extension Float16: PrecisionType {
//}
public
typealias
Float16
=
Int16
extension
Float16
:
PrecisionType
{
public
init
(
inFloat
:
Float32
)
{
self
=
Int16
(
inFloat
)
}
}
public
protocol
PrecisionType
{
init
(
inFloat
:
Float32
)
}
extension
Float32
:
PrecisionType
{
public
init
(
inFloat
:
Float32
)
{
self
=
inFloat
}
}
public
enum
DataLayout
{
...
...
metal/paddle-mobile/paddle-mobile/Executor.swift
浏览文件 @
92849d43
...
...
@@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
public
class
Executor
<
P
:
PrecisionType
>
{
var
ops
:
[
Runable
&
InferShaperable
]
=
[]
let
program
:
Program
public
init
(
inProgram
:
Program
)
throws
{
let
device
:
MTLDevice
let
queue
:
MTLCommandQueue
public
init
(
inDevice
:
MTLDevice
,
inQueue
:
MTLCommandQueue
,
inProgram
:
Program
)
throws
{
program
=
inProgram
device
=
inDevice
queue
=
inQueue
for
block
in
inProgram
.
programDesc
.
blocks
{
for
op
in
block
.
ops
{
do
{
let
op
=
try
OpCreator
<
P
>.
shared
.
creat
(
opDesc
:
op
,
scope
:
inProgram
.
scope
)
let
op
=
try
OpCreator
<
P
>.
shared
.
creat
(
device
:
inDevice
,
opDesc
:
op
,
scope
:
inProgram
.
scope
)
op
.
inferShape
()
ops
.
append
(
op
)
}
catch
let
error
{
...
...
@@ -65,12 +68,29 @@ public class Executor<P: PrecisionType> {
}
public
func
predict
(
input
:
MTLTexture
,
expect
:
[
Int
])
throws
->
ResultHolder
<
P
>
{
let
beforeDate
=
Date
.
init
()
let
inputTexture
=
InputTexture
.
init
(
inMTLTexture
:
input
,
inExpectDim
:
Dim
.
init
(
inDim
:
expect
))
program
.
scope
.
setInput
(
input
:
inputTexture
)
guard
let
buffer
=
queue
.
makeCommandBuffer
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
"CommandBuffer is nil"
)
}
for
op
in
ops
{
op
.
run
()
do
{
try
op
.
run
(
device
:
device
,
buffer
:
buffer
)
}
catch
let
error
{
throw
error
}
}
buffer
.
addCompletedHandler
{
(
commandbuffer
)
in
let
afterDate
=
Date
.
init
()
print
(
afterDate
.
timeIntervalSince
(
beforeDate
))
print
(
" encoder end ! "
)
}
buffer
.
commit
()
guard
let
outputVar
=
program
.
scope
.
output
()
else
{
throw
PaddleMobileError
.
netError
(
message
:
"output nil"
)
}
...
...
@@ -78,6 +98,8 @@ public class Executor<P: PrecisionType> {
guard
let
output
=
outputVar
as?
ResultHolder
<
P
>
else
{
throw
PaddleMobileError
.
netError
(
message
:
"output var type error"
)
}
return
output
}
...
...
metal/paddle-mobile/paddle-mobile/Loader.swift
浏览文件 @
92849d43
...
...
@@ -68,11 +68,24 @@ public class Loader<P: PrecisionType> {
/*
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
*/
let
bytesRead
=
fread
(
tensor
.
data
.
pointer
,
1
,
tensor
.
data
.
size
,
file
)
guard
bytesRead
==
tensor
.
data
.
size
else
{
throw
PaddleMobileError
.
loaderError
(
message
:
"param read size error"
)
//现在模型传入模型为 Float 类型, 这块应该根据模型来
let
tmpCapacity
=
MemoryLayout
<
Float
>.
size
*
tensor
.
numel
()
let
tmpPointer
=
UnsafeMutablePointer
<
Float
>.
allocate
(
capacity
:
tmpCapacity
);
// let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file)
// guard bytesRead == tensor.data.size else {
// throw PaddleMobileError.loaderError(message: "param read size error")
// }
// TODO: use script to convert
let
bytesRead
=
fread
(
tmpPointer
,
1
,
tmpCapacity
,
file
)
for
i
in
0
..<
tensor
.
numel
()
{
tensor
.
data
[
i
]
=
P
.
init
(
inFloat
:
tmpPointer
[
i
])
}
tmpPointer
.
deinitialize
(
count
:
tmpCapacity
)
tmpPointer
.
deallocate
()
nowIndex
+=
bytesRead
}
...
...
@@ -125,9 +138,9 @@ public class Loader<P: PrecisionType> {
throw
PaddleMobileError
.
loaderError
(
message
:
"get tensor desc failed"
)
}
guard
(
try
?
tensorDesc
.
dataType
.
dataTypeSize
())
==
MemoryLayout
<
P
>.
size
else
{
throw
PaddleMobileError
.
memoryError
(
message
:
"PrecisionType not support"
)
}
//
guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
//
throw PaddleMobileError.memoryError(message: "PrecisionType not support")
//
}
if
(
varDesc
.
persistable
&&
varDesc
.
type
!=
.
FeedMiniBatch
...
...
@@ -149,7 +162,7 @@ public class Loader<P: PrecisionType> {
scope
[
varDesc
.
name
]
=
tensor
}
else
{
let
dim
=
Dim
.
init
(
inDim
:
tensorDesc
.
NHWCDim
)
scope
[
varDesc
.
name
]
=
Texture
.
init
(
device
:
device
,
inDim
:
dim
)
scope
[
varDesc
.
name
]
=
Texture
<
P
>
.
init
(
device
:
device
,
inDim
:
dim
)
}
}
else
{
if
varDesc
.
name
==
fetchKey
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift
浏览文件 @
92849d43
...
...
@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> {
}
}
func
creat
(
opDesc
:
OpDesc
,
scope
:
Scope
)
throws
->
Runable
&
InferShaperable
{
func
creat
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
scope
:
Scope
)
throws
->
Runable
&
InferShaperable
{
guard
let
opCreator
=
opCreators
[
opDesc
.
type
]
else
{
throw
PaddleMobileError
.
opError
(
message
:
"there is no "
+
opDesc
.
type
+
" yet"
)
}
do
{
return
try
opCreator
(
opDesc
,
scope
)
return
try
opCreator
(
device
,
opDesc
,
scope
)
}
catch
let
error
{
throw
error
}
}
let
opCreators
:
[
String
:
(
OpDesc
,
Scope
)
throws
->
Runable
&
InferShaperable
]
=
let
opCreators
:
[
String
:
(
MTLDevice
,
OpDesc
,
Scope
)
throws
->
Runable
&
InferShaperable
]
=
[
gConvType
:
ConvOp
<
P
>.
creat
,
gBatchNormType
:
BatchNormOp
<
P
>.
creat
,
gReluType
:
ReluOp
<
P
>.
creat
,
...
...
metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
浏览文件 @
92849d43
...
...
@@ -12,29 +12,35 @@
See the License for the specific language governing permissions and
limitations under the License. */
import
Metal
import
Foundation
protocol
Runable
{
func
run
(
)
func
runImpl
(
)
func
run
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
}
extension
Runable
where
Self
:
OperatorProtocol
{
func
run
()
{
runImpl
()
func
run
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
do
{
try
runImpl
(
device
:
device
,
buffer
:
buffer
)
}
catch
let
error
{
throw
error
}
print
(
type
+
": "
+
para
.
outputDesc
())
}
}
protocol
Creator
where
Self
:
OperatorProtocol
{
associatedtype
OpType
:
OperatorProtocol
&
Runable
&
InferShaperable
static
func
creat
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
static
func
creat
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
}
extension
Creator
where
Self
:
OperatorProtocol
{
static
func
creat
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
{
static
func
creat
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
OpType
{
do
{
return
try
OpType
.
provide
(
opDesc
:
opDesc
,
inScope
:
inScope
)
return
try
OpType
.
provide
(
device
:
device
,
opDesc
:
opDesc
,
inScope
:
inScope
)
}
catch
let
error
{
throw
error
}
...
...
@@ -47,19 +53,21 @@ protocol InferShaperable {
protocol
OperatorProtocol
{
associatedtype
ParamType
:
OpParam
associatedtype
KerType
:
Computable
var
type
:
String
{
get
}
var
inputs
:
[
String
:
[
String
]]
{
get
}
var
paraInputs
:
[
String
:
[
String
]]
{
get
}
var
outpus
:
[
String
:
[
String
]]
{
get
}
var
attrs
:
[
String
:
Attr
]
{
get
}
var
para
:
ParamType
{
get
}
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
var
kernel
:
KerType
{
get
}
init
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
}
extension
OperatorProtocol
{
static
func
provide
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
Self
{
static
func
provide
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
->
Self
{
do
{
return
try
Self
.
init
(
opDesc
:
opDesc
,
inScope
:
inScope
)
return
try
Self
.
init
(
device
:
device
,
opDesc
:
opDesc
,
inScope
:
inScope
)
}
catch
let
error
{
throw
error
}
...
...
@@ -67,20 +75,23 @@ extension OperatorProtocol {
}
class
Operator
<
ParameterType
:
OpParam
>
:
OperatorProtocol
{
typealias
ParamType
=
ParameterType
class
Operator
<
ParameterType
:
OpParam
,
KernelType
:
Computable
>
:
OperatorProtocol
{
typealias
ParamType
=
ParameterType
typealias
KerType
=
KernelType
let
type
:
String
let
inputs
:
[
String
:
[
String
]]
let
paraInputs
:
[
String
:
[
String
]]
let
outpus
:
[
String
:
[
String
]]
let
attrs
:
[
String
:
Attr
]
let
para
:
ParamType
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
var
kernel
:
KerType
required
init
(
device
:
MTLDevice
,
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
type
=
opDesc
.
type
inputs
=
opDesc
.
inputs
outpus
=
opDesc
.
outputs
attrs
=
opDesc
.
attrs
paraInputs
=
opDesc
.
paraInputs
kernel
=
KerType
.
init
(
device
:
device
)
do
{
para
=
try
ParamType
.
init
(
opDesc
:
opDesc
,
inScope
:
inScope
)
}
catch
let
error
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
浏览文件 @
92849d43
...
...
@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
throw
error
}
}
let
input
:
Texture
var
output
:
Texture
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
inputBias
:
Tensor
<
ParamPrecisionType
>
let
inputMean
:
Tensor
<
ParamPrecisionType
>
let
inputScale
:
Tensor
<
ParamPrecisionType
>
...
...
@@ -42,12 +42,12 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let
is_test
:
Bool
}
class
BatchNormOp
<
P
:
PrecisionType
>
:
Operator
<
BatchNormParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
BatchNormOp
<
P
:
PrecisionType
>
:
Operator
<
BatchNormParam
<
P
>
,
BatchNormKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
}
typealias
OpType
=
BatchNormOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is BatchNormOp"
)
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
浏览文件 @
92849d43
...
...
@@ -30,8 +30,8 @@ struct ConvParam<P: PrecisionType>: OpParam {
}
}
let
input
:
Texture
var
output
:
Texture
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
filter
:
Tensor
<
ParamPrecisionType
>
let
stride
:
[
Int32
]
let
paddings
:
[
Int32
]
...
...
@@ -39,7 +39,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
let
groups
:
Int
}
class
ConvOp
<
P
:
PrecisionType
>
:
Operator
<
ConvParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ConvOp
<
P
:
PrecisionType
>
:
Operator
<
ConvParam
<
P
>
,
ConvKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
let
inDims
=
para
.
input
.
dim
let
filterDim
=
para
.
filter
.
dim
...
...
@@ -63,7 +63,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferS
}
typealias
OpType
=
ConvOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is conv"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift
浏览文件 @
92849d43
...
...
@@ -26,20 +26,20 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
throw
error
}
}
let
input
:
Texture
let
input
:
Texture
<
P
>
let
inputY
:
Tensor
<
P
>
var
output
:
Texture
var
output
:
Texture
<
P
>
let
axis
:
Int
}
class
ElementwiseAddOp
<
P
:
PrecisionType
>
:
Operator
<
ElementwiseAddParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ElementwiseAddOp
<
P
:
PrecisionType
>
:
Operator
<
ElementwiseAddParam
<
P
>
,
ElementwiseAddKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
}
typealias
OpType
=
ElementwiseAddOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"this is ElementwiseAddOp"
)
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
浏览文件 @
92849d43
...
...
@@ -15,7 +15,7 @@
import
Foundation
struct
FeedParam
<
P
:
PrecisionType
>
:
OpParam
{
var
output
:
Texture
var
output
:
Texture
<
P
>
var
input
:
InputTexture
{
return
scope
.
input
()
as!
InputTexture
}
...
...
@@ -33,19 +33,26 @@ struct FeedParam<P: PrecisionType>: OpParam{
typealias
ParamPrecisionType
=
P
}
class
FeedOp
<
P
:
PrecisionType
>
:
Operator
<
FeedParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
FeedOp
<
P
:
PrecisionType
>
:
Operator
<
FeedParam
<
P
>
,
ResizeKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
typealias
OpType
=
FeedOp
<
P
>
func
inferShape
()
{
// print("feed input: \(para.input.expectDim)")
print
(
"feed output:
\(
para
.
output
.
dim
)
"
)
// para.ou
/
tput.dim = para.input.expectDim
// para.output.dim =
// para.output.dim = para.input.expectDim
}
func
runImpl
()
{
print
(
"feed op"
)
// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>)
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
let
resizeKernel
=
ResizeKernel
<
P
>.
init
(
device
:
device
)
let
resizeParam
=
ResizeParam
.
init
(
input
:
para
.
input
.
mtlTexture
,
output
:
para
.
output
.
metalTexture
,
expectDim
:
para
.
input
.
expectDim
)
do
{
print
(
"feed op to compute "
)
try
resizeKernel
.
compute
(
commandBuffer
:
buffer
,
param
:
resizeParam
)
print
(
"feed op end compute "
)
}
catch
let
error
{
throw
error
}
}
}
metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
浏览文件 @
92849d43
...
...
@@ -16,7 +16,7 @@ import Foundation
struct
FetchParam
<
P
:
PrecisionType
>
:
OpParam
{
var
output
:
ResultHolder
<
P
>
=
ResultHolder
.
init
(
inDim
:
[],
inResult
:
[])
let
input
:
Texture
let
input
:
Texture
<
P
>
let
scope
:
Scope
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
scope
=
inScope
...
...
@@ -30,14 +30,14 @@ struct FetchParam<P: PrecisionType>: OpParam{
typealias
ParamPrecisionType
=
P
}
class
FetchOp
<
P
:
PrecisionType
>
:
Operator
<
FetchParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
FetchOp
<
P
:
PrecisionType
>
:
Operator
<
FetchParam
<
P
>
,
ResizeKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
print
(
para
.
input
.
dim
)
}
typealias
OpType
=
FetchOp
<
P
>
func
runImpl
(
)
{
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
print
(
"fetch op"
)
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
0 → 100644
浏览文件 @
92849d43
//
// BatchNormKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
BatchNormKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"batchnorm"
)
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
BatchNormParam
<
P
>
)
throws
{
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
0 → 100644
浏览文件 @
92849d43
//
// ConvKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ConvKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ConvParam
<
P
>
)
throws
{
}
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"conv"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift
0 → 100644
浏览文件 @
92849d43
//
// ElementwiseAddKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import
Foundation
class
ElementwiseAddKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"conv"
)
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ElementwiseAddParam
<
P
>
)
throws
{
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift
浏览文件 @
92849d43
...
...
@@ -18,6 +18,12 @@ import Foundation
protocol
Computable
{
associatedtype
ParamType
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ParamType
)
throws
init
(
device
:
MTLDevice
)
}
protocol
KernelProtocol
{
var
pipline
:
MTLComputePipelineState
{
get
set
}
var
functionName
:
String
{
get
set
}
}
class
Kernel
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal
浏览文件 @
92849d43
//
// Kernels.metal
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
...
...
@@ -16,19 +22,70 @@ struct OutputDim {
ushort strideY;
};
kernel void resize(
texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::write> outTexture [[texture(1)]],
constant OutputDim ¶ms [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]) {
kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant OutputDim ¶ms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
return;
}
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z,
0.0h), gid
);
outTexture.write(half4(input.x, input.y, input.z,
input.w), gid.xy, gid.z
);
}
kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float4 relu = fmax((float4)input, 0.0);
outTexture.write(half4(relu), gid.xy, gid.z);
}
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift
0 → 100644
浏览文件 @
92849d43
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
class
ReluKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ReluParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encode is nil"
)
}
print
(
" the usage of input of relu
\(
param
.
input
.
metalTexture
.
usage
)
"
)
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
}
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"relu"
)
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift
浏览文件 @
92849d43
//
// ResizeKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
...
...
@@ -22,15 +28,14 @@ struct OutputDim {
let
strideY
:
UInt16
}
class
ResizeKernel
:
Kernel
,
Computable
{
class
ResizeKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
ResizeParam
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encode is nil"
)
}
encoder
.
setTexture
(
param
.
input
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
,
index
:
1
)
encoder
.
setTexture
(
param
.
output
,
index
:
1
)
let
strideX
=
param
.
input
.
width
/
param
.
expectDim
[
2
]
let
strideY
=
param
.
input
.
height
/
param
.
expectDim
[
1
]
var
outputDim
=
OutputDim
.
init
(
width
:
UInt16
(
param
.
expectDim
[
1
]),
height
:
UInt16
(
param
.
expectDim
[
2
]),
strideX
:
UInt16
(
strideX
),
strideY
:
UInt16
(
strideY
))
...
...
@@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{
encoder
.
endEncoding
()
}
init
(
device
:
MTLDevice
)
{
required
init
(
device
:
MTLDevice
)
{
super
.
init
(
device
:
device
,
inFunctionName
:
"resize"
)
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
浏览文件 @
92849d43
...
...
@@ -24,19 +24,23 @@ struct ReluParam<P: PrecisionType>: OpParam {
throw
error
}
}
let
input
:
Texture
var
output
:
Texture
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
}
class
ReluOp
<
P
:
PrecisionType
>
:
Operator
<
ReluParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ReluOp
<
P
:
PrecisionType
>
:
Operator
<
ReluParam
<
P
>
,
ReluKernel
<
P
>
>
,
Runable
,
Creator
,
InferShaperable
{
func
inferShape
()
{
para
.
output
.
dim
=
para
.
input
.
dim
}
typealias
OpType
=
ReluOp
<
P
>
func
runImpl
()
{
print
(
"this is ReluOp"
)
func
runImpl
(
device
:
MTLDevice
,
buffer
:
MTLCommandBuffer
)
throws
{
do
{
try
kernel
.
compute
(
commandBuffer
:
buffer
,
param
:
para
)
}
catch
let
error
{
throw
error
}
}
}
...
...
metal/paddle-mobile/paddle-mobile/framework/Texture.swift
浏览文件 @
92849d43
...
...
@@ -38,7 +38,7 @@ extension InputTexture {
}
}
public
class
Texture
:
Tensorial
{
public
class
Texture
<
P
:
PrecisionType
>
:
Tensorial
{
var
dim
:
Dim
let
textureDesc
:
MTLTextureDescriptor
var
metalTexture
:
MTLTexture
...
...
@@ -61,7 +61,15 @@ public class Texture: Tensorial {
}
else
{
fatalError
(
" didn't support yet"
)
}
tmpTextureDes
.
pixelFormat
=
.
r32Float
if
MemoryLayout
<
P
>.
size
==
1
{
tmpTextureDes
.
pixelFormat
=
.
r8Sint
}
else
if
MemoryLayout
<
P
>.
size
==
2
{
tmpTextureDes
.
pixelFormat
=
.
r16Float
}
else
if
MemoryLayout
<
P
>.
size
==
4
{
tmpTextureDes
.
pixelFormat
=
.
r32Float
}
tmpTextureDes
.
usage
=
.
unknown
tmpTextureDes
.
storageMode
=
.
shared
textureDesc
=
tmpTextureDes
metalTexture
=
device
.
makeTexture
(
descriptor
:
tmpTextureDes
)
?
!
" texture nil "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录