提交 e71320da 编写于 作者: D dolphin8


上级 2bcb1135
......@@ -16,8 +16,9 @@
4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; };
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; };
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; };
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */; };
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */; };
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */; };
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; };
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
......@@ -126,8 +127,9 @@
4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = "<group>"; };
4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; };
4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; };
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.metal.inc; sourceTree = "<group>"; };
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.metal.inc; sourceTree = "<group>"; };
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.inc.metal; sourceTree = "<group>"; };
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = "<group>"; };
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
......@@ -395,6 +397,7 @@
FCD04E6720F315020007374F /* PoolKernel.swift */,
FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */,
FCD04E6F20F31B720007374F /* ReshapeKernel.swift */,
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */,
FCD04E7320F3437E0007374F /* ConvAddKernel.swift */,
FCBCCC5A2122F66F00D94F7E /* ConvBNReluKernel.swift */,
FCBCCC602122FBDF00D94F7E /* PriorBoxKernel.swift */,
......@@ -442,7 +445,7 @@
children = (
FC27990D21341016000B6BAD /* BoxCoder.metal */,
4AF928812135673D005B6C3A /* ConcatKernel.metal */,
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */,
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */,
4AF9288321357BE3005B6C3A /* Elementwise.metal */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
......@@ -455,7 +458,7 @@
FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */,
FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */,
FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */,
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */,
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */,
FCA3A1642132A5EB00084FE5 /* Common.metal */,
FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */,
FCA67CD42138272900BD58AA /* ConvAddMetal.metal */,
......@@ -477,7 +480,7 @@
FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */,
FC292C85214257CB00CF622F /* CPUCompute.h in Headers */,
FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */,
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */,
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */,
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */,
runOnlyForDeploymentPostprocessing = 0;
......@@ -617,6 +620,7 @@
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */,
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */,
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */,
FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */,
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */,
......@@ -657,7 +661,7 @@
FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */,
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */,
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */,
FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */,
FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
......@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam {
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
if input.transpose != [0, 2, 3, 1] {
fatalError("batch norm only accepts NHWC")
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope)
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
bias = try BatchNormParam.getFirstTensor(key: "Bias", map: opDesc.paraInputs, from: inScope)
mean = try BatchNormParam.getFirstTensor(key: "Mean", map: opDesc.paraInputs, from: inScope)
scale = try BatchNormParam.getFirstTensor(key: "Scale", map: opDesc.paraInputs, from: inScope)
variance = try BatchNormParam.getFirstTensor(key: "Variance", map: opDesc.paraInputs, from: inScope)
epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
} catch let error {
......@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam {
let input: Texture<P>
var output: Texture<P>
let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType>
let inputVariance: Tensor<ParamPrecisionType>
let bias: Tensor<P>
let mean: Tensor<P>
let scale: Tensor<P>
let variance: Tensor<P>
let epsilon: Float
let momentum: Float
......@@ -14,7 +14,24 @@
import Foundation
class FlattenOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, Runable, Creator, InferShaperable{
class FlattenParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try FlattenParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try FlattenParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
throw error
let input: Texture<P>
var output: Texture<P>
let axis: Int
class FlattenOp<P: PrecisionType>: Operator<FlattenKernel<P>, FlattenParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = FlattenOp<P>
......@@ -15,20 +15,20 @@
import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
// var newScale: MTLBuffer
// var newBias: MTLBuffer
required init(device: MTLDevice, param: BatchNormParam<P>) {
// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
// fatalError()
// }
// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
// fatalError()
// }
// self.newScale = newScale
// self.newBias = newBias
let count = param.variance.dim.numel()
let varianceP = param.variance.data.pointer
let meanP = param.mean.data.pointer
let scaleP = param.scale.data.pointer
let biasP = param.scale.data.pointer
for i in 0..<count {
let invStd = P(1 / (Float32(varianceP[i]) + param.epsilon).squareRoot())
biasP[i] = biasP[i] - meanP[i] * invStd * scaleP[i]
scaleP[i] = invStd * scaleP[i]
param.bias.initBuffer(device: device, precision: computePrecision)
param.scale.initBuffer(device: device, precision: computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm")
} else if computePrecision == .Float16 {
......@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
} else {
// let varianceBuffer : MTLBuffer = param.inputVariance.buffer
// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
// for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
// }
// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
// let scale : MTLBuffer = param.inputScale.buffer
// let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
// let bias : MTLBuffer = param.inputBias.buffer
// let biasContents = bias.contents().assumingMemoryBound(to: P.self)
// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
// for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
// }
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
// encoder.setTexture(param.input.metalTexture, index: 0)
// encoder.setTexture(param.output.metalTexture, index: 1)
// encoder.setBuffer(newScale, offset: 0, index: 0)
// encoder.setBuffer(newBias, offset: 0, index: 1)
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBuffer(param.scale.buffer, offset: 0, index: 0)
encoder.setBuffer(param.bias.buffer, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
......@@ -122,10 +122,11 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision)
let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "concat")
super.init(device: device, inFunctionName: "concat_\(orank)_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "concat_half")
super.init(device: device, inFunctionName: "concat_\(orank)_half")
} else {
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FlattenMetalParam {
var idim: (Int32, Int32, Int32, Int32)
var itrans: (Int32, Int32, Int32, Int32)
var odim: (Int32, Int32, Int32, Int32)
var otrans: (Int32, Int32, Int32, Int32)
class FlattenKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: FlattenParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
let it: [Int32] = param.input.transpose.map { Int32($0) }
var od: [Int32] = [1, 1, 1, 1]
for i in 0..<param.output.tensorDim.cout() {
od[4-param.output.tensorDim.cout()+i] = Int32(param.output.tensorDim[i])
let ot: [Int32] = param.output.transpose.map { Int32($0) }
metalParam = FlattenMetalParam.init(
idim: (id[0], id[1], id[2], id[3]),
itrans: (it[0], it[1], it[2], it[3]),
odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3])
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
assert(orank == 2)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_half")
} else {
func compute(commandBuffer: MTLCommandBuffer, param: FlattenParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
......@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3])
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape")
super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_half")
super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_half")
} else {
......@@ -27,7 +27,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: SplitParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
// param.output.initTexture(device: device, computePrecision: computePrecision)
for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split")
} else if computePrecision == .Float16 {
......@@ -15,28 +15,28 @@
#include <metal_stdlib>
using namespace metal;
kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 * newScale [[buffer(0)]],
const device half4 * newBias [[buffer(1)]],
kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 * newScale [[buffer(0)]],
const device float4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
const half4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z];
const float4 input = inTexture.read(gid.xy, gid.z);
float4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 * newScale [[buffer(0)]],
const device float4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 * newScale [[buffer(0)]],
const device half4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
const float4 input = inTexture.read(gid.xy, gid.z);
float4 output = input * newScale[gid.z] + newBias[gid.z];
const half4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
#ifndef D
#define D 4
#ifndef P
#define P float
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define FUNC(f, d, p) CONCAT3_(f, d, p)
#define FUNC(f, r, p) CONCAT3_(f, r, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_D(f, d) CONCAT2_(f, d)
#define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
kernel void FUNC(concat, R, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
texture2d_array<P, access::read> in2 [[texture(2)]],
texture2d_array<P, access::read> in3 [[texture(3)]],
......@@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
VECTOR(P, 4) r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
#if D == 4
#if R == 4
xyzn2abcd_4(cp.odim[3], xyzn, abcd);
FUNC_D(xyzn2abcd, D)(xyzn, abcd);
FUNC_R(xyzn2abcd, R)(xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
......@@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
#if D == 4
#if R == 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
FUNC_D(abcd2xyzn, D)(abcd, oxyzn);
FUNC_R(abcd2xyzn, R)(abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
......@@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
out.write(r, gid.xy, gid.z);
......@@ -26,31 +26,31 @@ struct ConcatParam {
#define P float
#define D 4
#include "ConcatKernel.metal.inc"
#undef D
#define D 3
#include "ConcatKernel.metal.inc"
#undef D
#define D 2
#include "ConcatKernel.metal.inc"
#undef D
#define D 1
#include "ConcatKernel.metal.inc"
#undef D
#define R 4
#include "ConcatKernel.inc.metal"
#undef R
#define R 3
#include "ConcatKernel.inc.metal"
#undef R
#define R 2
#include "ConcatKernel.inc.metal"
#undef R
#define R 1
#include "ConcatKernel.inc.metal"
#undef R
#undef P
#define P half
#define D 4
#include "ConcatKernel.metal.inc"
#undef D
#define D 3
#include "ConcatKernel.metal.inc"
#undef D
#define D 2
#include "ConcatKernel.metal.inc"
#undef D
#define D 1
#include "ConcatKernel.metal.inc"
#undef D
#define R 4
#include "ConcatKernel.inc.metal"
#undef R
#define R 3
#include "ConcatKernel.inc.metal"
#undef R
#define R 2
#include "ConcatKernel.inc.metal"
#undef R
#define R 1
#include "ConcatKernel.inc.metal"
#undef R
#undef P
#ifndef P
#define P float
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define FUNC(f, d1, d2, p) CONCAT4_(f, d1, d2, p)
#define FUNC(f, r1, r2, p) CONCAT4_(f, r1, r2, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_D(f, d) CONCAT2_(f, d)
#define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
kernel void FUNC(reshape, RIN, ROUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
......@@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
VECTOR(P, 4) r;
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
#if DOUT == 4
#if ROUT == 4
xyzn2abcd_4(oC, oxyzn, oabcd);
FUNC_D(xyzn2abcd, DOUT)(oxyzn, oabcd);
FUNC_R(xyzn2abcd, ROUT)(oxyzn, oabcd);
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
......@@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, iabcd, ixyzn);
#if DIN == 4
#if RIN == 4
abcd2xyzn_4(iC, iabcd, ixyzn);
FUNC_D(abcd2xyzn, DIN)(iabcd, ixyzn);
FUNC_R(abcd2xyzn, RIN)(iabcd, ixyzn);
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
......@@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
outTexture.write(r, gid.xy, gid.z);
......@@ -8,7 +8,7 @@
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License. */
......@@ -25,127 +25,126 @@ struct ReshapeParam {
#define P float
#define DIN 4
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 4
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define DIN 3
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 3
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define DIN 2
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 2
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define DIN 1
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#undef P
#define P half
#define DIN 4
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 4
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define DIN 3
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 3
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define DIN 2
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define DIN 1
#define DOUT 4
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 3
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 2
#include "ReshapeKernel.metal.inc"
#undef DOUT
#define DOUT 1
#include "ReshapeKernel.metal.inc"
#undef DOUT
#undef DIN
#define RIN 2
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#undef P
......@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam {
output.padToFourDim = Dim.init(inDim: dim)
output.dim = output.padToFourDim
// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
} catch let error {
throw error
let input: Texture<P>
let shape: [Int32]
// let inplace: Bool
var output: Texture<P>
......@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
output = try ShapeParam.output(outputs: opDesc.outputs, from: inScope)
input = try ShapeParam.input(inputs: opDesc.inputs, from: inScope)
output = try ShapeParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
var output: Texture<P>
let input: Texture<P>
class ShapeOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{
class ShapeOp<P: PrecisionType>: Operator<ShapeKernel<P>, ShapeParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = SplitOp<P>
typealias OpType = ShapeOp<P>
func inferShape() {
// para.output.dim = para.input.dim
......@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
// output = try SplitParam.output(outputs: opDesc.outputs, from: inScope)
output = try SplitParam.outputOut(outputs: opDesc.outputs, from: inScope)
input = try SplitParam.inputX(inputs: opDesc.inputs, from: inScope)
output = Texture<P>.init(device: input.metalTexture!.device, inDim: input.dim)
axis = try SplitParam.getAttr(key: "axis", attrs: opDesc.attrs)
sections = try SplitParam.getAttr(key: "sections", attrs: opDesc.attrs)
if axis < 0 {
axis = input.tensorDim.cout() + axis
guard let outlist = opDesc.outputs["Out"] else {
for out in outlist {
guard let variant = inScope[out], let v = variant as? Texture<P> else {
} catch let error {
throw error
var axis: Int
let input: Texture<P>
var output: Texture<P>
var outputList: [Texture<P>] = []
var sections: [Int32] = []
class SplitOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册