Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
e71320da
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e71320da
编写于
9月 12, 2018
作者:
D
dolphin8
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
batchnorm
上级
2bcb1135
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
334 addition
and
243 deletion
+334
-243
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
+12
-8
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
...l/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
+11
-8
metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift
+18
-1
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
...ile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
+17
-38
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift
...mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift
+3
-2
metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift
...obile/paddle-mobile/Operators/Kernels/FlattenKernel.swift
+71
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
...obile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
+4
-2
metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift
...-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift
+4
-1
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal
...ddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal
+13
-13
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal
...dle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal
+9
-14
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal
.../paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal
+24
-24
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal
...le-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal
+9
-10
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal
...paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal
+113
-114
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
+0
-3
metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift
+5
-3
metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift
metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift
+21
-2
未找到文件。
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
浏览文件 @
e71320da
...
...
@@ -16,8 +16,9 @@
4AA1EA92214665D700D0F791
/* ShapeOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA91214665D700D0F791
/* ShapeOp.swift */
;
};
4AA1EA942146661500D0F791
/* ShapeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA932146661500D0F791
/* ShapeKernel.swift */
;
};
4AA1EA982146666500D0F791
/* FlattenOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA972146666500D0F791
/* FlattenOp.swift */
;
};
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.metal.inc in Headers */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.metal.inc */
;
};
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.metal.inc in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.metal.inc */
;
};
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.inc.metal in Headers */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.inc.metal */
;
};
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.inc.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.inc.metal */
;
};
4AA1EAA2214912CD00D0F791
/* FlattenKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EAA1214912CC00D0F791
/* FlattenKernel.swift */
;
};
4AF928772133F1DB005B6C3A
/* BoxCoder.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF928762133F1DB005B6C3A
/* BoxCoder.metal */
;
};
4AF9287921341661005B6C3A
/* Softmax.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF9287821341661005B6C3A
/* Softmax.metal */
;
};
4AF928822135673D005B6C3A
/* ConcatKernel.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
;
};
...
...
@@ -126,8 +127,9 @@
4AA1EA91214665D700D0F791
/* ShapeOp.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ShapeOp.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA932146661500D0F791
/* ShapeKernel.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ShapeKernel.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA972146666500D0F791
/* FlattenOp.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FlattenOp.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.metal.inc */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ConcatKernel.metal.inc
;
sourceTree
=
"<group>"
;
};
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.metal.inc */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ReshapeKernel.metal.inc
;
sourceTree
=
"<group>"
;
};
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.inc.metal */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ConcatKernel.inc.metal
;
sourceTree
=
"<group>"
;
};
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.inc.metal */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ReshapeKernel.inc.metal
;
sourceTree
=
"<group>"
;
};
4AA1EAA1214912CC00D0F791
/* FlattenKernel.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FlattenKernel.swift
;
sourceTree
=
"<group>"
;
};
4AF928762133F1DB005B6C3A
/* BoxCoder.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
BoxCoder.metal
;
sourceTree
=
"<group>"
;
};
4AF9287821341661005B6C3A
/* Softmax.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
Softmax.metal
;
sourceTree
=
"<group>"
;
};
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
ConcatKernel.metal
;
sourceTree
=
"<group>"
;
};
...
...
@@ -395,6 +397,7 @@
FCD04E6720F315020007374F
/* PoolKernel.swift */
,
FCD04E6B20F31A280007374F
/* SoftmaxKernel.swift */
,
FCD04E6F20F31B720007374F
/* ReshapeKernel.swift */
,
4AA1EAA1214912CC00D0F791
/* FlattenKernel.swift */
,
FCD04E7320F3437E0007374F
/* ConvAddKernel.swift */
,
FCBCCC5A2122F66F00D94F7E
/* ConvBNReluKernel.swift */
,
FCBCCC602122FBDF00D94F7E
/* PriorBoxKernel.swift */
,
...
...
@@ -442,7 +445,7 @@
children
=
(
FC27990D21341016000B6BAD
/* BoxCoder.metal */
,
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
,
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.
metal.inc
*/
,
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.
inc.metal
*/
,
4AF9288321357BE3005B6C3A
/* Elementwise.metal */
,
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
,
FC4CB74820F0B954007C0C6D
/* ConvKernel.metal */
,
...
...
@@ -455,7 +458,7 @@
FCDDC6CB212FDFDB00E5EF74
/* ReluKernel.metal */
,
FCDDC6CE212FE14700E5EF74
/* PriorBoxKernel.metal */
,
FCA3A1622132A4AC00084FE5
/* ReshapeKernel.metal */
,
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.
metal.inc
*/
,
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.
inc.metal
*/
,
FCA3A1642132A5EB00084FE5
/* Common.metal */
,
FCA67B1621364EF000BD58AA
/* ConvTransposeKernel.metal */
,
FCA67CD42138272900BD58AA
/* ConvAddMetal.metal */
,
...
...
@@ -477,7 +480,7 @@
FC4FD9792140E4980073E130
/* PaddleMobile.h in Headers */
,
FC292C85214257CB00CF622F
/* CPUCompute.h in Headers */
,
FC292C5421421B2F00CF622F
/* PaddleMobileGPU.h in Headers */
,
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.
metal.inc
in Headers */
,
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.
inc.metal
in Headers */
,
FC039B6F20E11C3C0081E9F8
/* paddle_mobile.h in Headers */
,
);
runOnlyForDeploymentPostprocessing
=
0
;
...
...
@@ -617,6 +620,7 @@
FCBCCC592122F42700D94F7E
/* ConvBNReluOp.swift in Sources */
,
FC039BA920E11CBC0081E9F8
/* ConvOp.swift in Sources */
,
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
,
4AA1EAA2214912CD00D0F791
/* FlattenKernel.swift in Sources */
,
4AA1EA982146666500D0F791
/* FlattenOp.swift in Sources */
,
FCBCCC652122FCD700D94F7E
/* TransposeOp.swift in Sources */
,
FCD04E6E20F31B4B0007374F
/* ReshapeOp.swift in Sources */
,
...
...
@@ -657,7 +661,7 @@
FCBCCC67212306B000D94F7E
/* ConcatOp.swift in Sources */
,
FCD04E6C20F31A280007374F
/* SoftmaxKernel.swift in Sources */
,
FCEB684A212F00DB00D2448E
/* PreluKernel.metal in Sources */
,
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.
metal.inc
in Sources */
,
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.
inc.metal
in Sources */
,
FC9A19E32148C31300CD9CBF
/* MobilenetSSD_AR.swift in Sources */
,
FCDDC6CF212FE14700E5EF74
/* PriorBoxKernel.metal in Sources */
,
FC4CB74B20F12C30007C0C6D
/* ProgramOptimize.swift in Sources */
,
...
...
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
浏览文件 @
e71320da
...
...
@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam {
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
input
=
try
BatchNormParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
if
input
.
transpose
!=
[
0
,
2
,
3
,
1
]
{
fatalError
(
"batch norm only accepts NHWC"
)
}
output
=
try
BatchNormParam
.
outputY
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
inputBias
=
try
BatchNormParam
.
inputBiase
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
inputMean
=
try
BatchNormParam
.
inputMean
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
inputScale
=
try
BatchNormParam
.
inputScale
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
inputVariance
=
try
BatchNormParam
.
inputVariance
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
bias
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Bias"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
mean
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Mean"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
scale
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Scale"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
variance
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Variance"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
epsilon
=
try
BatchNormParam
.
getAttr
(
key
:
"epsilon"
,
attrs
:
opDesc
.
attrs
)
momentum
=
try
BatchNormParam
.
getAttr
(
key
:
"momentum"
,
attrs
:
opDesc
.
attrs
)
}
catch
let
error
{
...
...
@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam {
}
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
inputBias
:
Tensor
<
ParamPrecisionType
>
let
inputMean
:
Tensor
<
ParamPrecisionType
>
let
inputScale
:
Tensor
<
ParamPrecisionType
>
let
inputVariance
:
Tensor
<
ParamPrecisionType
>
let
bias
:
Tensor
<
P
>
let
mean
:
Tensor
<
P
>
let
scale
:
Tensor
<
P
>
let
variance
:
Tensor
<
P
>
let
epsilon
:
Float
let
momentum
:
Float
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift
浏览文件 @
e71320da
...
...
@@ -14,7 +14,24 @@
import
Foundation
class
FlattenOp
<
P
:
PrecisionType
>
:
Operator
<
ReshapeKernel
<
P
>
,
ReshapeParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
FlattenParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
input
=
try
FlattenParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
FlattenParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
axis
=
try
FlattenParam
.
getAttr
(
key
:
"axis"
,
attrs
:
opDesc
.
attrs
)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
axis
:
Int
}
class
FlattenOp
<
P
:
PrecisionType
>
:
Operator
<
FlattenKernel
<
P
>
,
FlattenParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
typealias
OpType
=
FlattenOp
<
P
>
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
浏览文件 @
e71320da
...
...
@@ -15,20 +15,20 @@
import
Foundation
class
BatchNormKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
// var newScale: MTLBuffer
// var newBias: MTLBuffer
//
required
init
(
device
:
MTLDevice
,
param
:
BatchNormParam
<
P
>
)
{
// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
// fatalError()
// }
//
// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
// fatalError()
// }
// self.newScale = newScale
// self.newBias = newBias
//
let
count
=
param
.
variance
.
dim
.
numel
()
let
varianceP
=
param
.
variance
.
data
.
pointer
let
meanP
=
param
.
mean
.
data
.
pointer
let
scaleP
=
param
.
scale
.
data
.
pointer
let
biasP
=
param
.
scale
.
data
.
pointer
for
i
in
0
..<
count
{
let
invStd
=
P
(
1
/
(
Float32
(
varianceP
[
i
])
+
param
.
epsilon
)
.
squareRoot
())
biasP
[
i
]
=
biasP
[
i
]
-
meanP
[
i
]
*
invStd
*
scaleP
[
i
]
scaleP
[
i
]
=
invStd
*
scaleP
[
i
]
}
param
.
bias
.
initBuffer
(
device
:
device
,
precision
:
computePrecision
)
param
.
scale
.
initBuffer
(
device
:
device
,
precision
:
computePrecision
)
param
.
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
input
.
transpose
,
computePrecision
:
computePrecision
)
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"batchnorm"
)
}
else
if
computePrecision
==
.
Float16
{
...
...
@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
}
else
{
fatalError
()
}
//
// let varianceBuffer : MTLBuffer = param.inputVariance.buffer
//
// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
// for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
// }
//
// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
// let scale : MTLBuffer = param.inputScale.buffer
// let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
// let bias : MTLBuffer = param.inputBias.buffer
// let biasContents = bias.contents().assumingMemoryBound(to: P.self)
// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
//
// for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
// }
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
BatchNormParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encoder is nil"
)
}
//
encoder.setTexture(param.input.metalTexture, index: 0)
//
encoder.setTexture(param.output.metalTexture, index: 1)
// encoder.setBuffer(newScale
, offset: 0, index: 0)
// encoder.setBuffer(newBias
, offset: 0, index: 1)
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
encoder
.
setBuffer
(
param
.
scale
.
buffer
,
offset
:
0
,
index
:
0
)
encoder
.
setBuffer
(
param
.
bias
.
buffer
,
offset
:
0
,
index
:
1
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift
浏览文件 @
e71320da
...
...
@@ -122,10 +122,11 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
required
init
(
device
:
MTLDevice
,
param
:
ConcatParam
<
P
>
)
{
param
.
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
transpose
,
computePrecision
:
computePrecision
)
let
orank
=
param
.
output
.
tensorDim
.
cout
()
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"concat"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"concat
_
\(
orank
)
_float
"
)
}
else
if
computePrecision
==
.
Float16
{
super
.
init
(
device
:
device
,
inFunctionName
:
"concat_half"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"concat_
\(
orank
)
_
half"
)
}
else
{
fatalError
()
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift
0 → 100644
浏览文件 @
e71320da
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
struct
FlattenMetalParam
{
var
idim
:
(
Int32
,
Int32
,
Int32
,
Int32
)
var
itrans
:
(
Int32
,
Int32
,
Int32
,
Int32
)
var
odim
:
(
Int32
,
Int32
,
Int32
,
Int32
)
var
otrans
:
(
Int32
,
Int32
,
Int32
,
Int32
)
}
class
FlattenKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
var
metalParam
:
FlattenMetalParam
required
init
(
device
:
MTLDevice
,
param
:
FlattenParam
<
P
>
)
{
param
.
output
.
initTexture
(
device
:
device
,
computePrecision
:
computePrecision
)
var
id
:
[
Int32
]
=
[
1
,
1
,
1
,
1
]
for
i
in
0
..<
param
.
input
.
tensorDim
.
cout
()
{
id
[
4
-
param
.
input
.
tensorDim
.
cout
()
+
i
]
=
Int32
(
param
.
input
.
tensorDim
[
i
])
}
let
it
:
[
Int32
]
=
param
.
input
.
transpose
.
map
{
Int32
(
$0
)
}
var
od
:
[
Int32
]
=
[
1
,
1
,
1
,
1
]
for
i
in
0
..<
param
.
output
.
tensorDim
.
cout
()
{
od
[
4
-
param
.
output
.
tensorDim
.
cout
()
+
i
]
=
Int32
(
param
.
output
.
tensorDim
[
i
])
}
let
ot
:
[
Int32
]
=
param
.
output
.
transpose
.
map
{
Int32
(
$0
)
}
metalParam
=
FlattenMetalParam
.
init
(
idim
:
(
id
[
0
],
id
[
1
],
id
[
2
],
id
[
3
]),
itrans
:
(
it
[
0
],
it
[
1
],
it
[
2
],
it
[
3
]),
odim
:
(
od
[
0
],
od
[
1
],
od
[
2
],
od
[
3
]),
otrans
:
(
ot
[
0
],
ot
[
1
],
ot
[
2
],
ot
[
3
])
)
let
irank
=
param
.
input
.
tensorDim
.
cout
()
let
orank
=
param
.
output
.
tensorDim
.
cout
()
assert
(
orank
==
2
)
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_
\(
irank
)
_2_float"
)
}
else
if
computePrecision
==
.
Float16
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_
\(
irank
)
_2_half"
)
}
else
{
fatalError
()
}
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
FlattenParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encoder is nil"
)
}
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
encoder
.
setBytes
(
&
metalParam
,
length
:
MemoryLayout
<
ReshapeMetalParam
>.
size
,
index
:
0
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
浏览文件 @
e71320da
...
...
@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
odim
:
(
od
[
0
],
od
[
1
],
od
[
2
],
od
[
3
]),
otrans
:
(
ot
[
0
],
ot
[
1
],
ot
[
2
],
ot
[
3
])
)
let
irank
=
param
.
input
.
tensorDim
.
cout
()
let
orank
=
param
.
output
.
tensorDim
.
cout
()
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape
_
\(
irank
)
_
\(
orank
)
_float
"
)
}
else
if
computePrecision
==
.
Float16
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_half"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_
\(
irank
)
_
\(
orank
)
_
half"
)
}
else
{
fatalError
()
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift
浏览文件 @
e71320da
...
...
@@ -27,7 +27,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
}
required
init
(
device
:
MTLDevice
,
param
:
SplitParam
<
P
>
)
{
param
.
output
.
initTexture
(
device
:
device
,
computePrecision
:
computePrecision
)
// param.output.initTexture(device: device, computePrecision: computePrecision)
for
output
in
param
.
outputList
{
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
input
.
transpose
,
computePrecision
:
computePrecision
)
}
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"split"
)
}
else
if
computePrecision
==
.
Float16
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal
浏览文件 @
e71320da
...
...
@@ -15,28 +15,28 @@
#include <metal_stdlib>
using namespace metal;
kernel void batchnorm
_half(texture2d_array<half
, access::read> inTexture [[texture(0)]],
texture2d_array<
half
, access::write> outTexture [[texture(1)]],
const device
half
4 * newScale [[buffer(0)]],
const device
half
4 * newBias [[buffer(1)]],
kernel void batchnorm
(texture2d_array<float
, access::read> inTexture [[texture(0)]],
texture2d_array<
float
, access::write> outTexture [[texture(1)]],
const device
float
4 * newScale [[buffer(0)]],
const device
float
4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
const
half
4 input = inTexture.read(gid.xy, gid.z);
half
4 output = input * newScale[gid.z] + newBias[gid.z];
const
float
4 input = inTexture.read(gid.xy, gid.z);
float
4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
kernel void batchnorm
(texture2d_array<float
, access::read> inTexture [[texture(0)]],
texture2d_array<float
, access::write> outTexture [[texture(1)]],
const device float
4 * newScale [[buffer(0)]],
const device float
4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
kernel void batchnorm
_half(texture2d_array<half
, access::read> inTexture [[texture(0)]],
texture2d_array<half
, access::write> outTexture [[texture(1)]],
const device half
4 * newScale [[buffer(0)]],
const device half
4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
const
float
4 input = inTexture.read(gid.xy, gid.z);
float
4 output = input * newScale[gid.z] + newBias[gid.z];
const
half
4 input = inTexture.read(gid.xy, gid.z);
half
4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.
metal.inc
→
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.
inc.metal
浏览文件 @
e71320da
#ifndef D
#define D 4
#endif
#ifndef P
#define P float
#endif
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define FUNC(f,
d, p) CONCAT3_(f, d
, p)
#define FUNC(f,
r, p) CONCAT3_(f, r
, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_
D(f, d) CONCAT2_(f, d
)
#define FUNC_
R(f, r) CONCAT2_(f, r
)
kernel
void
FUNC
(
concat
,
D
,
P
)(
texture2d_array
<
P
,
access
::
read
>
in0
[[
texture
(
0
)]],
kernel void FUNC(concat,
R
, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
texture2d_array<P, access::read> in2 [[texture(2)]],
texture2d_array<P, access::read> in3 [[texture(3)]],
...
...
@@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
VECTOR(P, 4) r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
#if
D
== 4
#if
R
== 4
xyzn2abcd_4(cp.odim[3], xyzn, abcd);
#else
FUNC_
D
(
xyzn2abcd
,
D
)(
xyzn
,
abcd
);
FUNC_
R(xyzn2abcd, R
)(xyzn, abcd);
#endif
int k = abcd[cp.axis] - cp.offset;
int j = 0;
...
...
@@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
#if
D
== 4
#if
R
== 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
#else
FUNC_
D
(
abcd2xyzn
,
D
)(
abcd
,
oxyzn
);
FUNC_
R(abcd2xyzn, R
)(abcd, oxyzn);
#endif
cp.odim[cp.axis] = ta;
switch (j) {
...
...
@@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
}
out.write(r, gid.xy, gid.z);
}
#endif
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal
浏览文件 @
e71320da
...
...
@@ -26,31 +26,31 @@ struct ConcatParam {
};
#define P float
#define
D
4
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
D
3
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
D
2
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
D
1
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
R
4
#include "ConcatKernel.
inc.metal
"
#undef
R
#define
R
3
#include "ConcatKernel.
inc.metal
"
#undef
R
#define
R
2
#include "ConcatKernel.
inc.metal
"
#undef
R
#define
R
1
#include "ConcatKernel.
inc.metal
"
#undef
R
#undef P
#define P half
#define
D
4
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
D
3
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
D
2
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
D
1
#include "ConcatKernel.
metal.inc
"
#undef
D
#define
R
4
#include "ConcatKernel.
inc.metal
"
#undef
R
#define
R
3
#include "ConcatKernel.
inc.metal
"
#undef
R
#define
R
2
#include "ConcatKernel.
inc.metal
"
#undef
R
#define
R
1
#include "ConcatKernel.
inc.metal
"
#undef
R
#undef P
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.
metal.inc
→
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.
inc.metal
浏览文件 @
e71320da
#ifndef P
#define P float
#endif
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define FUNC(f,
d1, d2, p) CONCAT4_(f, d1, d
2, p)
#define FUNC(f,
r1, r2, p) CONCAT4_(f, r1, r
2, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_
D(f, d) CONCAT2_(f, d
)
#define FUNC_
R(f, r) CONCAT2_(f, r
)
kernel
void
FUNC
(
reshape
,
DIN
,
D
OUT
,
P
)(
texture2d_array
<
P
,
access
::
read
>
inTexture
[[
texture
(
0
)]],
kernel void FUNC(reshape,
RIN, R
OUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
...
...
@@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
VECTOR(P, 4) r;
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
#if
D
OUT == 4
#if
R
OUT == 4
xyzn2abcd_4(oC, oxyzn, oabcd);
#else
FUNC_
D
(
xyzn2abcd
,
D
OUT
)(
oxyzn
,
oabcd
);
FUNC_
R(xyzn2abcd, R
OUT)(oxyzn, oabcd);
#endif
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
...
...
@@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, iabcd, ixyzn);
#if
D
IN == 4
#if
R
IN == 4
abcd2xyzn_4(iC, iabcd, ixyzn);
#else
FUNC_
D
(
abcd2xyzn
,
D
IN
)(
iabcd
,
ixyzn
);
FUNC_
R(abcd2xyzn, R
IN)(iabcd, ixyzn);
#endif
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
...
...
@@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
outTexture.write(r, gid.xy, gid.z);
}
#endif
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal
浏览文件 @
e71320da
...
...
@@ -8,7 +8,7 @@
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CON
DITIONS OF ANY KIND
, either express or implied.
WITHOUT WARRANTIES OR CON
RITIONS OF ANY KINR
, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
...
...
@@ -25,127 +25,126 @@ struct ReshapeParam {
};
#define P float
#define
D
IN 4
#define
D
OUT 4
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 3
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 2
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 1
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#undef
D
IN
#define
R
IN 4
#define
R
OUT 4
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 3
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 2
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 1
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#undef
R
IN
#define
D
IN 3
#define
D
OUT 4
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 3
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 2
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 1
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#undef
D
IN
#define
R
IN 3
#define
R
OUT 4
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 3
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 2
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 1
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#undef
R
IN
#define
D
IN 2
#define
D
OUT 4
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 3
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 2
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 1
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#undef
D
IN
#define
R
IN 2
#define
R
OUT 4
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 3
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 2
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 1
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#undef
R
IN
#define
D
IN 1
#define
D
OUT 4
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 3
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 2
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 1
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#undef
D
IN
#define
R
IN 1
#define
R
OUT 4
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 3
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 2
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 1
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#undef
R
IN
#undef P
#define P half
#define
D
IN 4
#define
D
OUT 4
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 3
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 2
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 1
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#undef
D
IN
#define
R
IN 4
#define
R
OUT 4
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 3
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 2
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 1
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#undef
R
IN
#define
D
IN 3
#define
D
OUT 4
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 3
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 2
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#define
D
OUT 1
#include "ReshapeKernel.
metal.inc
"
#undef
D
OUT
#undef
D
IN
#define
R
IN 3
#define
R
OUT 4
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 3
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 2
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#define
R
OUT 1
#include "ReshapeKernel.
inc.metal
"
#undef
R
OUT
#undef
R
IN
#define DIN 2
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define DIN 1
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 2
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#undef P
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
浏览文件 @
e71320da
...
...
@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam {
}
output
.
padToFourDim
=
Dim
.
init
(
inDim
:
dim
)
output
.
dim
=
output
.
padToFourDim
// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
let
shape
:
[
Int32
]
// let inplace: Bool
var
output
:
Texture
<
P
>
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift
浏览文件 @
e71320da
...
...
@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam {
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
output
=
try
ShapeParam
.
output
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
input
=
try
ShapeParam
.
input
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
ShapeParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
}
catch
let
error
{
throw
error
}
}
var
output
:
Texture
<
P
>
let
input
:
Texture
<
P
>
}
class
ShapeOp
<
P
:
PrecisionType
>
:
Operator
<
S
plitKernel
<
P
>
,
Split
Param
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ShapeOp
<
P
:
PrecisionType
>
:
Operator
<
S
hapeKernel
<
P
>
,
Shape
Param
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
typealias
OpType
=
S
plit
Op
<
P
>
typealias
OpType
=
S
hape
Op
<
P
>
func
inferShape
()
{
// para.output.dim = para.input.dim
...
...
metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift
浏览文件 @
e71320da
...
...
@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam {
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
// output = try SplitParam.output(outputs: opDesc.outputs, from: inScope)
output
=
try
SplitParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
input
=
try
SplitParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
Texture
<
P
>.
init
(
device
:
input
.
metalTexture
!.
device
,
inDim
:
input
.
dim
)
axis
=
try
SplitParam
.
getAttr
(
key
:
"axis"
,
attrs
:
opDesc
.
attrs
)
sections
=
try
SplitParam
.
getAttr
(
key
:
"sections"
,
attrs
:
opDesc
.
attrs
)
if
axis
<
0
{
axis
=
input
.
tensorDim
.
cout
()
+
axis
}
guard
let
outlist
=
opDesc
.
outputs
[
"Out"
]
else
{
fatalError
()
}
for
out
in
outlist
{
guard
let
variant
=
inScope
[
out
],
let
v
=
variant
as?
Texture
<
P
>
else
{
fatalError
()
}
outputList
.
append
(
v
)
sections
.
append
(
Int32
(
v
.
tensorDim
.
dims
[
axis
]))
}
}
catch
let
error
{
throw
error
}
}
var
axis
:
Int
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
var
outputList
:
[
Texture
<
P
>
]
=
[]
var
sections
:
[
Int32
]
=
[]
}
class
SplitOp
<
P
:
PrecisionType
>
:
Operator
<
SplitKernel
<
P
>
,
SplitParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录