提交 b5248db0 编写于 作者: L liuruilong

correct cnn implementation

上级 2cee66eb
<?xml version="1.0" encoding="UTF-8"?>
<Scheme
LastUpgradeVersion = "0940"
version = "1.3">
<BuildAction
parallelizeBuildables = "YES"
buildImplicitDependencies = "YES">
<BuildActionEntries>
<BuildActionEntry
buildForTesting = "YES"
buildForRunning = "YES"
buildForProfiling = "YES"
buildForArchiving = "YES"
buildForAnalyzing = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "FC039B7D20E11C550081E9F8"
BuildableName = "paddle-mobile-demo.app"
BlueprintName = "paddle-mobile-demo"
ReferencedContainer = "container:paddle-mobile-demo.xcodeproj">
</BuildableReference>
</BuildActionEntry>
</BuildActionEntries>
</BuildAction>
<TestAction
buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
shouldUseLaunchSchemeArgsEnv = "YES">
<Testables>
</Testables>
<MacroExpansion>
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "FC039B7D20E11C550081E9F8"
BuildableName = "paddle-mobile-demo.app"
BlueprintName = "paddle-mobile-demo"
ReferencedContainer = "container:paddle-mobile-demo.xcodeproj">
</BuildableReference>
</MacroExpansion>
<AdditionalOptions>
</AdditionalOptions>
</TestAction>
<LaunchAction
buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0"
useCustomWorkingDirectory = "NO"
ignoresPersistentStateOnLaunch = "NO"
debugDocumentVersioning = "YES"
debugServiceExtension = "internal"
allowLocationSimulation = "YES">
<BuildableProductRunnable
runnableDebuggingMode = "0">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "FC039B7D20E11C550081E9F8"
BuildableName = "paddle-mobile-demo.app"
BlueprintName = "paddle-mobile-demo"
ReferencedContainer = "container:paddle-mobile-demo.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
<AdditionalOptions>
</AdditionalOptions>
</LaunchAction>
<ProfileAction
buildConfiguration = "Release"
shouldUseLaunchSchemeArgsEnv = "YES"
savedToolIdentifier = ""
useCustomWorkingDirectory = "NO"
debugDocumentVersioning = "YES">
<BuildableProductRunnable
runnableDebuggingMode = "0">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "FC039B7D20E11C550081E9F8"
BuildableName = "paddle-mobile-demo.app"
BlueprintName = "paddle-mobile-demo"
ReferencedContainer = "container:paddle-mobile-demo.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
</ProfileAction>
<AnalyzeAction
buildConfiguration = "Debug">
</AnalyzeAction>
<ArchiveAction
buildConfiguration = "Release"
revealArchiveInOrganizer = "YES">
</ArchiveAction>
</Scheme>
...@@ -7,7 +7,15 @@ ...@@ -7,7 +7,15 @@
<key>paddle-mobile-demo.xcscheme</key> <key>paddle-mobile-demo.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>4</integer> <integer>3</integer>
</dict>
</dict>
<key>SuppressBuildableAutocreation</key>
<dict>
<key>FC039B7D20E11C550081E9F8</key>
<dict>
<key>primary</key>
<true/>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -40,6 +40,7 @@ class ViewController: UIViewController { ...@@ -40,6 +40,7 @@ class ViewController: UIViewController {
let dest = device.makeTexture(descriptor: tmpTextureDes) let dest = device.makeTexture(descriptor: tmpTextureDes)
let scale = MPSImageLanczosScale.init(device: device) let scale = MPSImageLanczosScale.init(device: device)
let buffer = queue.makeCommandBuffer() let buffer = queue.makeCommandBuffer()
scale.encode(commandBuffer: buffer!, sourceTexture: input, destinationTexture: dest!) scale.encode(commandBuffer: buffer!, sourceTexture: input, destinationTexture: dest!)
buffer?.addCompletedHandler({ (buffer) in buffer?.addCompletedHandler({ (buffer) in
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key> <key>paddle-mobile.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>3</integer> <integer>4</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -103,11 +103,11 @@ public extension MTLTexture { ...@@ -103,11 +103,11 @@ public extension MTLTexture {
str += "2d array count : \(width * height * depth * 4) \n" str += "2d array count : \(width * height * depth * 4) \n"
if stridable { if stridable {
for j in stride(from: 0, to: width * height * depth * 4 , by: width * height * depth * 4 / 100){ for j in stride(from: 0, to: width * height * depth * 4 , by: width * height * depth * 4 / 100){
str += " \(p[j])" str += " index \(j): \(p[j])"
} }
} else { } else {
for j in 0..<width * height * depth * 4 { for j in 0..<width * height * depth * 4 {
str += " \(p[j])" str += " index \(j): \(p[j])"
} }
} }
......
...@@ -55,7 +55,7 @@ public class Executor<P: PrecisionType> { ...@@ -55,7 +55,7 @@ public class Executor<P: PrecisionType> {
device = inDevice device = inDevice
queue = inQueue queue = inQueue
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
for i in 0..<2 { for i in 0..<block.ops.count {
let op = block.ops[i] let op = block.ops[i]
do { do {
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
......
...@@ -107,16 +107,17 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer ...@@ -107,16 +107,17 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer
} }
func delogOutput() { func delogOutput() {
let _: P? = para.input.metalTexture.logDesc(header: "conv add batchnorm relu input: ", stridable: false) // let _: P? = para.input.metalTexture.logDesc(header: "conv add batchnorm relu input: ", stridable: false)
para.filter.logDataPointer(header: "filter data pointer: ") // para.filter.logDataPointer(header: "filter data pointer: ")
print("filter: \(para.filter)") //
// print("filter: \(para.filter)")
print("biase: \(para.bias)") // print("biase: \(para.bias)")
// print("padding: \(para.paddings)")
let _: P? = para.newBiase?.logDesc(header: "new biase: ", stridable: false) // print("stride: \(para.stride)")
let _: P? = para.newScale?.logDesc(header: "new scale: ", stridable: false) //
// let _: P? = para.newBiase?.logDesc(header: "new biase: ", stridable: false)
let _: P? = para.output.metalTexture.logDesc(header: "conv add batchnorm relu output: ", stridable: true) // let _: P? = para.newScale?.logDesc(header: "new scale: ", stridable: false)
// let _: P? = para.output.metalTexture.logDesc(header: "conv add batchnorm relu output: ", stridable: true)
} }
} }
...@@ -18,10 +18,22 @@ class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable { ...@@ -18,10 +18,22 @@ class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>) { required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>) {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3")
let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0]) if param.filter.width == 1 && param.filter.height == 1 {
let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1]) super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1")
} else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_1x1")
} else {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3")
}
let offsetX = param.filter.width/2 - Int(param.paddings[0])
let offsetY = param.filter.height/2 - Int(param.paddings[1])
print("offset x: \(offsetX)")
print("offset y: \(offsetY)")
let offsetZ = 0.0 let offsetZ = 0.0
metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3])) metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3]))
...@@ -69,6 +81,4 @@ class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable { ...@@ -69,6 +81,4 @@ class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable {
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
} }
...@@ -16,7 +16,7 @@ import Foundation ...@@ -16,7 +16,7 @@ import Foundation
class ConvAddKernel<P: PrecisionType>: Kernel, Computable { class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ConvAddParam<P>) { required init(device: MTLDevice, param: ConvAddParam<P>) {
super.init(device: device, inFunctionName: "conv3x3") super.init(device: device, inFunctionName: "conv_add_1x1")
} }
......
...@@ -24,41 +24,6 @@ struct MetalConvParam { ...@@ -24,41 +24,6 @@ struct MetalConvParam {
}; };
kernel void conv3x3(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint wightSliceCount = 36;
uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
half4 output = 0.0;
for (uint i = 0; i < inTexture.get_array_size(); ++i) {
half4 input[9];
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight = weights[weithTo + wightSliceCount * i + j * 4];
output += dot(input[j], weight);
}
}
outTexture.write(output, gid.xy, gid.z);
}
//kernel void conv_add_batch_norm_relu_3x3(texture2d_array<half, access::sample> inTexture [[texture(0)]], //kernel void conv_add_batch_norm_relu_3x3(texture2d_array<half, access::sample> inTexture [[texture(0)]],
// texture2d_array<half, access::write> outTexture [[texture(1)]], // texture2d_array<half, access::write> outTexture [[texture(1)]],
// constant MetalConvParam &param [[buffer(0)]], // constant MetalConvParam &param [[buffer(0)]],
...@@ -119,30 +84,172 @@ kernel void conv_add_batch_norm_relu_3x3(texture2d_array<float, access::sample> ...@@ -119,30 +84,172 @@ kernel void conv_add_batch_norm_relu_3x3(texture2d_array<float, access::sample>
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY); short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint wightSliceCount = 36; const uint kernelHXW = 9;
uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
float4 output = 0.0; uint input_arr_size = inTexture.get_array_size();
for (uint i = 0; i < inTexture.get_array_size(); ++i) { uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 input[9];
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i); float4 output = float4(0.0);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i); float4 input[9];
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i); for (uint i = 0; i < input_arr_size; ++i) {
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i); input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i); input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i); input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i); input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i); input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) { for (int j = 0; j < 9; ++j) {
float4 weight = weights[weithTo + wightSliceCount * i + j * 4]; float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output += dot(input[j], weight); output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
} }
} }
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_batch_norm_relu_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0); output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
} }
kernel void depthwise_conv_add_batch_norm_relu_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW;
float4 output = float4(0.0);
float4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
float4 input = inputs[j];
float4 weight = weights[weithTo + j];
output.x += input.x * weight.x;
output.y += input.y * weight.y;
output.z += input.z * weight.z;
output.w += input.w * weight.w;
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
...@@ -27,7 +27,7 @@ struct MetalConvParam { ...@@ -27,7 +27,7 @@ struct MetalConvParam {
class ConvKernel<P: PrecisionType>: Kernel, Computable { class ConvKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvParam<P>) { required init(device: MTLDevice, param: ConvParam<P>) {
super.init(device: device, inFunctionName: "conv3x3") super.init(device: device, inFunctionName: "conv_add_1x1")
let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0]) let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0])
let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1]) let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1])
let offsetZ = 0.0 let offsetZ = 0.0
......
...@@ -98,6 +98,8 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -98,6 +98,8 @@ class Tensor<P: PrecisionType>: Tensorial {
buffer = device.makeBuffer(length: count * MemoryLayout<P>.stride) buffer = device.makeBuffer(length: count * MemoryLayout<P>.stride)
if C == paddedC { if C == paddedC {
buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout<P>.stride) buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout<P>.stride)
} else if C == 1 {
buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout<P>.stride)
} else { } else {
var tmpPointer = data.pointer var tmpPointer = data.pointer
var dstPtr = buffer?.contents().bindMemory(to: P.self, capacity: count) var dstPtr = buffer?.contents().bindMemory(to: P.self, capacity: count)
...@@ -121,6 +123,37 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -121,6 +123,37 @@ class Tensor<P: PrecisionType>: Tensorial {
data.release() data.release()
} }
var width: Int {
get {
if dim.cout() == 4 {
return dim[1]
} else {
fatalError()
}
}
}
var height: Int {
get {
if dim.cout() == 4 {
return dim[2]
} else {
fatalError()
}
}
}
var channel: Int {
get {
if dim.cout() == 4 {
return dim[3]
} else {
fatalError()
}
}
}
func NCHW2NHWC(newPtr: UnsafeMutablePointer<P>) { func NCHW2NHWC(newPtr: UnsafeMutablePointer<P>) {
let N = dim[0] let N = dim[0]
let C = dim[1] let C = dim[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册