未验证 提交 16c3a457 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #982 from dolphin8/metal

softmax & transpose
......@@ -23,6 +23,8 @@
4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA5214B5F6800D0F791 /* Shape.metal */; };
4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */; };
4AA1EAAA214F53D800D0F791 /* BoxCoder.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA9214F53D800D0F791 /* BoxCoder.inc.metal */; };
4AA1EAAC214F55C800D0F791 /* Softmax.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAAB214F55C800D0F791 /* Softmax.inc.metal */; };
4AA1EAAE214F5FD900D0F791 /* TransposeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAAD214F5FD900D0F791 /* TransposeKernel.inc.metal */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
......@@ -138,6 +140,8 @@
4AA1EAA5214B5F6800D0F791 /* Shape.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Shape.metal; sourceTree = "<group>"; };
4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BilinearInterp.inc.metal; sourceTree = "<group>"; };
4AA1EAA9214F53D800D0F791 /* BoxCoder.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.inc.metal; sourceTree = "<group>"; };
4AA1EAAB214F55C800D0F791 /* Softmax.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.inc.metal; sourceTree = "<group>"; };
4AA1EAAD214F5FD900D0F791 /* TransposeKernel.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = TransposeKernel.inc.metal; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
......@@ -463,6 +467,7 @@
4AA1EA892146631C00D0F791 /* BilinearInterp.metal */,
4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */,
4AF9287821341661005B6C3A /* Softmax.metal */,
4AA1EAAB214F55C800D0F791 /* Softmax.inc.metal */,
FCEB6849212F00DB00D2448E /* PreluKernel.metal */,
FCDDC6C9212FDF6800E5EF74 /* BatchNormKernel.metal */,
FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */,
......@@ -475,6 +480,7 @@
FCA67CD6213827AC00BD58AA /* ConvAddBNReluKernel.metal */,
FCA67CD82138287B00BD58AA /* ConvBNReluKernel.metal */,
FC0226552138F33800F395E2 /* TransposeKernel.metal */,
4AA1EAAD214F5FD900D0F791 /* TransposeKernel.inc.metal */,
FC0226572138F38D00F395E2 /* PoolKernel.metal */,
);
path = metal;
......@@ -595,6 +601,7 @@
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
4AA1EAAE214F5FD900D0F791 /* TransposeKernel.inc.metal in Sources */,
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */,
......@@ -612,6 +619,7 @@
4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */,
FC292C81214255BD00CF622F /* CPUCompute.mm in Sources */,
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */,
4AA1EAAC214F55C800D0F791 /* Softmax.inc.metal in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */,
FC33B0F02147659000714A93 /* MobileNet.swift in Sources */,
......
......@@ -195,23 +195,23 @@ public class PaddleMobileUnitTest {
// let tx: [Float32] = self.device.texture2tensor(texture: outputTexture, dim: [3, 3, 2, 4])
// self.tensorPrint(tensor: tx, dim: [3, 3, 2, 4])
// }
let input: [Float32] = (0..<24).map { Float32($0) }
let inputTexture = device.tensor2texture(value: input, dim: [2, 3, 4])
let outputTexture = device.tensor2texture(value: [Float](), dim: [3, 4, 2])
let param = TransposeTestParam.init(inputTexture: inputTexture, outputTexture: outputTexture, iC: 4, oC: 2, axis: [0, 2, 3, 1])
let transposeKernel = TransposeKernel<Float32>.init(device: device, testParam: param)
transposeKernel.test(commandBuffer: buffer, param: param)
buffer.addCompletedHandler { (buffer) in
let _: Float32? = inputTexture.logDesc(header: "input texture", stridable: false)
let _: Float32? = outputTexture.logDesc(header: "output texture", stridable: false)
self.tensorPrint(tensor: input, dim: [2, 3, 4])
let tx: [Float32] = self.device.texture2tensor(texture: outputTexture, dim: [3, 4, 2])
self.tensorPrint(tensor: tx, dim: [3, 4, 2])
}
//
// let input: [Float32] = (0..<24).map { Float32($0) }
// let inputTexture = device.tensor2texture(value: input, dim: [2, 3, 4])
// let outputTexture = device.tensor2texture(value: [Float](), dim: [3, 4, 2])
// let param = TransposeTestParam.init(inputTexture: inputTexture, outputTexture: outputTexture, iC: 4, oC: 2, axis: [0, 2, 3, 1])
// let transposeKernel = TransposeKernel<Float32>.init(device: device, testParam: param)
//
// transposeKernel.test(commandBuffer: buffer, param: param)
//
// buffer.addCompletedHandler { (buffer) in
// let _: Float32? = inputTexture.logDesc(header: "input texture", stridable: false)
// let _: Float32? = outputTexture.logDesc(header: "output texture", stridable: false)
// self.tensorPrint(tensor: input, dim: [2, 3, 4])
// let tx: [Float32] = self.device.texture2tensor(texture: outputTexture, dim: [3, 4, 2])
// self.tensorPrint(tensor: tx, dim: [3, 4, 2])
// }
//
buffer.commit()
}
......
......@@ -49,6 +49,9 @@ class FlattenOp<P: PrecisionType>: Operator<FlattenKernel<P>, FlattenParam<P>>,
func delogOutput() {
print(" \(type) output: ")
let device = para.output.metalTexture!.device
let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
print(outputArray.strideArray())
}
}
......
......@@ -29,7 +29,7 @@ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
K: Int32(param.input.tensorDim[1])
)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "softmax")
super.init(device: device, inFunctionName: "softmax_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "softmax_half")
} else {
......
......@@ -17,73 +17,52 @@ import Foundation
struct TransposeMetalParam {
var iC: Int32 = 0
var oC: Int32 = 0
var i0: Int32
var i1: Int32
var i2: Int32
var i3: Int32
init(_ i0: Int32, _ i1: Int32, _ i2: Int32, _ i3: Int32) {
self.i0 = i0
self.i1 = i1
self.i2 = i2
self.i3 = i3
}
init(_ axis: [Int]) {
self.init(Int32(axis[0]), Int32(axis[1]), Int32(axis[2]), Int32(axis[3]))
}
}
struct TransposeTestParam: TestParam {
let inputTexture: MTLTexture
let outputTexture: MTLTexture
let iC: Int
let oC: Int
let axis: [Int]
var axis: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
}
class TransposeKernel<P: PrecisionType>: Kernel, Computable, Testable {
class TransposeKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: TransposeMetalParam = TransposeMetalParam.init()
required init(device: MTLDevice, param: TransposeParam<P>) {
param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: computePrecision)
if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "transpose_half")
} else if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "transpose")
} else {
fatalError()
}
var invT: [Int] = [0, 1, 2, 3]
for (i, v) in param.input.transpose.enumerated() {
invT[v] = i
}
param.output.initTexture(device: device, computePrecision: computePrecision)
let rank = param.input.tensorDim.cout()
var axis: [Int] = [0, 1, 2, 3]
for i in 0..<param.axis.count {
axis[4-param.axis.count+i] = 4 - param.axis.count + Int(param.axis[i])
axis[4-rank+i] = 4 - rank + Int(param.axis[i])
}
let realAxis = axis.map {invT[$0]}
var tmp = TransposeMetalParam.init(realAxis)
tmp.iC = Int32(param.input.dim[param.input.transpose[3]])
tmp.oC = Int32(param.output.dim[3])
if realAxis == [0, 1, 2, 3] {
// print("====> transpose! FAST :)")
} else {
// print("====> transpose! SLOW :(")
var naxis: [Int] = [0, 0, 0, 0]
for i in 0..<4 {
for j in 0..<4 {
if param.input.transpose[j] == axis[i] {
naxis[i] = j
break
}
}
}
metalParam = tmp
}
required init(device: MTLDevice, testParam: TransposeTestParam) {
metalParam.iC = Int32(param.input.dim[param.input.transpose[3]])
metalParam.oC = Int32(param.output.dim[3])
metalParam.axis = (Int32(naxis[0]), Int32(naxis[1]), Int32(naxis[2]), Int32(naxis[3]))
var kernelFunc = "transpose_undefined"
if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "transpose_half")
if param.input.transpose == axis {
kernelFunc = "transpose_copy_half"
} else {
kernelFunc = "transpose_\(rank)_half"
}
} else if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "transpose")
if param.input.transpose == axis {
kernelFunc = "transpose_copy_float"
} else {
kernelFunc = "transpose_\(rank)_float"
}
} else {
fatalError()
}
print("===========>", kernelFunc)
print(metalParam)
super.init(device: device, inFunctionName: kernelFunc)
}
var metalParam: TransposeMetalParam!
func compute(commandBuffer: MTLCommandBuffer, param: TransposeParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
......@@ -95,20 +74,4 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable, Testable {
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
public func test(commandBuffer: MTLCommandBuffer, param: TransposeTestParam) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
fatalError()
}
encoder.setTexture(param.inputTexture, index: 0)
encoder.setTexture(param.outputTexture, index: 1)
var tmp = TransposeMetalParam.init(param.axis)
tmp.iC = Int32(param.iC)
tmp.oC = Int32(param.oC)
encoder.setBytes(&tmp, length: MemoryLayout<TransposeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.outputTexture)
encoder.endEncoding()
}}
}
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define FUNC(f, p) CONCAT2_(f, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(softmax, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant SoftmaxParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
// int zsize = inTexture.get_array_size();
P maxv = inTexture.read(uint2(0, gid.y), 0)[0];
int group = sp.K / 4;
int remain = sp.K % 4;
for (int x = 0; x < group; x++) {
VECTOR(P, 4) r = inTexture.read(uint2(x, gid.y), 0);
maxv = max(maxv, max(r[0], max(r[1], max(r[2], r[3]))));
}
if (remain > 0) {
VECTOR(P, 4) r = inTexture.read(uint2(group, gid.y), 0);
for (int i = 0; i < remain; i++) {
maxv = max(maxv, r[i]);
}
}
VECTOR(P, 4) rsum = {0, 0, 0, 0};
for (int x = 0; x < group; x++) {
VECTOR(P, 4) r = inTexture.read(uint2(x, gid.y), 0);
rsum += exp(r - maxv);
}
P sum = rsum[0] + rsum[1] + rsum[2] + rsum[3];
if (remain > 0) {
VECTOR(P, 4) r = inTexture.read(uint2(group, gid.y), 0);
for (int i = 0; i < remain; i++) {
sum += exp(r[i] - maxv);
}
}
VECTOR(P, 4) rr = inTexture.read(gid.xy, gid.z);
rr = exp(rr - maxv) / sum;
outTexture.write(rr, gid.xy, gid.z);
}
#endif
......@@ -20,81 +20,10 @@ struct SoftmaxParam {
int K;
};
kernel void softmax(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant SoftmaxParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
// int zsize = inTexture.get_array_size();
float maxv = inTexture.read(gid.xy, 0)[0];
int group = sp.K / 4;
int remain = sp.K % 4;
for (int z = 0; z < group; z++) {
float4 r = inTexture.read(gid.xy, z);
maxv = max(maxv, max(r[0], max(r[1], max(r[2], r[3]))));
}
if (remain > 0) {
float4 r = inTexture.read(gid.xy, group);
for (int i = 0; i < remain; i++) {
maxv = max(maxv, r[i]);
}
}
float4 rsum = {0, 0, 0, 0};
for (int z = 0; z < group; z++) {
float4 r = inTexture.read(gid.xy, z);
rsum += exp(r - maxv);
}
float sum = rsum[0] + rsum[1] + rsum[2] + rsum[3];
if (remain > 0) {
float4 r = inTexture.read(gid.xy, group);
for (int i = 0; i < remain; i++) {
sum += exp(r[i] - maxv);
}
}
float4 rr = inTexture.read(gid.xy, gid.z);
rr = exp(rr - maxv) / sum;
outTexture.write(rr, gid.xy, gid.z);
}
kernel void softmax_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant SoftmaxParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
// int zsize = inTexture.get_array_size();
half maxv = inTexture.read(gid.xy, 0)[0];
int group = sp.K / 4;
int remain = sp.K % 4;
for (int z = 0; z < group; z++) {
half4 r = inTexture.read(gid.xy, z);
maxv = max(maxv, max(r[0], max(r[1], max(r[2], r[3]))));
}
if (remain > 0) {
half4 r = inTexture.read(gid.xy, group);
for (int i = 0; i < remain; i++) {
maxv = max(maxv, r[i]);
}
}
float4 rsum = {0, 0, 0, 0};
for (int z = 0; z < group; z++) {
half4 r = inTexture.read(gid.xy, z);
rsum += exp(float4(r) - float4(maxv));
}
float sum = rsum[0] + rsum[1] + rsum[2] + rsum[3];
if (remain > 0) {
half4 r = inTexture.read(gid.xy, group);
for (int i = 0; i < remain; i++) {
sum += exp(float(r[i]) - float(maxv));
}
}
half4 rr = inTexture.read(gid.xy, gid.z);
rr = half4(exp(float4(rr) - float(maxv)) / sum);
outTexture.write(rr, gid.xy, gid.z);
}
#define P float
#include "Softmax.inc.metal"
#undef P
#define P half
#include "Softmax.inc.metal"
#undef P
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define FUNC(f, r, p) CONCAT3_(f, r, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(transpose, R, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r;
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0};
int iabcd[4], oabcd[4], ixyzn[4];
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
#if R == 4
xyzn2abcd_4(pm.oC, oxyzn, iabcd);
#endif // R == 4
#if R == 3
xyzn2abcd_3(oxyzn, oabcd);
#endif // R == 3
#if R == 2
xyzn2abcd_2(oxyzn, oabcd);
#endif // R == 2
iabcd[pm.axis[0]] = oabcd[0];
iabcd[pm.axis[1]] = oabcd[1];
iabcd[pm.axis[2]] = oabcd[2];
iabcd[pm.axis[3]] = oabcd[3];
#if R == 4
abcd2xyzn_4(pm.iC, iabcd, ixyzn);
#endif // R == 4
#if R == 3
abcd2xyzn_3(iabcd, ixyzn);
#endif // R == 3
#if R == 2
abcd2xyzn_2(iabcd, ixyzn);
#endif // R == 2
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
}
outTexture.write(r, gid.xy, gid.z);
}
#endif
......@@ -22,59 +22,42 @@ struct TransposeParam {
int axis[4];
};
kernel void transpose(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if ((pm.axis[0] == 0) && (pm.axis[1] == 1) && (pm.axis[2] == 2) && (pm.axis[3] == 3)) {
// do nothing
float4 r = inTexture.read(gid.xy, gid.z);
outTexture.write(r, gid.xy, gid.z);
} else {
float4 r;
for (int n = 0; n < 4; n++) {
int ixyzn[] = {int(gid.x), int(gid.y), int(gid.z), n};
int iabcd[4], oabcd[4], oxyzn[4];
xyzn2abcd(pm.oC, ixyzn, iabcd);
oabcd[pm.axis[0]] = iabcd[0];
oabcd[pm.axis[1]] = iabcd[1];
oabcd[pm.axis[2]] = iabcd[2];
oabcd[pm.axis[3]] = iabcd[3];
abcd2xyzn(pm.iC, oabcd, oxyzn);
float4 rt = inTexture.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2]);
r[n] = rt[oxyzn[3]];
}
outTexture.write(r, gid.xy, gid.z);
}
kernel void transpose_copy_float(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
outTexture.write(inTexture.read(gid.xy, gid.z), gid.xy, gid.z);
}
kernel void transpose_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
kernel void transpose_copy_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if ((pm.axis[0] == 0) && (pm.axis[1] == 1) && (pm.axis[2] == 2) && (pm.axis[3] == 3)) {
// do nothing
half4 r = inTexture.read(gid.xy, gid.z);
outTexture.write(r, gid.xy, gid.z);
} else {
half4 r;
for (int n = 0; n < 4; n++) {
int ixyzn[] = {int(gid.x), int(gid.y), int(gid.z), n};
int iabcd[4], oabcd[4], oxyzn[4];
xyzn2abcd(pm.oC, ixyzn, iabcd);
oabcd[pm.axis[0]] = iabcd[0];
oabcd[pm.axis[1]] = iabcd[1];
oabcd[pm.axis[2]] = iabcd[2];
oabcd[pm.axis[3]] = iabcd[3];
abcd2xyzn(pm.iC, oabcd, oxyzn);
half4 rt = inTexture.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2]);
r[n] = rt[oxyzn[3]];
}
outTexture.write(r, gid.xy, gid.z);
}
outTexture.write(inTexture.read(gid.xy, gid.z), gid.xy, gid.z);
}
#define R 4
#define P float
#include "TransposeKernel.inc.metal"
#undef P
#define P half
#include "TransposeKernel.inc.metal"
#undef P
#undef R
#define R 3
#define P float
#include "TransposeKernel.inc.metal"
#undef P
#define P half
#include "TransposeKernel.inc.metal"
#undef P
#undef R
#define R 2
#define P float
#include "TransposeKernel.inc.metal"
#undef P
#define P half
#include "TransposeKernel.inc.metal"
#undef P
#undef R
......@@ -48,15 +48,6 @@ class TransposeOp<P: PrecisionType>: Operator<TransposeKernel<P>, TransposeParam
func delogOutput() {
print(" \(type) output: ")
let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] {
let outputArray = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] {
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
} else {
print(" not implement")
}
let device = para.output.metalTexture!.device
let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
print(outputArray.strideArray())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册