未验证 提交 7c3a2051 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #856 from dolphin8/metal

Metal
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; };
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; };
D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; }; D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; };
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */ = {isa = PBXBuildFile; fileRef = FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */; settings = {ATTRIBUTES = (Public, ); }; }; FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */ = {isa = PBXBuildFile; fileRef = FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */; settings = {ATTRIBUTES = (Public, ); }; };
FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B9420E11C9A0081E9F8 /* Extensions.swift */; }; FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B9420E11C9A0081E9F8 /* Extensions.swift */; };
...@@ -90,6 +92,8 @@ ...@@ -90,6 +92,8 @@
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = "<group>"; };
4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = "<group>"; };
CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = "<group>"; }; CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = "<group>"; };
DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
E2A7957C92EDA5C3BEC0FFC2 /* Pods-paddle-mobile.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.release.xcconfig"; sourceTree = "<group>"; }; E2A7957C92EDA5C3BEC0FFC2 /* Pods-paddle-mobile.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.release.xcconfig"; sourceTree = "<group>"; };
...@@ -355,6 +359,8 @@ ...@@ -355,6 +359,8 @@
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC27990D21341016000B6BAD /* BoxCoder.metal */, FC27990D21341016000B6BAD /* BoxCoder.metal */,
4AF928812135673D005B6C3A /* Concat.metal */,
4AF9288321357BE3005B6C3A /* Elementwise.metal */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
4AF928762133F1DB005B6C3A /* BoxCoder.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */,
...@@ -478,6 +484,7 @@ ...@@ -478,6 +484,7 @@
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */, FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */,
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */, FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */,
FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */, FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
...@@ -515,6 +522,7 @@ ...@@ -515,6 +522,7 @@
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
4AF928822135673D005B6C3A /* Concat.metal in Sources */,
FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */, FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */,
FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */, FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */,
FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */,
......
...@@ -25,6 +25,13 @@ class ConcatParam<P: PrecisionType>: OpParam { ...@@ -25,6 +25,13 @@ class ConcatParam<P: PrecisionType>: OpParam {
guard let variant = inScope[x], let v = variant as? Texture<P> else { guard let variant = inScope[x], let v = variant as? Texture<P> else {
fatalError() fatalError()
} }
if transpose.count == 0 {
transpose = v.transpose
}
if v.transpose != transpose {
fatalError()
}
input.append(v) input.append(v)
} }
axis = try ConcatParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ConcatParam.getAttr(key: "axis", attrs: opDesc.attrs)
...@@ -35,6 +42,7 @@ class ConcatParam<P: PrecisionType>: OpParam { ...@@ -35,6 +42,7 @@ class ConcatParam<P: PrecisionType>: OpParam {
} }
var input: [Texture<P>] = [] var input: [Texture<P>] = []
var output: Texture<P> var output: Texture<P>
var transpose: [Int] = []
let axis: Int let axis: Int
} }
......
...@@ -18,36 +18,42 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -18,36 +18,42 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
} catch _ {
do {
inputYTexture = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
} catch let error {
throw error
}
}
do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
do {
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch _ {
let tensorY: Tensor<P> = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
let device = inputX.metalTexture!.device
inputY = Texture.init(device: device, inDim: tensorY.dim)
let value: [P] = Array(UnsafeBufferPointer(start: tensorY.data.pointer, count: tensorY.dim.numel()))
inputY.metalTexture = device.tensor2texture(value: value, dim: tensorY.dim.dims)
} }
var inputYTexture: Texture<P>? var offset = axis
var inputY: Tensor<P>? if axis == -1 {
var input: Texture<P> offset = inputX.tensorDim.cout() - inputY.tensorDim.cout()
}
for i in 0..<(inputY.tensorDim.cout()) {
assert(inputX.tensorDim[offset + i] == inputY.tensorDim[i])
}
}
var inputX: Texture<P>
var inputY: Texture<P>
var output: Texture<P> var output: Texture<P>
let axis: Int var axis: Int
} }
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = ElementwiseAddOp<P> typealias OpType = ElementwiseAddOp<P>
func inferShape() { func inferShape() {
para.output.dim = para.input.dim // para.output.dim = para.input.dim
} }
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
......
...@@ -26,6 +26,7 @@ struct ConcatMetalParam { ...@@ -26,6 +26,7 @@ struct ConcatMetalParam {
var odim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1) var odim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1)
var axis: Int32 = 0 var axis: Int32 = 0
var offset: Int32 = 0 var offset: Int32 = 0
var trans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
var vdim: (Int32, Int32, Int32, Int32, Int32, Int32) = (0, 0, 0, 0, 0, 0) var vdim: (Int32, Int32, Int32, Int32, Int32, Int32) = (0, 0, 0, 0, 0, 0)
} }
...@@ -68,6 +69,12 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -68,6 +69,12 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
let odim = (0..<4).map { Int32(param.output.dim[$0]) } let odim = (0..<4).map { Int32(param.output.dim[$0]) }
p.odim = (odim[0], odim[1], odim[2], odim[3]) p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis) p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis)
for i in 0..<4 {
if Int32(param.transpose[i]) == p.axis {
p.axis = Int32(i)
break
}
}
for i in 0..<istart { for i in 0..<istart {
p.offset += Int32(param.input[i+istart].dim[Int(p.axis)]) p.offset += Int32(param.input[i+istart].dim[Int(p.axis)])
} }
...@@ -80,6 +87,8 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -80,6 +87,8 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
encoder.setTexture(param.input[0].metalTexture, index: i) encoder.setTexture(param.input[0].metalTexture, index: i)
vdim.append(0) vdim.append(0)
} }
p.trans = (Int32(param.transpose[0]), Int32(param.transpose[1]), Int32(param.transpose[2]), Int32(param.transpose[3]))
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5]) p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output.metalTexture, index: 6) encoder.setTexture(param.output.metalTexture, index: 6)
encoder.setTexture(param.output.metalTexture, index: 7) encoder.setTexture(param.output.metalTexture, index: 7)
...@@ -89,11 +98,6 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -89,11 +98,6 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
} }
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
for i in 0..<param.input.count {
for j in 0..<4 {
assert(param.input[i].transpose[j] == j)
}
}
let group = param.input.count / 6 let group = param.input.count / 6
let remain = param.input.count % 6 let remain = param.input.count % 6
...@@ -109,15 +113,15 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -109,15 +113,15 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
let group = param.input.count / 6 let group = param.input.count / 6
let remain = param.input.count % 6 let remain = param.input.count % 6
for i in 0..<group { for i in 0..<group {
try self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1)) self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1))
} }
if remain > 0 { if remain > 0 {
try self.encodeTest(cmdBuffer, param, 6 * group, param.input.count) self.encodeTest(cmdBuffer, param, 6 * group, param.input.count)
} }
} }
required init(device: MTLDevice, param: ConcatParam<P>) { required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device) param.output.initTexture(device: device, inTranspose: param.transpose)
super.init(device: device, inFunctionName: "concat") super.init(device: device, inFunctionName: "concat")
} }
......
...@@ -43,6 +43,8 @@ class ConvTransposeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -43,6 +43,8 @@ class ConvTransposeKernel<P: PrecisionType>: Kernel, Computable{
let dilationY = UInt16(param.dilations[1]) let dilationY = UInt16(param.dilations[1])
metalParam = MetalConvTransposeParam.init(kernelW: kernelWidth, kernelH: kernelHeight, strideX: strideX, strideY: strideY, paddingX: paddingX, paddingY: paddingY, dilationX: dilationX, dilationY: dilationY) metalParam = MetalConvTransposeParam.init(kernelW: kernelWidth, kernelH: kernelHeight, strideX: strideX, strideY: strideY, paddingX: paddingX, paddingY: paddingY, dilationX: dilationX, dilationY: dilationY)
param.output.initTexture(device: device, inTranspose: param.input.transpose)
} }
func compute(commandBuffer: MTLCommandBuffer, param: ConvTransposeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ConvTransposeParam<P>) throws {
......
...@@ -14,14 +14,52 @@ ...@@ -14,14 +14,52 @@
import Foundation import Foundation
struct ElementwiseAddMetalParam {
var fast: Int32 = 0
var axis: Int32 = 0
var yoff: Int32 = 0
var xdim: (Int32, Int32, Int32, Int32) = (0, 0, 0, 0)
var xtrans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
var ydim: (Int32, Int32, Int32, Int32) = (0, 0, 0, 0)
var ytrans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
}
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable { class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ElementwiseAddParam<P>) { required init(device: MTLDevice, param: ElementwiseAddParam<P>) {
super.init(device: device, inFunctionName: "elementwise_add") super.init(device: device, inFunctionName: "elementwise_add")
param.output.initTexture(device: device, inTranspose: param.input.transpose) param.output.initTexture(device: device, inTranspose: param.inputX.transpose)
} }
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
var emp = ElementwiseAddMetalParam.init()
encoder.setTexture(param.inputX.metalTexture, index: 0)
encoder.setTexture(param.inputY.metalTexture, index: 1)
encoder.setTexture(param.output.metalTexture, index: 2)
let xdim: [Int32] = (0..<4).map { Int32(param.inputX.dim[$0]) }
let ydim: [Int32] = (0..<4).map { Int32(param.inputY.dim[$0]) }
let xtrans: [Int32] = (0..<4).map { Int32(param.inputX.transpose[$0]) }
let ytrans: [Int32] = (0..<4).map { Int32(param.inputY.transpose[$0]) }
emp.xdim = (xdim[0], xdim[1], xdim[2], xdim[3])
emp.ydim = (ydim[0], ydim[1], ydim[2], ydim[3])
emp.xtrans = (xtrans[0], xtrans[1], xtrans[2], xtrans[3])
emp.ytrans = (ytrans[0], ytrans[1], ytrans[2], ytrans[3])
if param.axis == -1 {
emp.axis = 4 - Int32(param.inputY.tensorDim.cout())
} else {
emp.axis = 4 - Int32(param.inputX.tensorDim.cout()) + Int32(param.axis)
}
emp.yoff = 4 - Int32(param.inputY.tensorDim.cout())
if (param.inputX.dim == param.inputY.dim) && (param.inputX.transpose == param.inputY.transpose) {
emp.fast = 1
}
encoder.setBytes(&emp, length: MemoryLayout<ElementwiseAddMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[6];
};
kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
texture2d_array<float, access::read> in1 [[texture(1)]],
texture2d_array<float, access::read> in2 [[texture(2)]],
texture2d_array<float, access::read> in3 [[texture(3)]],
texture2d_array<float, access::read> in4 [[texture(4)]],
texture2d_array<float, access::read> in5 [[texture(5)]],
texture2d_array<float, access::read> inx [[texture(6)]],
texture2d_array<float, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
float4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ElementwiseAddParam {
int32_t fast;
int32_t axis;
int32_t yoff;
int32_t xdim[4];
int32_t xtrans[4];
int32_t ydim[4];
int32_t ytrans[4];
};
kernel void elementwise_add(texture2d_array<float, access::read> inputX [[texture(0)]],
texture2d_array<float, access::read> inputY [[texture(1)]],
texture2d_array<float, access::write> outTexture [[texture(2)]],
constant ElementwiseAddParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
float4 rx, ry;
if (pm.fast == 1) {
rx = inputX.read(gid.xy, gid.z);
ry = inputY.read(gid.xy, gid.z);
} else {
rx = inputX.read(gid.xy, gid.z);
int32_t x_xyzn[4] = {int32_t(gid.x), int32_t(gid.y), int32_t(gid.z), 0}, x_abcd[4], t_abcd[4];
int32_t y_abcd[4] = {1, 1, 1, 1}, y_xyzn[4];
int32_t xtrans[4] = {pm.xtrans[0], pm.xtrans[1], pm.xtrans[2], pm.xtrans[3]};
int32_t ytrans[4] = {pm.ytrans[0], pm.ytrans[1], pm.ytrans[2], pm.ytrans[3]};
for (int n = 0; n < 4; n++) {
xyzn2abcd(pm.xdim[3], x_xyzn, x_abcd);
invtrans(xtrans, x_abcd, t_abcd);
for (int k = pm.axis; k < (4 - pm.yoff); k++) {
y_abcd[k+pm.yoff] = t_abcd[k];
}
trans(ytrans, y_abcd, t_abcd);
abcd2xyzn(pm.ydim[3], t_abcd, y_xyzn);
ry[n] = inputY.read(uint2(y_xyzn[0], y_xyzn[1]), y_xyzn[2])[y_xyzn[3]];
}
}
float4 r = rx + ry;
outTexture.write(r, gid.xy, gid.z);
}
...@@ -43,17 +43,6 @@ kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]], ...@@ -43,17 +43,6 @@ kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z); outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z);
} }
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
//kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]], //kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]],
...@@ -200,55 +189,3 @@ kernel void transpose(texture2d_array<float, access::read> inTexture [[texture(0 ...@@ -200,55 +189,3 @@ kernel void transpose(texture2d_array<float, access::read> inTexture [[texture(0
outTexture.write(r, gid.xy, gid.z); outTexture.write(r, gid.xy, gid.z);
} }
} }
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t vdim[6];
};
kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
texture2d_array<float, access::read> in1 [[texture(1)]],
texture2d_array<float, access::read> in2 [[texture(2)]],
texture2d_array<float, access::read> in3 [[texture(3)]],
texture2d_array<float, access::read> in4 [[texture(4)]],
texture2d_array<float, access::read> in5 [[texture(5)]],
texture2d_array<float, access::read> inx [[texture(6)]],
texture2d_array<float, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
float4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
...@@ -26,6 +26,7 @@ class PoolParam<P: PrecisionType>: OpParam { ...@@ -26,6 +26,7 @@ class PoolParam<P: PrecisionType>: OpParam {
padding = try PoolParam.getAttr(key: "paddings", attrs: opDesc.attrs) padding = try PoolParam.getAttr(key: "paddings", attrs: opDesc.attrs)
ceilMode = try PoolParam.getAttr(key: "ceil_mode", attrs: opDesc.attrs) ceilMode = try PoolParam.getAttr(key: "ceil_mode", attrs: opDesc.attrs)
globalPooling = try PoolParam.getAttr(key: "global_pooling", attrs: opDesc.attrs) globalPooling = try PoolParam.getAttr(key: "global_pooling", attrs: opDesc.attrs)
assert(input.transpose == [0, 2, 3, 1])
} catch let error { } catch let error {
throw error throw error
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册