提交 20f5ae65 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1.add flatten2, slice and nearest_interp op. 2.truncate tensor with more than...

1.add flatten2, slice and nearest_interp op. 2.truncate tensor with more than 4 dimensions to 4. (#1634)
上级 1b7329e2
......@@ -7,7 +7,9 @@
objects = {
/* Begin PBXBuildFile section */
16324D862292C4930047277D /* NearestInterpKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16324D852292C4930047277D /* NearestInterpKernel.metal */; };
165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; };
16D3F3BB22929EAD0067C45D /* SliceKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3BA22929EAD0067C45D /* SliceKernel.metal */; };
16FBFB3E22925D040025B406 /* ActivationKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3D22925D040025B406 /* ActivationKernel.metal */; };
5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; };
A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */; };
......@@ -55,7 +57,9 @@
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
16324D852292C4930047277D /* NearestInterpKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = NearestInterpKernel.metal; sourceTree = "<group>"; };
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = "<group>"; };
16D3F3BA22929EAD0067C45D /* SliceKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = SliceKernel.metal; sourceTree = "<group>"; };
16FBFB3D22925D040025B406 /* ActivationKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ActivationKernel.metal; sourceTree = "<group>"; };
33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = "<group>"; };
5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; };
......@@ -199,6 +203,8 @@
A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */,
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */,
16FBFB3D22925D040025B406 /* ActivationKernel.metal */,
16D3F3BA22929EAD0067C45D /* SliceKernel.metal */,
16324D852292C4930047277D /* NearestInterpKernel.metal */,
);
path = "paddle-mobile-metallib";
sourceTree = "<group>";
......@@ -316,12 +322,14 @@
FCC15E0B221E69E100DC3CB2 /* FetchKernel.inc.metal in Sources */,
FCC15DEE221E69E100DC3CB2 /* ConvTransposeKernel.metal in Sources */,
FCC15DFC221E69E100DC3CB2 /* ConvAddPreluKernel.metal in Sources */,
16D3F3BB22929EAD0067C45D /* SliceKernel.metal in Sources */,
FCC15E06221E69E100DC3CB2 /* BoxCoder.inc.metal in Sources */,
FCC15DF1221E69E100DC3CB2 /* BilinearInterp.inc.metal in Sources */,
FCC15E08221E69E100DC3CB2 /* Split.inc.metal in Sources */,
FCC15DF4221E69E100DC3CB2 /* ResizeBilinear.metal in Sources */,
FCC15E05221E69E100DC3CB2 /* BatchNormKernel.metal in Sources */,
165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */,
16324D862292C4930047277D /* NearestInterpKernel.metal in Sources */,
FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */,
FCC15DF6221E69E100DC3CB2 /* PoolKernel.metal in Sources */,
FCC15E09221E69E100DC3CB2 /* ConcatKernel.inc.metal in Sources */,
......
......@@ -117,4 +117,3 @@ struct MetalConvParam {
ushort dilationX;
ushort dilationY;
};
......@@ -67,6 +67,19 @@ struct ConcatParam {
#undef R
#undef V
// lens: (R=4, N=3, V=x)
#define V VX
#define R 4
#define N 3
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=3, N=2, V=y)
#define V VY
......@@ -96,6 +109,19 @@ struct ConcatParam {
#undef R
#undef V
// lens: (R=4, N=2, V=z)
#define V VZ
#define R 4
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=2, N=6, V=y)
#define V VY
......@@ -138,6 +164,19 @@ struct ConcatParam {
#undef R
#undef V
// lens: (R=2, N=3, V=normal)
#define V VNORMAL
#define R 2
#define N 3
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VY
#define R 2
......@@ -165,7 +204,3 @@ struct ConcatParam {
#undef N
#undef R
#undef V
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct NearestInterpParam {
float scale;
};
kernel void nearest_interp(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant NearestInterpParam &param [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float scale = param.scale;
uint x = uint(round(float(gid.x) / scale));
uint y = uint(round(float(gid.y) / scale));
const float4 input = inTexture.read(uint2(x, y), gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void nearest_interp_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant NearestInterpParam &param [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float scale = param.scale;
uint x = uint(round(float(gid.x) / scale));
uint y = uint(round(float(gid.y) / scale));
const half4 input = inTexture.read(uint2(x, y), gid.z);
outTexture.write(input, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct MetalSliceParam {
short start0;
short start1;
short start2;
short start3;
short end0;
short end1;
short end2;
short end3;
};
kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalSliceParam &param [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output;
for (int i = 0; i < 4; i++) {
int input_c = gid.z * 4 + i + param.start1;
int input_z = input_c / 4;
const float4 input = inTexture.read(gid.xy, input_z);
output[i] = input[input_c % 4];
}
outTexture.write(output, gid.xy, gid.z);
}
kernel void slice_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalSliceParam &param [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output;
for (int i = 0; i < 4; i++) {
int input_c = gid.z * 4 + i + param.start1;
int input_z = input_c / 4;
const float4 input = float4(inTexture.read(gid.xy, input_z));
output[i] = input[input_c % 4];
}
outTexture.write(half4(output), gid.xy, gid.z);
}
......@@ -7,8 +7,12 @@
objects = {
/* Begin PBXBuildFile section */
16324D842292ABDB0047277D /* NearestInterpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16324D832292ABDB0047277D /* NearestInterpKernel.swift */; };
165F38D32276CDEA0088E29F /* ConvAddReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */; };
165F38D52276CE7D0088E29F /* ConvAddReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */; };
16D3F3B522929C390067C45D /* SliceOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3B422929C390067C45D /* SliceOp.swift */; };
16D3F3B722929C660067C45D /* NearestInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3B622929C660067C45D /* NearestInterpOp.swift */; };
16D3F3B922929D070067C45D /* SliceKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3B822929D070067C45D /* SliceKernel.swift */; };
16FBFB36229259C60025B406 /* ExpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB35229259C60025B406 /* ExpOp.swift */; };
16FBFB3822925B030025B406 /* ExpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3722925B030025B406 /* ExpKernel.swift */; };
16FBFB3A22925C3E0025B406 /* SigmoidKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */; };
......@@ -113,8 +117,12 @@
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
16324D832292ABDB0047277D /* NearestInterpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NearestInterpKernel.swift; sourceTree = "<group>"; };
165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluOp.swift; sourceTree = "<group>"; };
165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluKernel.swift; sourceTree = "<group>"; };
16D3F3B422929C390067C45D /* SliceOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SliceOp.swift; sourceTree = "<group>"; };
16D3F3B622929C660067C45D /* NearestInterpOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NearestInterpOp.swift; sourceTree = "<group>"; };
16D3F3B822929D070067C45D /* SliceKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SliceKernel.swift; sourceTree = "<group>"; };
16FBFB35229259C60025B406 /* ExpOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ExpOp.swift; sourceTree = "<group>"; };
16FBFB3722925B030025B406 /* ExpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ExpKernel.swift; sourceTree = "<group>"; };
16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SigmoidKernel.swift; sourceTree = "<group>"; };
......@@ -353,6 +361,8 @@
16FBFB35229259C60025B406 /* ExpOp.swift */,
16FBFB3B22925C800025B406 /* SigmoidOp.swift */,
16FBFB3F229266FE0025B406 /* LeakyReluOp.swift */,
16D3F3B422929C390067C45D /* SliceOp.swift */,
16D3F3B622929C660067C45D /* NearestInterpOp.swift */,
);
path = Operators;
sourceTree = "<group>";
......@@ -413,6 +423,8 @@
16FBFB3722925B030025B406 /* ExpKernel.swift */,
16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */,
16FBFB412292684E0025B406 /* LeakyReluKernel.swift */,
16D3F3B822929D070067C45D /* SliceKernel.swift */,
16324D832292ABDB0047277D /* NearestInterpKernel.swift */,
);
path = Kernels;
sourceTree = "<group>";
......@@ -567,12 +579,14 @@
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
16D3F3B522929C390067C45D /* SliceOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */,
FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* PMProgramDesc.swift in Sources */,
FCE3A1AB2153DE8C00C37CDE /* ConvAddAddPreluKernel.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
16324D842292ABDB0047277D /* NearestInterpKernel.swift in Sources */,
FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */,
FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */,
FC9797CB21D6102D00F2FD90 /* ResizeBilinearKernel.swift in Sources */,
......@@ -605,6 +619,7 @@
4AA1EA8E2146647F00D0F791 /* SplitKernel.swift in Sources */,
FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */,
FC1CF3F721D4B4C400F7392E /* Runner.swift in Sources */,
16D3F3B722929C660067C45D /* NearestInterpOp.swift in Sources */,
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FCD04E6620F314C50007374F /* PoolOp.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
......@@ -625,6 +640,7 @@
FC803BC1214CB77A0094B8E5 /* ConvAddPreluKernel.swift in Sources */,
FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */,
FCE3A1AF2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift in Sources */,
16D3F3B922929D070067C45D /* SliceKernel.swift in Sources */,
FCE9D7B7214F869000B520C3 /* Net.swift in Sources */,
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
......
......@@ -77,6 +77,9 @@ class OpCreator<P: PrecisionProtocol> {
gExpType : ExpOp<P>.creat,
gSigmoidType : SigmoidOp<P>.creat,
gLeakyReluType : LeakyReluOp<P>.creat,
gFlatten2Type : Flatten2Op<P>.creat,
gSliceType : SliceOp<P>.creat,
gNearestInterpType : NearestInterpOp<P>.creat,
]
......
......@@ -188,6 +188,9 @@ let gRelu6Type = "relu6"
let gExpType = "exp"
let gSigmoidType = "sigmoid"
let gLeakyReluType = "leaky_relu"
let gFlatten2Type = "flatten2"
let gSliceType = "slice"
let gNearestInterpType = "nearest_interp"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
......@@ -226,4 +229,7 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out
gExpType : (inputs: ["X"], outputs: ["Out"]),
gSigmoidType : (inputs: ["X"], outputs: ["Out"]),
gLeakyReluType : (inputs: ["X"], outputs: ["Out"]),
gFlatten2Type : (inputs: ["X"], outputs: ["Out"]),
gSliceType : (inputs: ["Input"], outputs: ["Out"]),
gNearestInterpType : (inputs: ["X"], outputs: ["Out"]),
]
......@@ -20,14 +20,14 @@ class FlattenParam<P: PrecisionProtocol>: OpParam {
do {
input = try FlattenParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try FlattenParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs)
// axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture
var input: Texture
var output: Texture
let axis: Int
var axis: Int = 0
}
......@@ -56,8 +56,44 @@ class FlattenOp<P: PrecisionProtocol>: Operator<FlattenKernel<P>, FlattenParam<P
}
class Flatten2Param<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try Flatten2Param.inputX(inputs: opDesc.inputs, from: inScope)
output = try Flatten2Param.outputOut(outputs: opDesc.outputs, from: inScope)
let inDims = input.dim
guard inDims.cout() == 4 else {
fatalError("flatten2 can't handle dims not equal to 4")
}
let outDim = [inDims[0] * inDims[1], inDims[2] * inDims[3]]
output.dim = Dim.init(inDim: outDim)
} catch let error {
throw error
}
}
var input: Texture
var output: Texture
}
class Flatten2Op<P: PrecisionProtocol>: Operator<Flatten2Kernel<P>, Flatten2Param<P>>, Runable, Creator, InferShaperable {
typealias OpType = Flatten2Op<P>
func inferShape() {
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
let device = para.output.metalTexture!.device
let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
print(outputArray.strideArray())
}
}
......@@ -22,7 +22,7 @@ struct FlattenMetalParam {
}
class FlattenKernel<P: PrecisionProtocol>: Kernel, Computable{
class FlattenKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: FlattenMetalParam
......@@ -75,3 +75,57 @@ class FlattenKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding()
}
}
class Flatten2Kernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: Flatten2Param<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
}
let it: [Int32] = param.input.transpose.map { Int32($0) }
var od: [Int32] = [1, 1, 1, 1]
for i in 0..<param.output.tensorDim.cout() {
od[4-param.output.tensorDim.cout()+i] = Int32(param.output.tensorDim[i])
}
let ot: [Int32] = param.output.transpose.map { Int32($0) }
metalParam = FlattenMetalParam.init(
idim: (id[0], id[1], id[2], id[3]),
itrans: (it[0], it[1], it[2], it[3]),
odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3])
)
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
assert(orank == 2)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_float", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_half", initContext: initContext)
} else {
fatalError()
}
}
func compute(commandBuffer: MTLCommandBuffer, param: Flatten2Param<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct NearestInterpMetalParam {
let scale: Float32
}
class NearestInterpKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: NearestInterpMetalParam
func compute(commandBuffer: MTLCommandBuffer, param: NearestInterpParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<NearestInterpMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: NearestInterpParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = NearestInterpMetalParam(scale: param.scale)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "nearest_interp", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "nearest_interp_half", initContext: initContext)
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
public struct SliceMetalParam {
let start0: Int16
let start1: Int16
let start2: Int16
let start3: Int16
let end0: Int16
let end1: Int16
let end2: Int16
let end3: Int16
}
class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: SliceMetalParam
func compute(commandBuffer: MTLCommandBuffer, param: SliceParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<SliceMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: SliceParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var ranges = [[Int16]]()
for i in 0..<4 {
if let range = param.ranges[i] {
ranges.append(range)
} else {
ranges.append([0, Int16(param.input.tensorDim[i])])
}
}
let start0 = ranges[0][0]
let start1 = ranges[1][0]
let start2 = ranges[2][0]
let start3 = ranges[3][0]
let end0 = ranges[0][1]
let end1 = ranges[1][1]
let end2 = ranges[2][1]
let end3 = ranges[3][1]
metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "slice", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "slice_half", initContext: initContext)
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class NearestInterpParam<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try NearestInterpParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try NearestInterpParam.outputOut(outputs: opDesc.outputs, from: inScope)
let inputDim = input.tensorDim
let outputDim = output.tensorDim
guard inputDim.cout() == 4 && outputDim.cout() == 4 && inputDim[0] == outputDim[0] && inputDim[1] == outputDim[1] else {
fatalError("nearest interp only support scale along width and height")
}
let scaleX = Float32(outputDim[2]) / Float32(inputDim[2])
let scaleY = Float32(outputDim[3]) / Float32(inputDim[3])
guard abs(scaleX - scaleY) <= 0.00001 else {
fatalError("nearest interp only support same scale factor")
}
scale = scaleX
print("ok")
} catch let error {
throw error
}
}
var input: Texture
var output: Texture
var scale: Float32
}
class NearestInterpOp<P: PrecisionProtocol>: Operator<NearestInterpKernel<P>, NearestInterpParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = NearestInterpOp<P>
func inferShape() {
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
let device = para.output.metalTexture!.device
let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
print(outputArray.strideArray())
}
}
......@@ -23,6 +23,12 @@ class ReshapeParam<P: PrecisionProtocol>: OpParam {
output = try ReshapeParam.outputOut(outputs: opDesc.outputs, from: inScope)
shape = try ReshapeParam.getAttr(key: "shape", attrs: opDesc.attrs)
if shape.count > 4 {
if shape[0] == -1 {
shape.removeFirst()
}
}
var s: [Int] = shape.map { Int($0) }
var di = -1
......@@ -49,7 +55,7 @@ class ReshapeParam<P: PrecisionProtocol>: OpParam {
}
}
let input: Texture
let shape: [Int32]
var shape: [Int32]
var output: Texture
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class SliceParam<P: PrecisionProtocol>: OpParam {
//typealias ParamPrecisionType = P
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try SliceParam.input(inputs: opDesc.inputs, from: inScope)
output = try SliceParam.outputOut(outputs: opDesc.outputs, from: inScope)
starts = try SliceParam.getAttr(key: "starts", attrs: opDesc.attrs)
ends = try SliceParam.getAttr(key: "ends", attrs: opDesc.attrs)
for i in 0..<input.tensorDim.cout() {
if input.tensorDim[i] != output.tensorDim[i] {
axes.append(Int32(i))
}
}
guard axes.count == 1 && axes[0] == 1 else {
fatalError("slice only support channel axe")
}
for i in 0..<axes.count {
ranges[Int(axes[i])] = [Int16(starts[i]), Int16(ends[i])]
}
} catch let error {
throw error
}
}
let input: Texture
var output: Texture
let starts: [Int32]
let ends: [Int32]
var axes = [Int32]()
var ranges = [Int: [Int16]]()
}
class SliceOp<P: PrecisionProtocol>: Operator<SliceKernel<P>, SliceParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = SliceOp<P>
func inferShape() {
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print("\(type) output : ")
print(para.output.toTensor().strideArray())
}
}
......@@ -62,6 +62,19 @@ class TensorDesc {
let dim = Int(protoTensorDesc.dimsArray.value(at: i)) > 0 ?Int(protoTensorDesc.dimsArray.value(at: i)) :abs(Int(protoTensorDesc.dimsArray.value(at: i)))
dimsArray.append(dim)
}
if dimsCount > 4 {
let headDims = Int(dimsCount - 4)
for i in 0..<headDims {
guard dimsArray[i] <= 1 else {
fatalError("dims count is larger than 4 and can't be truncated to 4")
}
}
for _ in 0..<headDims {
dimsArray.removeFirst()
}
}
dims = dimsArray
dataType = VarTypeType.init(rawValue: Int(protoTensorDesc.dataType.rawValue)) ?? .ErrorType
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册