Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
e71320da
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e71320da
编写于
9月 12, 2018
作者:
D
dolphin8
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
batchnorm
上级
2bcb1135
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
334 addition
and
243 deletion
+334
-243
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
+12
-8
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
...l/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
+11
-8
metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift
metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift
+18
-1
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
...ile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
+17
-38
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift
...mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift
+3
-2
metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift
...obile/paddle-mobile/Operators/Kernels/FlattenKernel.swift
+71
-0
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
...obile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
+4
-2
metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift
...-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift
+4
-1
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal
...ddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal
+13
-13
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal
...dle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal
+9
-14
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal
.../paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal
+24
-24
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal
...le-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal
+9
-10
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal
...paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal
+113
-114
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
+0
-3
metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift
metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift
+5
-3
metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift
metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift
+21
-2
未找到文件。
metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj
浏览文件 @
e71320da
...
@@ -16,8 +16,9 @@
...
@@ -16,8 +16,9 @@
4AA1EA92214665D700D0F791
/* ShapeOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA91214665D700D0F791
/* ShapeOp.swift */
;
};
4AA1EA92214665D700D0F791
/* ShapeOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA91214665D700D0F791
/* ShapeOp.swift */
;
};
4AA1EA942146661500D0F791
/* ShapeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA932146661500D0F791
/* ShapeKernel.swift */
;
};
4AA1EA942146661500D0F791
/* ShapeKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA932146661500D0F791
/* ShapeKernel.swift */
;
};
4AA1EA982146666500D0F791
/* FlattenOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA972146666500D0F791
/* FlattenOp.swift */
;
};
4AA1EA982146666500D0F791
/* FlattenOp.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA972146666500D0F791
/* FlattenOp.swift */
;
};
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.metal.inc in Headers */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.metal.inc */
;
};
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.inc.metal in Headers */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.inc.metal */
;
};
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.metal.inc in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.metal.inc */
;
};
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.inc.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.inc.metal */
;
};
4AA1EAA2214912CD00D0F791
/* FlattenKernel.swift in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AA1EAA1214912CC00D0F791
/* FlattenKernel.swift */
;
};
4AF928772133F1DB005B6C3A
/* BoxCoder.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF928762133F1DB005B6C3A
/* BoxCoder.metal */
;
};
4AF928772133F1DB005B6C3A
/* BoxCoder.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF928762133F1DB005B6C3A
/* BoxCoder.metal */
;
};
4AF9287921341661005B6C3A
/* Softmax.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF9287821341661005B6C3A
/* Softmax.metal */
;
};
4AF9287921341661005B6C3A
/* Softmax.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF9287821341661005B6C3A
/* Softmax.metal */
;
};
4AF928822135673D005B6C3A
/* ConcatKernel.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
;
};
4AF928822135673D005B6C3A
/* ConcatKernel.metal in Sources */
=
{
isa
=
PBXBuildFile
;
fileRef
=
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
;
};
...
@@ -126,8 +127,9 @@
...
@@ -126,8 +127,9 @@
4AA1EA91214665D700D0F791
/* ShapeOp.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ShapeOp.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA91214665D700D0F791
/* ShapeOp.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ShapeOp.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA932146661500D0F791
/* ShapeKernel.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ShapeKernel.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA932146661500D0F791
/* ShapeKernel.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
ShapeKernel.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA972146666500D0F791
/* FlattenOp.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FlattenOp.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA972146666500D0F791
/* FlattenOp.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FlattenOp.swift
;
sourceTree
=
"<group>"
;
};
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.metal.inc */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ConcatKernel.metal.inc
;
sourceTree
=
"<group>"
;
};
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.inc.metal */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ConcatKernel.inc.metal
;
sourceTree
=
"<group>"
;
};
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.metal.inc */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ReshapeKernel.metal.inc
;
sourceTree
=
"<group>"
;
};
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.inc.metal */
=
{
isa
=
PBXFileReference
;
explicitFileType
=
sourcecode.metal
;
fileEncoding
=
4
;
path
=
ReshapeKernel.inc.metal
;
sourceTree
=
"<group>"
;
};
4AA1EAA1214912CC00D0F791
/* FlattenKernel.swift */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.swift
;
path
=
FlattenKernel.swift
;
sourceTree
=
"<group>"
;
};
4AF928762133F1DB005B6C3A
/* BoxCoder.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
BoxCoder.metal
;
sourceTree
=
"<group>"
;
};
4AF928762133F1DB005B6C3A
/* BoxCoder.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
BoxCoder.metal
;
sourceTree
=
"<group>"
;
};
4AF9287821341661005B6C3A
/* Softmax.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
Softmax.metal
;
sourceTree
=
"<group>"
;
};
4AF9287821341661005B6C3A
/* Softmax.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
Softmax.metal
;
sourceTree
=
"<group>"
;
};
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
ConcatKernel.metal
;
sourceTree
=
"<group>"
;
};
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
=
{
isa
=
PBXFileReference
;
fileEncoding
=
4
;
lastKnownFileType
=
sourcecode.metal
;
path
=
ConcatKernel.metal
;
sourceTree
=
"<group>"
;
};
...
@@ -395,6 +397,7 @@
...
@@ -395,6 +397,7 @@
FCD04E6720F315020007374F
/* PoolKernel.swift */
,
FCD04E6720F315020007374F
/* PoolKernel.swift */
,
FCD04E6B20F31A280007374F
/* SoftmaxKernel.swift */
,
FCD04E6B20F31A280007374F
/* SoftmaxKernel.swift */
,
FCD04E6F20F31B720007374F
/* ReshapeKernel.swift */
,
FCD04E6F20F31B720007374F
/* ReshapeKernel.swift */
,
4AA1EAA1214912CC00D0F791
/* FlattenKernel.swift */
,
FCD04E7320F3437E0007374F
/* ConvAddKernel.swift */
,
FCD04E7320F3437E0007374F
/* ConvAddKernel.swift */
,
FCBCCC5A2122F66F00D94F7E
/* ConvBNReluKernel.swift */
,
FCBCCC5A2122F66F00D94F7E
/* ConvBNReluKernel.swift */
,
FCBCCC602122FBDF00D94F7E
/* PriorBoxKernel.swift */
,
FCBCCC602122FBDF00D94F7E
/* PriorBoxKernel.swift */
,
...
@@ -442,7 +445,7 @@
...
@@ -442,7 +445,7 @@
children
=
(
children
=
(
FC27990D21341016000B6BAD
/* BoxCoder.metal */
,
FC27990D21341016000B6BAD
/* BoxCoder.metal */
,
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
,
4AF928812135673D005B6C3A
/* ConcatKernel.metal */
,
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.
metal.inc
*/
,
4AA1EA9D2148D6F900D0F791
/* ConcatKernel.
inc.metal
*/
,
4AF9288321357BE3005B6C3A
/* Elementwise.metal */
,
4AF9288321357BE3005B6C3A
/* Elementwise.metal */
,
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
,
FC1B16B220EC9A4F00678B91
/* Kernels.metal */
,
FC4CB74820F0B954007C0C6D
/* ConvKernel.metal */
,
FC4CB74820F0B954007C0C6D
/* ConvKernel.metal */
,
...
@@ -455,7 +458,7 @@
...
@@ -455,7 +458,7 @@
FCDDC6CB212FDFDB00E5EF74
/* ReluKernel.metal */
,
FCDDC6CB212FDFDB00E5EF74
/* ReluKernel.metal */
,
FCDDC6CE212FE14700E5EF74
/* PriorBoxKernel.metal */
,
FCDDC6CE212FE14700E5EF74
/* PriorBoxKernel.metal */
,
FCA3A1622132A4AC00084FE5
/* ReshapeKernel.metal */
,
FCA3A1622132A4AC00084FE5
/* ReshapeKernel.metal */
,
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.
metal.inc
*/
,
4AA1EA9F2148DEEE00D0F791
/* ReshapeKernel.
inc.metal
*/
,
FCA3A1642132A5EB00084FE5
/* Common.metal */
,
FCA3A1642132A5EB00084FE5
/* Common.metal */
,
FCA67B1621364EF000BD58AA
/* ConvTransposeKernel.metal */
,
FCA67B1621364EF000BD58AA
/* ConvTransposeKernel.metal */
,
FCA67CD42138272900BD58AA
/* ConvAddMetal.metal */
,
FCA67CD42138272900BD58AA
/* ConvAddMetal.metal */
,
...
@@ -477,7 +480,7 @@
...
@@ -477,7 +480,7 @@
FC4FD9792140E4980073E130
/* PaddleMobile.h in Headers */
,
FC4FD9792140E4980073E130
/* PaddleMobile.h in Headers */
,
FC292C85214257CB00CF622F
/* CPUCompute.h in Headers */
,
FC292C85214257CB00CF622F
/* CPUCompute.h in Headers */
,
FC292C5421421B2F00CF622F
/* PaddleMobileGPU.h in Headers */
,
FC292C5421421B2F00CF622F
/* PaddleMobileGPU.h in Headers */
,
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.
metal.inc
in Headers */
,
4AA1EA9E2148D6F900D0F791
/* ConcatKernel.
inc.metal
in Headers */
,
FC039B6F20E11C3C0081E9F8
/* paddle_mobile.h in Headers */
,
FC039B6F20E11C3C0081E9F8
/* paddle_mobile.h in Headers */
,
);
);
runOnlyForDeploymentPostprocessing
=
0
;
runOnlyForDeploymentPostprocessing
=
0
;
...
@@ -617,6 +620,7 @@
...
@@ -617,6 +620,7 @@
FCBCCC592122F42700D94F7E
/* ConvBNReluOp.swift in Sources */
,
FCBCCC592122F42700D94F7E
/* ConvBNReluOp.swift in Sources */
,
FC039BA920E11CBC0081E9F8
/* ConvOp.swift in Sources */
,
FC039BA920E11CBC0081E9F8
/* ConvOp.swift in Sources */
,
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
,
FC9D038420E23B01000F735A
/* Texture.swift in Sources */
,
4AA1EAA2214912CD00D0F791
/* FlattenKernel.swift in Sources */
,
4AA1EA982146666500D0F791
/* FlattenOp.swift in Sources */
,
4AA1EA982146666500D0F791
/* FlattenOp.swift in Sources */
,
FCBCCC652122FCD700D94F7E
/* TransposeOp.swift in Sources */
,
FCBCCC652122FCD700D94F7E
/* TransposeOp.swift in Sources */
,
FCD04E6E20F31B4B0007374F
/* ReshapeOp.swift in Sources */
,
FCD04E6E20F31B4B0007374F
/* ReshapeOp.swift in Sources */
,
...
@@ -657,7 +661,7 @@
...
@@ -657,7 +661,7 @@
FCBCCC67212306B000D94F7E
/* ConcatOp.swift in Sources */
,
FCBCCC67212306B000D94F7E
/* ConcatOp.swift in Sources */
,
FCD04E6C20F31A280007374F
/* SoftmaxKernel.swift in Sources */
,
FCD04E6C20F31A280007374F
/* SoftmaxKernel.swift in Sources */
,
FCEB684A212F00DB00D2448E
/* PreluKernel.metal in Sources */
,
FCEB684A212F00DB00D2448E
/* PreluKernel.metal in Sources */
,
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.
metal.inc
in Sources */
,
4AA1EAA02148DEEE00D0F791
/* ReshapeKernel.
inc.metal
in Sources */
,
FC9A19E32148C31300CD9CBF
/* MobilenetSSD_AR.swift in Sources */
,
FC9A19E32148C31300CD9CBF
/* MobilenetSSD_AR.swift in Sources */
,
FCDDC6CF212FE14700E5EF74
/* PriorBoxKernel.metal in Sources */
,
FCDDC6CF212FE14700E5EF74
/* PriorBoxKernel.metal in Sources */
,
FC4CB74B20F12C30007C0C6D
/* ProgramOptimize.swift in Sources */
,
FC4CB74B20F12C30007C0C6D
/* ProgramOptimize.swift in Sources */
,
...
...
metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
浏览文件 @
e71320da
...
@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam {
...
@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam {
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
input
=
try
BatchNormParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
input
=
try
BatchNormParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
if
input
.
transpose
!=
[
0
,
2
,
3
,
1
]
{
fatalError
(
"batch norm only accepts NHWC"
)
}
output
=
try
BatchNormParam
.
outputY
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
output
=
try
BatchNormParam
.
outputY
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
inputBias
=
try
BatchNormParam
.
inputBiase
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
bias
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Bias"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
inputMean
=
try
BatchNormParam
.
inputMean
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
mean
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Mean"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
inputScale
=
try
BatchNormParam
.
inputScale
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
scale
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Scale"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
inputVariance
=
try
BatchNormParam
.
inputVariance
(
inputs
:
opDesc
.
paraInputs
,
from
:
inScope
)
variance
=
try
BatchNormParam
.
getFirstTensor
(
key
:
"Variance"
,
map
:
opDesc
.
paraInputs
,
from
:
inScope
)
epsilon
=
try
BatchNormParam
.
getAttr
(
key
:
"epsilon"
,
attrs
:
opDesc
.
attrs
)
epsilon
=
try
BatchNormParam
.
getAttr
(
key
:
"epsilon"
,
attrs
:
opDesc
.
attrs
)
momentum
=
try
BatchNormParam
.
getAttr
(
key
:
"momentum"
,
attrs
:
opDesc
.
attrs
)
momentum
=
try
BatchNormParam
.
getAttr
(
key
:
"momentum"
,
attrs
:
opDesc
.
attrs
)
}
catch
let
error
{
}
catch
let
error
{
...
@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam {
...
@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam {
}
}
let
input
:
Texture
<
P
>
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
inputBias
:
Tensor
<
ParamPrecisionType
>
let
bias
:
Tensor
<
P
>
let
inputMean
:
Tensor
<
ParamPrecisionType
>
let
mean
:
Tensor
<
P
>
let
inputScale
:
Tensor
<
ParamPrecisionType
>
let
scale
:
Tensor
<
P
>
let
inputVariance
:
Tensor
<
ParamPrecisionType
>
let
variance
:
Tensor
<
P
>
let
epsilon
:
Float
let
epsilon
:
Float
let
momentum
:
Float
let
momentum
:
Float
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift
浏览文件 @
e71320da
...
@@ -14,7 +14,24 @@
...
@@ -14,7 +14,24 @@
import
Foundation
import
Foundation
class
FlattenOp
<
P
:
PrecisionType
>
:
Operator
<
ReshapeKernel
<
P
>
,
ReshapeParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
FlattenParam
<
P
:
PrecisionType
>
:
OpParam
{
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
input
=
try
FlattenParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
FlattenParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
axis
=
try
FlattenParam
.
getAttr
(
key
:
"axis"
,
attrs
:
opDesc
.
attrs
)
}
catch
let
error
{
throw
error
}
}
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
axis
:
Int
}
class
FlattenOp
<
P
:
PrecisionType
>
:
Operator
<
FlattenKernel
<
P
>
,
FlattenParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
typealias
OpType
=
FlattenOp
<
P
>
typealias
OpType
=
FlattenOp
<
P
>
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift
浏览文件 @
e71320da
...
@@ -15,20 +15,20 @@
...
@@ -15,20 +15,20 @@
import
Foundation
import
Foundation
class
BatchNormKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
class
BatchNormKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
// var newScale: MTLBuffer
// var newBias: MTLBuffer
//
required
init
(
device
:
MTLDevice
,
param
:
BatchNormParam
<
P
>
)
{
required
init
(
device
:
MTLDevice
,
param
:
BatchNormParam
<
P
>
)
{
// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
let
count
=
param
.
variance
.
dim
.
numel
()
// fatalError()
let
varianceP
=
param
.
variance
.
data
.
pointer
// }
let
meanP
=
param
.
mean
.
data
.
pointer
//
let
scaleP
=
param
.
scale
.
data
.
pointer
// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
let
biasP
=
param
.
scale
.
data
.
pointer
// fatalError()
for
i
in
0
..<
count
{
// }
let
invStd
=
P
(
1
/
(
Float32
(
varianceP
[
i
])
+
param
.
epsilon
)
.
squareRoot
())
// self.newScale = newScale
biasP
[
i
]
=
biasP
[
i
]
-
meanP
[
i
]
*
invStd
*
scaleP
[
i
]
// self.newBias = newBias
scaleP
[
i
]
=
invStd
*
scaleP
[
i
]
//
}
param
.
bias
.
initBuffer
(
device
:
device
,
precision
:
computePrecision
)
param
.
scale
.
initBuffer
(
device
:
device
,
precision
:
computePrecision
)
param
.
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
input
.
transpose
,
computePrecision
:
computePrecision
)
if
computePrecision
==
.
Float32
{
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"batchnorm"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"batchnorm"
)
}
else
if
computePrecision
==
.
Float16
{
}
else
if
computePrecision
==
.
Float16
{
...
@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
...
@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
}
else
{
}
else
{
fatalError
()
fatalError
()
}
}
//
// let varianceBuffer : MTLBuffer = param.inputVariance.buffer
//
// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
// for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
// }
//
// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
// let scale : MTLBuffer = param.inputScale.buffer
// let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
// let bias : MTLBuffer = param.inputBias.buffer
// let biasContents = bias.contents().assumingMemoryBound(to: P.self)
// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
//
// for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
// }
}
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
BatchNormParam
<
P
>
)
throws
{
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
BatchNormParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encoder is nil"
)
throw
PaddleMobileError
.
predictError
(
message
:
" encoder is nil"
)
}
}
//
encoder.setTexture(param.input.metalTexture, index: 0)
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
//
encoder.setTexture(param.output.metalTexture, index: 1)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
// encoder.setBuffer(newScale
, offset: 0, index: 0)
encoder
.
setBuffer
(
param
.
scale
.
buffer
,
offset
:
0
,
index
:
0
)
// encoder.setBuffer(newBias
, offset: 0, index: 1)
encoder
.
setBuffer
(
param
.
bias
.
buffer
,
offset
:
0
,
index
:
1
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
encoder
.
endEncoding
()
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift
浏览文件 @
e71320da
...
@@ -122,10 +122,11 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
...
@@ -122,10 +122,11 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
required
init
(
device
:
MTLDevice
,
param
:
ConcatParam
<
P
>
)
{
required
init
(
device
:
MTLDevice
,
param
:
ConcatParam
<
P
>
)
{
param
.
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
transpose
,
computePrecision
:
computePrecision
)
param
.
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
transpose
,
computePrecision
:
computePrecision
)
let
orank
=
param
.
output
.
tensorDim
.
cout
()
if
computePrecision
==
.
Float32
{
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"concat"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"concat
_
\(
orank
)
_float
"
)
}
else
if
computePrecision
==
.
Float16
{
}
else
if
computePrecision
==
.
Float16
{
super
.
init
(
device
:
device
,
inFunctionName
:
"concat_half"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"concat_
\(
orank
)
_
half"
)
}
else
{
}
else
{
fatalError
()
fatalError
()
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift
0 → 100644
浏览文件 @
e71320da
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import
Foundation
struct
FlattenMetalParam
{
var
idim
:
(
Int32
,
Int32
,
Int32
,
Int32
)
var
itrans
:
(
Int32
,
Int32
,
Int32
,
Int32
)
var
odim
:
(
Int32
,
Int32
,
Int32
,
Int32
)
var
otrans
:
(
Int32
,
Int32
,
Int32
,
Int32
)
}
class
FlattenKernel
<
P
:
PrecisionType
>
:
Kernel
,
Computable
{
var
metalParam
:
FlattenMetalParam
required
init
(
device
:
MTLDevice
,
param
:
FlattenParam
<
P
>
)
{
param
.
output
.
initTexture
(
device
:
device
,
computePrecision
:
computePrecision
)
var
id
:
[
Int32
]
=
[
1
,
1
,
1
,
1
]
for
i
in
0
..<
param
.
input
.
tensorDim
.
cout
()
{
id
[
4
-
param
.
input
.
tensorDim
.
cout
()
+
i
]
=
Int32
(
param
.
input
.
tensorDim
[
i
])
}
let
it
:
[
Int32
]
=
param
.
input
.
transpose
.
map
{
Int32
(
$0
)
}
var
od
:
[
Int32
]
=
[
1
,
1
,
1
,
1
]
for
i
in
0
..<
param
.
output
.
tensorDim
.
cout
()
{
od
[
4
-
param
.
output
.
tensorDim
.
cout
()
+
i
]
=
Int32
(
param
.
output
.
tensorDim
[
i
])
}
let
ot
:
[
Int32
]
=
param
.
output
.
transpose
.
map
{
Int32
(
$0
)
}
metalParam
=
FlattenMetalParam
.
init
(
idim
:
(
id
[
0
],
id
[
1
],
id
[
2
],
id
[
3
]),
itrans
:
(
it
[
0
],
it
[
1
],
it
[
2
],
it
[
3
]),
odim
:
(
od
[
0
],
od
[
1
],
od
[
2
],
od
[
3
]),
otrans
:
(
ot
[
0
],
ot
[
1
],
ot
[
2
],
ot
[
3
])
)
let
irank
=
param
.
input
.
tensorDim
.
cout
()
let
orank
=
param
.
output
.
tensorDim
.
cout
()
assert
(
orank
==
2
)
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_
\(
irank
)
_2_float"
)
}
else
if
computePrecision
==
.
Float16
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_
\(
irank
)
_2_half"
)
}
else
{
fatalError
()
}
}
func
compute
(
commandBuffer
:
MTLCommandBuffer
,
param
:
FlattenParam
<
P
>
)
throws
{
guard
let
encoder
=
commandBuffer
.
makeComputeCommandEncoder
()
else
{
throw
PaddleMobileError
.
predictError
(
message
:
" encoder is nil"
)
}
encoder
.
setTexture
(
param
.
input
.
metalTexture
,
index
:
0
)
encoder
.
setTexture
(
param
.
output
.
metalTexture
,
index
:
1
)
encoder
.
setBytes
(
&
metalParam
,
length
:
MemoryLayout
<
ReshapeMetalParam
>.
size
,
index
:
0
)
encoder
.
dispatch
(
computePipline
:
pipline
,
outTexture
:
param
.
output
.
metalTexture
)
encoder
.
endEncoding
()
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift
浏览文件 @
e71320da
...
@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
...
@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
odim
:
(
od
[
0
],
od
[
1
],
od
[
2
],
od
[
3
]),
odim
:
(
od
[
0
],
od
[
1
],
od
[
2
],
od
[
3
]),
otrans
:
(
ot
[
0
],
ot
[
1
],
ot
[
2
],
ot
[
3
])
otrans
:
(
ot
[
0
],
ot
[
1
],
ot
[
2
],
ot
[
3
])
)
)
let
irank
=
param
.
input
.
tensorDim
.
cout
()
let
orank
=
param
.
output
.
tensorDim
.
cout
()
if
computePrecision
==
.
Float32
{
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape
_
\(
irank
)
_
\(
orank
)
_float
"
)
}
else
if
computePrecision
==
.
Float16
{
}
else
if
computePrecision
==
.
Float16
{
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_half"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"reshape_
\(
irank
)
_
\(
orank
)
_
half"
)
}
else
{
}
else
{
fatalError
()
fatalError
()
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift
浏览文件 @
e71320da
...
@@ -27,7 +27,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
...
@@ -27,7 +27,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
}
}
required
init
(
device
:
MTLDevice
,
param
:
SplitParam
<
P
>
)
{
required
init
(
device
:
MTLDevice
,
param
:
SplitParam
<
P
>
)
{
param
.
output
.
initTexture
(
device
:
device
,
computePrecision
:
computePrecision
)
// param.output.initTexture(device: device, computePrecision: computePrecision)
for
output
in
param
.
outputList
{
output
.
initTexture
(
device
:
device
,
inTranspose
:
param
.
input
.
transpose
,
computePrecision
:
computePrecision
)
}
if
computePrecision
==
.
Float32
{
if
computePrecision
==
.
Float32
{
super
.
init
(
device
:
device
,
inFunctionName
:
"split"
)
super
.
init
(
device
:
device
,
inFunctionName
:
"split"
)
}
else
if
computePrecision
==
.
Float16
{
}
else
if
computePrecision
==
.
Float16
{
...
...
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal
浏览文件 @
e71320da
...
@@ -15,28 +15,28 @@
...
@@ -15,28 +15,28 @@
#include <metal_stdlib>
#include <metal_stdlib>
using namespace metal;
using namespace metal;
kernel void batchnorm
_half(texture2d_array<half
, access::read> inTexture [[texture(0)]],
kernel void batchnorm
(texture2d_array<float
, access::read> inTexture [[texture(0)]],
texture2d_array<
half
, access::write> outTexture [[texture(1)]],
texture2d_array<
float
, access::write> outTexture [[texture(1)]],
const device
half
4 * newScale [[buffer(0)]],
const device
float
4 * newScale [[buffer(0)]],
const device
half
4 * newBias [[buffer(1)]],
const device
float
4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
gid.z >= outTexture.get_array_size()) return;
const
half
4 input = inTexture.read(gid.xy, gid.z);
const
float
4 input = inTexture.read(gid.xy, gid.z);
half
4 output = input * newScale[gid.z] + newBias[gid.z];
float
4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
outTexture.write(output, gid.xy, gid.z);
}
}
kernel void batchnorm
(texture2d_array<float
, access::read> inTexture [[texture(0)]],
kernel void batchnorm
_half(texture2d_array<half
, access::read> inTexture [[texture(0)]],
texture2d_array<float
, access::write> outTexture [[texture(1)]],
texture2d_array<half
, access::write> outTexture [[texture(1)]],
const device float
4 * newScale [[buffer(0)]],
const device half
4 * newScale [[buffer(0)]],
const device float
4 * newBias [[buffer(1)]],
const device half
4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
gid.z >= outTexture.get_array_size()) return;
const
float
4 input = inTexture.read(gid.xy, gid.z);
const
half
4 input = inTexture.read(gid.xy, gid.z);
float
4 output = input * newScale[gid.z] + newBias[gid.z];
half
4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
outTexture.write(output, gid.xy, gid.z);
}
}
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.
metal.inc
→
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.
inc.metal
浏览文件 @
e71320da
#ifndef D
#ifdef P
#define D 4
#endif
#ifndef P
#define P float
#endif
#define CONCAT2(a, b) a ## b
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define FUNC(f,
d, p) CONCAT3_(f, d
, p)
#define FUNC(f,
r, p) CONCAT3_(f, r
, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_
D(f, d) CONCAT2_(f, d
)
#define FUNC_
R(f, r) CONCAT2_(f, r
)
kernel
void
FUNC
(
concat
,
D
,
P
)(
texture2d_array
<
P
,
access
::
read
>
in0
[[
texture
(
0
)]],
kernel void FUNC(concat,
R
, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
texture2d_array<P, access::read> in2 [[texture(2)]],
texture2d_array<P, access::read> in2 [[texture(2)]],
texture2d_array<P, access::read> in3 [[texture(3)]],
texture2d_array<P, access::read> in3 [[texture(3)]],
...
@@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
...
@@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
VECTOR(P, 4) r;
VECTOR(P, 4) r;
for (int i = 0; i < 4; i++) {
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn[3] = i;
#if
D
== 4
#if
R
== 4
xyzn2abcd_4(cp.odim[3], xyzn, abcd);
xyzn2abcd_4(cp.odim[3], xyzn, abcd);
#else
#else
FUNC_
D
(
xyzn2abcd
,
D
)(
xyzn
,
abcd
);
FUNC_
R(xyzn2abcd, R
)(xyzn, abcd);
#endif
#endif
int k = abcd[cp.axis] - cp.offset;
int k = abcd[cp.axis] - cp.offset;
int j = 0;
int j = 0;
...
@@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
...
@@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
int ta = cp.odim[cp.axis];
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
cp.odim[cp.axis] = cp.vdim[j];
#if
D
== 4
#if
R
== 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
#else
#else
FUNC_
D
(
abcd2xyzn
,
D
)(
abcd
,
oxyzn
);
FUNC_
R(abcd2xyzn, R
)(abcd, oxyzn);
#endif
#endif
cp.odim[cp.axis] = ta;
cp.odim[cp.axis] = ta;
switch (j) {
switch (j) {
...
@@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
...
@@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
}
}
out.write(r, gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
}
}
#endif
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal
浏览文件 @
e71320da
...
@@ -26,31 +26,31 @@ struct ConcatParam {
...
@@ -26,31 +26,31 @@ struct ConcatParam {
};
};
#define P float
#define P float
#define
D
4
#define
R
4
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#define
D
3
#define
R
3
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#define
D
2
#define
R
2
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#define
D
1
#define
R
1
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#undef P
#undef P
#define P half
#define P half
#define
D
4
#define
R
4
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#define
D
3
#define
R
3
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#define
D
2
#define
R
2
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#define
D
1
#define
R
1
#include "ConcatKernel.
metal.inc
"
#include "ConcatKernel.
inc.metal
"
#undef
D
#undef
R
#undef P
#undef P
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.
metal.inc
→
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.
inc.metal
浏览文件 @
e71320da
#ifndef P
#ifdef P
#define P float
#endif
#define CONCAT2(a, b) a ## b
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define FUNC(f,
d1, d2, p) CONCAT4_(f, d1, d
2, p)
#define FUNC(f,
r1, r2, p) CONCAT4_(f, r1, r
2, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_
D(f, d) CONCAT2_(f, d
)
#define FUNC_
R(f, r) CONCAT2_(f, r
)
kernel
void
FUNC
(
reshape
,
DIN
,
D
OUT
,
P
)(
texture2d_array
<
P
,
access
::
read
>
inTexture
[[
texture
(
0
)]],
kernel void FUNC(reshape,
RIN, R
OUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
uint3 gid [[thread_position_in_grid]]) {
...
@@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
...
@@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
VECTOR(P, 4) r;
VECTOR(P, 4) r;
for (int n = 0; n < 4; n++) {
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
oxyzn[3] = n;
#if
D
OUT == 4
#if
R
OUT == 4
xyzn2abcd_4(oC, oxyzn, oabcd);
xyzn2abcd_4(oC, oxyzn, oabcd);
#else
#else
FUNC_
D
(
xyzn2abcd
,
D
OUT
)(
oxyzn
,
oabcd
);
FUNC_
R(xyzn2abcd, R
OUT)(oxyzn, oabcd);
#endif
#endif
int tabcd[4];
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
invtrans(lrp.otrans, oabcd, tabcd);
...
@@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
...
@@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
index2abcd(lrp.idim, index, tabcd);
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, iabcd, ixyzn);
abcd2xyzn(iC, iabcd, ixyzn);
#if
D
IN == 4
#if
R
IN == 4
abcd2xyzn_4(iC, iabcd, ixyzn);
abcd2xyzn_4(iC, iabcd, ixyzn);
#else
#else
FUNC_
D
(
abcd2xyzn
,
D
IN
)(
iabcd
,
ixyzn
);
FUNC_
R(abcd2xyzn, R
IN)(iabcd, ixyzn);
#endif
#endif
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
} else {
...
@@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
...
@@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
outTexture.write(r, gid.xy, gid.z);
outTexture.write(r, gid.xy, gid.z);
}
}
#endif
metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal
浏览文件 @
e71320da
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CON
DITIONS OF ANY KIND
, either express or implied.
WITHOUT WARRANTIES OR CON
RITIONS OF ANY KINR
, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
...
@@ -25,127 +25,126 @@ struct ReshapeParam {
...
@@ -25,127 +25,126 @@ struct ReshapeParam {
};
};
#define P float
#define P float
#define
D
IN 4
#define
R
IN 4
#define
D
OUT 4
#define
R
OUT 4
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 3
#define
R
OUT 3
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 2
#define
R
OUT 2
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 1
#define
R
OUT 1
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#undef
D
IN
#undef
R
IN
#define
D
IN 3
#define
R
IN 3
#define
D
OUT 4
#define
R
OUT 4
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 3
#define
R
OUT 3
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 2
#define
R
OUT 2
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 1
#define
R
OUT 1
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#undef
D
IN
#undef
R
IN
#define
D
IN 2
#define
R
IN 2
#define
D
OUT 4
#define
R
OUT 4
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 3
#define
R
OUT 3
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 2
#define
R
OUT 2
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 1
#define
R
OUT 1
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#undef
D
IN
#undef
R
IN
#define
D
IN 1
#define
R
IN 1
#define
D
OUT 4
#define
R
OUT 4
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 3
#define
R
OUT 3
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 2
#define
R
OUT 2
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#define
D
OUT 1
#define
R
OUT 1
#include "ReshapeKernel.
metal.inc
"
#include "ReshapeKernel.
inc.metal
"
#undef
D
OUT
#undef
R
OUT
#undef
D
IN
#undef
R
IN
#undef P
#undef P
#define P half
#define P half
#define DIN 4
#define RIN 4
#define DOUT 4
#define ROUT 4
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 3
#define ROUT 3
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 2
#define ROUT 2
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 1
#define ROUT 1
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#undef DIN
#undef RIN
#define DIN 3
#define RIN 3
#define DOUT 4
#define ROUT 4
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 3
#define ROUT 3
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 2
#define ROUT 2
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 1
#define ROUT 1
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#undef DIN
#undef RIN
#define DIN 2
#define RIN 2
#define DOUT 4
#define ROUT 4
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 3
#define ROUT 3
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 2
#define ROUT 2
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 1
#define ROUT 1
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#undef DIN
#undef RIN
#define DIN 1
#define RIN 1
#define DOUT 4
#define ROUT 4
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 3
#define ROUT 3
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 2
#define ROUT 2
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#define DOUT 1
#define ROUT 1
#include "ReshapeKernel.metal.inc"
#include "ReshapeKernel.inc.metal"
#undef DOUT
#undef ROUT
#undef DIN
#undef RIN
#undef P
#undef P
metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift
浏览文件 @
e71320da
...
@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam {
...
@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam {
}
}
output
.
padToFourDim
=
Dim
.
init
(
inDim
:
dim
)
output
.
padToFourDim
=
Dim
.
init
(
inDim
:
dim
)
output
.
dim
=
output
.
padToFourDim
output
.
dim
=
output
.
padToFourDim
// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
}
}
let
input
:
Texture
<
P
>
let
input
:
Texture
<
P
>
let
shape
:
[
Int32
]
let
shape
:
[
Int32
]
// let inplace: Bool
var
output
:
Texture
<
P
>
var
output
:
Texture
<
P
>
}
}
...
...
metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift
浏览文件 @
e71320da
...
@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam {
...
@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam {
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
output
=
try
ShapeParam
.
output
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
input
=
try
ShapeParam
.
input
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
ShapeParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
}
}
var
output
:
Texture
<
P
>
var
output
:
Texture
<
P
>
let
input
:
Texture
<
P
>
}
}
class
ShapeOp
<
P
:
PrecisionType
>
:
Operator
<
S
plitKernel
<
P
>
,
Split
Param
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
ShapeOp
<
P
:
PrecisionType
>
:
Operator
<
S
hapeKernel
<
P
>
,
Shape
Param
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
typealias
OpType
=
S
plit
Op
<
P
>
typealias
OpType
=
S
hape
Op
<
P
>
func
inferShape
()
{
func
inferShape
()
{
// para.output.dim = para.input.dim
// para.output.dim = para.input.dim
...
...
metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift
浏览文件 @
e71320da
...
@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam {
...
@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam {
typealias
ParamPrecisionType
=
P
typealias
ParamPrecisionType
=
P
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
required
init
(
opDesc
:
OpDesc
,
inScope
:
Scope
)
throws
{
do
{
do
{
// output = try SplitParam.output(outputs: opDesc.outputs, from: inScope)
input
=
try
SplitParam
.
inputX
(
inputs
:
opDesc
.
inputs
,
from
:
inScope
)
output
=
try
SplitParam
.
outputOut
(
outputs
:
opDesc
.
outputs
,
from
:
inScope
)
output
=
Texture
<
P
>.
init
(
device
:
input
.
metalTexture
!.
device
,
inDim
:
input
.
dim
)
axis
=
try
SplitParam
.
getAttr
(
key
:
"axis"
,
attrs
:
opDesc
.
attrs
)
sections
=
try
SplitParam
.
getAttr
(
key
:
"sections"
,
attrs
:
opDesc
.
attrs
)
if
axis
<
0
{
axis
=
input
.
tensorDim
.
cout
()
+
axis
}
guard
let
outlist
=
opDesc
.
outputs
[
"Out"
]
else
{
fatalError
()
}
for
out
in
outlist
{
guard
let
variant
=
inScope
[
out
],
let
v
=
variant
as?
Texture
<
P
>
else
{
fatalError
()
}
outputList
.
append
(
v
)
sections
.
append
(
Int32
(
v
.
tensorDim
.
dims
[
axis
]))
}
}
catch
let
error
{
}
catch
let
error
{
throw
error
throw
error
}
}
}
}
var
axis
:
Int
let
input
:
Texture
<
P
>
var
output
:
Texture
<
P
>
var
output
:
Texture
<
P
>
var
outputList
:
[
Texture
<
P
>
]
=
[]
var
sections
:
[
Int32
]
=
[]
}
}
class
SplitOp
<
P
:
PrecisionType
>
:
Operator
<
SplitKernel
<
P
>
,
SplitParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
class
SplitOp
<
P
:
PrecisionType
>
:
Operator
<
SplitKernel
<
P
>
,
SplitParam
<
P
>>
,
Runable
,
Creator
,
InferShaperable
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录