提交 11d8528b 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #856 from dolphin8/metal

Metal
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; };
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; };
D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; }; D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; };
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */ = {isa = PBXBuildFile; fileRef = FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */; settings = {ATTRIBUTES = (Public, ); }; }; FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */ = {isa = PBXBuildFile; fileRef = FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */; settings = {ATTRIBUTES = (Public, ); }; };
FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B9420E11C9A0081E9F8 /* Extensions.swift */; }; FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B9420E11C9A0081E9F8 /* Extensions.swift */; };
...@@ -90,6 +92,8 @@ ...@@ -90,6 +92,8 @@
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = "<group>"; };
4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = "<group>"; };
CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = "<group>"; }; CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = "<group>"; };
DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
E2A7957C92EDA5C3BEC0FFC2 /* Pods-paddle-mobile.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.release.xcconfig"; sourceTree = "<group>"; }; E2A7957C92EDA5C3BEC0FFC2 /* Pods-paddle-mobile.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.release.xcconfig"; sourceTree = "<group>"; };
...@@ -355,6 +359,8 @@ ...@@ -355,6 +359,8 @@
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC27990D21341016000B6BAD /* BoxCoder.metal */, FC27990D21341016000B6BAD /* BoxCoder.metal */,
4AF928812135673D005B6C3A /* Concat.metal */,
4AF9288321357BE3005B6C3A /* Elementwise.metal */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
4AF928762133F1DB005B6C3A /* BoxCoder.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */,
...@@ -478,6 +484,7 @@ ...@@ -478,6 +484,7 @@
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */, FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */,
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */, FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */,
FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */, FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
...@@ -515,6 +522,7 @@ ...@@ -515,6 +522,7 @@
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
4AF928822135673D005B6C3A /* Concat.metal in Sources */,
FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */, FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */,
FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */, FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */,
FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */,
......
...@@ -25,6 +25,13 @@ class ConcatParam<P: PrecisionType>: OpParam { ...@@ -25,6 +25,13 @@ class ConcatParam<P: PrecisionType>: OpParam {
guard let variant = inScope[x], let v = variant as? Texture<P> else { guard let variant = inScope[x], let v = variant as? Texture<P> else {
fatalError() fatalError()
} }
if transpose.count == 0 {
transpose = v.transpose
}
if v.transpose != transpose {
fatalError()
}
input.append(v) input.append(v)
} }
axis = try ConcatParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ConcatParam.getAttr(key: "axis", attrs: opDesc.attrs)
...@@ -35,6 +42,7 @@ class ConcatParam<P: PrecisionType>: OpParam { ...@@ -35,6 +42,7 @@ class ConcatParam<P: PrecisionType>: OpParam {
} }
var input: [Texture<P>] = [] var input: [Texture<P>] = []
var output: Texture<P> var output: Texture<P>
var transpose: [Int] = []
let axis: Int let axis: Int
} }
......
...@@ -18,36 +18,42 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -18,36 +18,42 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
} catch _ {
do {
inputYTexture = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
} catch let error {
throw error
}
}
do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
do {
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch _ {
let tensorY: Tensor<P> = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
let device = inputX.metalTexture!.device
inputY = Texture.init(device: device, inDim: tensorY.dim)
let value: [P] = Array(UnsafeBufferPointer(start: tensorY.data.pointer, count: tensorY.dim.numel()))
inputY.metalTexture = device.tensor2texture(value: value, dim: tensorY.dim.dims)
}
var offset = axis
if axis == -1 {
offset = inputX.tensorDim.cout() - inputY.tensorDim.cout()
}
for i in 0..<(inputY.tensorDim.cout()) {
assert(inputX.tensorDim[offset + i] == inputY.tensorDim[i])
}
} }
var inputYTexture: Texture<P>? var inputX: Texture<P>
var inputY: Tensor<P>? var inputY: Texture<P>
var input: Texture<P>
var output: Texture<P> var output: Texture<P>
let axis: Int var axis: Int
} }
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = ElementwiseAddOp<P> typealias OpType = ElementwiseAddOp<P>
func inferShape() { func inferShape() {
para.output.dim = para.input.dim // para.output.dim = para.input.dim
} }
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
......
...@@ -15,113 +15,117 @@ ...@@ -15,113 +15,117 @@
import Foundation import Foundation
struct ConcatTestParam: TestParam { struct ConcatTestParam: TestParam {
var input: [MTLTexture] var input: [MTLTexture]
var output: MTLTexture var output: MTLTexture
var dims: [[Int]] var dims: [[Int]]
var axis: Int var axis: Int
var odim: [Int] var odim: [Int]
} }
struct ConcatMetalParam { struct ConcatMetalParam {
var odim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1) var odim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1)
var axis: Int32 = 0 var axis: Int32 = 0
var offset: Int32 = 0 var offset: Int32 = 0
var vdim: (Int32, Int32, Int32, Int32, Int32, Int32) = (0, 0, 0, 0, 0, 0) var trans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
var vdim: (Int32, Int32, Int32, Int32, Int32, Int32) = (0, 0, 0, 0, 0, 0)
} }
class ConcatKernel<P: PrecisionType>: Kernel, Computable{ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
func encodeTest(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatTestParam, _ istart: Int, _ iend: Int) { func encodeTest(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatTestParam, _ istart: Int, _ iend: Int) {
let encoder = cmdBuffer.makeComputeCommandEncoder()! let encoder = cmdBuffer.makeComputeCommandEncoder()!
var p = ConcatMetalParam.init() var p = ConcatMetalParam.init()
var odim: [Int32] = [1, 1, 1, 1] var odim: [Int32] = [1, 1, 1, 1]
for i in 0..<param.odim.count { for i in 0..<param.odim.count {
odim[4-param.odim.count+i] = Int32(param.odim[i]) odim[4-param.odim.count+i] = Int32(param.odim[i])
}
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.odim.count + param.axis)
for i in 0..<istart {
p.offset += Int32(param.dims[i][param.axis])
}
var vdim: [Int32] = []
for i in 0..<(iend - istart) {
encoder.setTexture(param.input[i+istart], index: i)
vdim.append(Int32(param.dims[i+istart][Int(param.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0], index: i)
vdim.append(0)
}
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output, index: 6)
encoder.setTexture(param.output, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding()
} }
p.odim = (odim[0], odim[1], odim[2], odim[3])
func encode(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatParam<P>, _ istart: Int, _ iend: Int) throws { p.axis = Int32(4 - param.odim.count + param.axis)
guard let encoder = cmdBuffer.makeComputeCommandEncoder() else { for i in 0..<istart {
throw PaddleMobileError.predictError(message: " encode is nil") p.offset += Int32(param.dims[i][param.axis])
}
var p = ConcatMetalParam.init()
let odim = (0..<4).map { Int32(param.output.dim[$0]) }
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis)
for i in 0..<istart {
p.offset += Int32(param.input[i+istart].dim[Int(p.axis)])
}
var vdim: [Int32] = []
for i in 0..<(iend - istart) {
encoder.setTexture(param.input[i+istart].metalTexture, index: i)
vdim.append(Int32(param.input[i+istart].dim[Int(p.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0].metalTexture, index: i)
vdim.append(0)
}
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output.metalTexture, index: 6)
encoder.setTexture(param.output.metalTexture, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
} }
var vdim: [Int32] = []
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws { for i in 0..<(iend - istart) {
for i in 0..<param.input.count { encoder.setTexture(param.input[i+istart], index: i)
for j in 0..<4 { vdim.append(Int32(param.dims[i+istart][Int(param.axis)]))
assert(param.input[i].transpose[j] == j)
}
}
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
try self.encode(commandBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encode(commandBuffer, param, 6 * group, param.input.count)
}
} }
for i in (iend-istart)..<6 {
func test(cmdBuffer: MTLCommandBuffer, param: ConcatTestParam) { encoder.setTexture(param.input[0], index: i)
let group = param.input.count / 6 vdim.append(0)
let remain = param.input.count % 6
for i in 0..<group {
try self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encodeTest(cmdBuffer, param, 6 * group, param.input.count)
}
} }
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
required init(device: MTLDevice, param: ConcatParam<P>) { encoder.setTexture(param.output, index: 6)
param.output.initTexture(device: device) encoder.setTexture(param.output, index: 7)
super.init(device: device, inFunctionName: "concat") encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding()
}
func encode(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatParam<P>, _ istart: Int, _ iend: Int) throws {
guard let encoder = cmdBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
var p = ConcatMetalParam.init()
let odim = (0..<4).map { Int32(param.output.dim[$0]) }
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis)
for i in 0..<4 {
if Int32(param.transpose[i]) == p.axis {
p.axis = Int32(i)
break
}
}
for i in 0..<istart {
p.offset += Int32(param.input[i+istart].dim[Int(p.axis)])
} }
var vdim: [Int32] = []
for i in 0..<(iend - istart) {
encoder.setTexture(param.input[i+istart].metalTexture, index: i)
vdim.append(Int32(param.input[i+istart].dim[Int(p.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0].metalTexture, index: i)
vdim.append(0)
}
p.trans = (Int32(param.transpose[0]), Int32(param.transpose[1]), Int32(param.transpose[2]), Int32(param.transpose[3]))
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output.metalTexture, index: 6)
encoder.setTexture(param.output.metalTexture, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
required init(device: MTLDevice, testParam: ConcatTestParam) { let group = param.input.count / 6
super.init(device: device, inFunctionName: "concat") let remain = param.input.count % 6
for i in 0..<group {
try self.encode(commandBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encode(commandBuffer, param, 6 * group, param.input.count)
}
}
func test(cmdBuffer: MTLCommandBuffer, param: ConcatTestParam) {
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
self.encodeTest(cmdBuffer, param, 6 * group, param.input.count)
} }
}
required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device, inTranspose: param.transpose)
super.init(device: device, inFunctionName: "concat")
}
required init(device: MTLDevice, testParam: ConcatTestParam) {
super.init(device: device, inFunctionName: "concat")
}
} }
...@@ -43,6 +43,8 @@ class ConvTransposeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -43,6 +43,8 @@ class ConvTransposeKernel<P: PrecisionType>: Kernel, Computable{
let dilationY = UInt16(param.dilations[1]) let dilationY = UInt16(param.dilations[1])
metalParam = MetalConvTransposeParam.init(kernelW: kernelWidth, kernelH: kernelHeight, strideX: strideX, strideY: strideY, paddingX: paddingX, paddingY: paddingY, dilationX: dilationX, dilationY: dilationY) metalParam = MetalConvTransposeParam.init(kernelW: kernelWidth, kernelH: kernelHeight, strideX: strideX, strideY: strideY, paddingX: paddingX, paddingY: paddingY, dilationX: dilationX, dilationY: dilationY)
param.output.initTexture(device: device, inTranspose: param.input.transpose)
} }
func compute(commandBuffer: MTLCommandBuffer, param: ConvTransposeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ConvTransposeParam<P>) throws {
......
...@@ -14,14 +14,52 @@ ...@@ -14,14 +14,52 @@
import Foundation import Foundation
struct ElementwiseAddMetalParam {
var fast: Int32 = 0
var axis: Int32 = 0
var yoff: Int32 = 0
var xdim: (Int32, Int32, Int32, Int32) = (0, 0, 0, 0)
var xtrans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
var ydim: (Int32, Int32, Int32, Int32) = (0, 0, 0, 0)
var ytrans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
}
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable { class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ElementwiseAddParam<P>) { required init(device: MTLDevice, param: ElementwiseAddParam<P>) {
super.init(device: device, inFunctionName: "elementwise_add") super.init(device: device, inFunctionName: "elementwise_add")
param.output.initTexture(device: device, inTranspose: param.input.transpose) param.output.initTexture(device: device, inTranspose: param.inputX.transpose)
} }
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
var emp = ElementwiseAddMetalParam.init()
encoder.setTexture(param.inputX.metalTexture, index: 0)
encoder.setTexture(param.inputY.metalTexture, index: 1)
encoder.setTexture(param.output.metalTexture, index: 2)
let xdim: [Int32] = (0..<4).map { Int32(param.inputX.dim[$0]) }
let ydim: [Int32] = (0..<4).map { Int32(param.inputY.dim[$0]) }
let xtrans: [Int32] = (0..<4).map { Int32(param.inputX.transpose[$0]) }
let ytrans: [Int32] = (0..<4).map { Int32(param.inputY.transpose[$0]) }
emp.xdim = (xdim[0], xdim[1], xdim[2], xdim[3])
emp.ydim = (ydim[0], ydim[1], ydim[2], ydim[3])
emp.xtrans = (xtrans[0], xtrans[1], xtrans[2], xtrans[3])
emp.ytrans = (ytrans[0], ytrans[1], ytrans[2], ytrans[3])
if param.axis == -1 {
emp.axis = 4 - Int32(param.inputY.tensorDim.cout())
} else {
emp.axis = 4 - Int32(param.inputX.tensorDim.cout()) + Int32(param.axis)
}
emp.yoff = 4 - Int32(param.inputY.tensorDim.cout())
if (param.inputX.dim == param.inputY.dim) && (param.inputX.transpose == param.inputY.transpose) {
emp.fast = 1
}
encoder.setBytes(&emp, length: MemoryLayout<ElementwiseAddMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[6];
};
kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
texture2d_array<float, access::read> in1 [[texture(1)]],
texture2d_array<float, access::read> in2 [[texture(2)]],
texture2d_array<float, access::read> in3 [[texture(3)]],
texture2d_array<float, access::read> in4 [[texture(4)]],
texture2d_array<float, access::read> in5 [[texture(5)]],
texture2d_array<float, access::read> inx [[texture(6)]],
texture2d_array<float, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
float4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ElementwiseAddParam {
int32_t fast;
int32_t axis;
int32_t yoff;
int32_t xdim[4];
int32_t xtrans[4];
int32_t ydim[4];
int32_t ytrans[4];
};
kernel void elementwise_add(texture2d_array<float, access::read> inputX [[texture(0)]],
texture2d_array<float, access::read> inputY [[texture(1)]],
texture2d_array<float, access::write> outTexture [[texture(2)]],
constant ElementwiseAddParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
float4 rx, ry;
if (pm.fast == 1) {
rx = inputX.read(gid.xy, gid.z);
ry = inputY.read(gid.xy, gid.z);
} else {
rx = inputX.read(gid.xy, gid.z);
int32_t x_xyzn[4] = {int32_t(gid.x), int32_t(gid.y), int32_t(gid.z), 0}, x_abcd[4], t_abcd[4];
int32_t y_abcd[4] = {1, 1, 1, 1}, y_xyzn[4];
int32_t xtrans[4] = {pm.xtrans[0], pm.xtrans[1], pm.xtrans[2], pm.xtrans[3]};
int32_t ytrans[4] = {pm.ytrans[0], pm.ytrans[1], pm.ytrans[2], pm.ytrans[3]};
for (int n = 0; n < 4; n++) {
xyzn2abcd(pm.xdim[3], x_xyzn, x_abcd);
invtrans(xtrans, x_abcd, t_abcd);
for (int k = pm.axis; k < (4 - pm.yoff); k++) {
y_abcd[k+pm.yoff] = t_abcd[k];
}
trans(ytrans, y_abcd, t_abcd);
abcd2xyzn(pm.ydim[3], t_abcd, y_xyzn);
ry[n] = inputY.read(uint2(y_xyzn[0], y_xyzn[1]), y_xyzn[2])[y_xyzn[3]];
}
}
float4 r = rx + ry;
outTexture.write(r, gid.xy, gid.z);
}
...@@ -43,17 +43,6 @@ kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]], ...@@ -43,17 +43,6 @@ kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z); outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z);
} }
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
//kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]], //kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]],
...@@ -200,55 +189,3 @@ kernel void transpose(texture2d_array<float, access::read> inTexture [[texture(0 ...@@ -200,55 +189,3 @@ kernel void transpose(texture2d_array<float, access::read> inTexture [[texture(0
outTexture.write(r, gid.xy, gid.z); outTexture.write(r, gid.xy, gid.z);
} }
} }
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t vdim[6];
};
kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
texture2d_array<float, access::read> in1 [[texture(1)]],
texture2d_array<float, access::read> in2 [[texture(2)]],
texture2d_array<float, access::read> in3 [[texture(3)]],
texture2d_array<float, access::read> in4 [[texture(4)]],
texture2d_array<float, access::read> in5 [[texture(5)]],
texture2d_array<float, access::read> inx [[texture(6)]],
texture2d_array<float, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
float4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
...@@ -26,6 +26,7 @@ class PoolParam<P: PrecisionType>: OpParam { ...@@ -26,6 +26,7 @@ class PoolParam<P: PrecisionType>: OpParam {
padding = try PoolParam.getAttr(key: "paddings", attrs: opDesc.attrs) padding = try PoolParam.getAttr(key: "paddings", attrs: opDesc.attrs)
ceilMode = try PoolParam.getAttr(key: "ceil_mode", attrs: opDesc.attrs) ceilMode = try PoolParam.getAttr(key: "ceil_mode", attrs: opDesc.attrs)
globalPooling = try PoolParam.getAttr(key: "global_pooling", attrs: opDesc.attrs) globalPooling = try PoolParam.getAttr(key: "global_pooling", attrs: opDesc.attrs)
assert(input.transpose == [0, 2, 3, 1])
} catch let error { } catch let error {
throw error throw error
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册