未验证 提交 1658b359 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #526 from codeWorm2015/metal

add texture 2d to 2d array kernel
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key> <key>paddle-mobile-demo.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>3</integer> <integer>4</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -15,14 +15,41 @@ ...@@ -15,14 +15,41 @@
import UIKit import UIKit
import MetalKit import MetalKit
import paddle_mobile import paddle_mobile
import MetalPerformanceShaders
func Test<T>() -> T? {
return nil
}
class ViewController: UIViewController { class ViewController: UIViewController {
let device: MTLDevice! = MTLCreateSystemDefaultDevice() let device: MTLDevice! = MTLCreateSystemDefaultDevice()
var textureLoader: MTKTextureLoader! var textureLoader: MTKTextureLoader!
// let queue: MTLCommandQueue // let queue: MTLCommandQueue
func scaleTexture(queue: MTLCommandQueue, input: MTLTexture, complete: @escaping (MTLTexture) -> Void) {
let tmpTextureDes = MTLTextureDescriptor.init()
tmpTextureDes.width = 227
tmpTextureDes.height = 227
tmpTextureDes.depth = 1
tmpTextureDes.usage = [.shaderRead, .shaderWrite]
tmpTextureDes.pixelFormat = .rgba16Float
tmpTextureDes.textureType = .type2D
tmpTextureDes.storageMode = .shared
tmpTextureDes.cpuCacheMode = .defaultCache
let dest = device.makeTexture(descriptor: tmpTextureDes)
let scale = MPSImageLanczosScale.init(device: device)
let buffer = queue.makeCommandBuffer()
scale.encode(commandBuffer: buffer!, sourceTexture: input, destinationTexture: dest!)
buffer?.addCompletedHandler({ (buffer) in
complete(dest!)
})
buffer?.commit()
}
override func viewDidLoad() { override func viewDidLoad() {
super.viewDidLoad() super.viewDidLoad()
let queue = device.makeCommandQueue() let queue = device.makeCommandQueue()
textureLoader = MTKTextureLoader.init(device: device) textureLoader = MTKTextureLoader.init(device: device)
...@@ -35,18 +62,24 @@ class ViewController: UIViewController { ...@@ -35,18 +62,24 @@ class ViewController: UIViewController {
guard let inTexture = texture else { guard let inTexture = texture else {
fatalError(" texture is nil !") fatalError(" texture is nil !")
} }
let loader = Loader<Float16>.init() scaleTexture(queue: queue!, input: inTexture) { (inputTexture) in
do { let loader = Loader<Float16>.init()
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" do {
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath) let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let executor = try Executor<Float16>.init(inDevice: device, inQueue: queue!, inProgram: program) let program = try loader.load(device: self.device, modelPath: modelPath, paraPath: paraPath)
let output = try executor.predict(input: inTexture, expect: [1, 227, 227, 3]) let executor = try Executor<Float16>.init(inDevice: self.device, inQueue: queue!, inProgram: program)
print(output) let output = try executor.predict(input: inputTexture, expect: [1, 227, 227, 3])
} catch let error { // print(output)
print(error) } catch let error {
print(error)
}
} }
} }
} }
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; }; FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; }; FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; }; FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; };
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; }; FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; }; FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; }; FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
...@@ -79,6 +80,7 @@ ...@@ -79,6 +80,7 @@
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; }; FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; }; FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; }; FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; };
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; }; FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; }; FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; }; FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
...@@ -212,6 +214,7 @@ ...@@ -212,6 +214,7 @@
FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */, FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */,
FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */, FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */,
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */, FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */,
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */,
); );
path = Kernels; path = Kernels;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -356,6 +359,7 @@ ...@@ -356,6 +359,7 @@
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */, FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */,
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */, FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */,
FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */, FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */,
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */, FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key> <key>paddle-mobile.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>4</integer> <integer>3</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -29,7 +29,6 @@ extension MTLDevice { ...@@ -29,7 +29,6 @@ extension MTLDevice {
fatalError("Counld't find paddle mobile library") fatalError("Counld't find paddle mobile library")
} }
do { do {
print(path)
paddleMobileMetalLibrary = try makeLibrary(filepath: path) paddleMobileMetalLibrary = try makeLibrary(filepath: path)
} catch _ { } catch _ {
fatalError("Counld't load paddle mobile library") fatalError("Counld't load paddle mobile library")
...@@ -61,22 +60,21 @@ extension MTLDevice { ...@@ -61,22 +60,21 @@ extension MTLDevice {
extension MTLComputeCommandEncoder { extension MTLComputeCommandEncoder {
func dispatch(computePipline: MTLComputePipelineState, outTexture: MTLTexture) { func dispatch(computePipline: MTLComputePipelineState, outTexture: MTLTexture) {
let slices = (outTexture.depth + 3)/4 let slices = (outTexture.arrayLength * 4 + 3)/4
let width = computePipline.threadExecutionWidth let width = computePipline.threadExecutionWidth
let height = computePipline.maxTotalThreadsPerThreadgroup/width let height = computePipline.maxTotalThreadsPerThreadgroup/width
let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1) let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1)
print(" threads per group: \(threadsPerGroup) ") // print(" thread: threads per group: \(threadsPerGroup) ")
// print(" thread: out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
print(" out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
let groupWidth = (outTexture.width + width - 1)/width let groupWidth = (outTexture.width + width - 1)/width
let groupHeight = (outTexture.height + height - 1)/height let groupHeight = (outTexture.height + height - 1)/height
let groupDepth = slices let groupDepth = slices
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth) let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth)
print("groups: \(groups) ") // print("groups: \(groups) ")
setComputePipelineState(computePipline) setComputePipelineState(computePipline)
dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup) dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup)
...@@ -84,6 +82,60 @@ extension MTLComputeCommandEncoder { ...@@ -84,6 +82,60 @@ extension MTLComputeCommandEncoder {
} }
public extension MTLTexture {
func logDesc<T>(header: String = "", stridable: Bool = true) -> T? {
print(header)
print("texture: \(self)")
if textureType == .type2DArray {
for i in 0..<arrayLength{
var str: String = "slice: \(i): "
let bytes = UnsafeMutableRawPointer.allocate(byteCount: width * height * 4 * MemoryLayout<T>.size, alignment: MemoryLayout<T>.alignment)
let bytesPerRow = width * depth * 4 * MemoryLayout<T>.size
let bytesPerImage = width * height * depth * 4 * MemoryLayout<T>.size
let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: width, height: height, depth: depth))
getBytes(bytes, bytesPerRow: bytesPerRow, bytesPerImage: bytesPerImage, from: region, mipmapLevel: 0, slice: i)
let p = bytes.assumingMemoryBound(to: T.self)
str += "2d array count : \(width * height * depth * 4) \n"
if stridable {
for j in stride(from: 0, to: width * height * depth * 4 , by: width * height * depth * 4 / 100){
str += " \(p[j])"
}
} else {
for j in 0..<width * height * depth * 4 {
str += " \(p[j])"
}
}
bytes.deallocate()
print(str)
}
} else if textureType == .type2D {
var str: String = "texture 2D: "
let bytes = UnsafeMutableRawPointer.allocate(byteCount: width * height * 4 * MemoryLayout<T>.size, alignment: MemoryLayout<T>.alignment)
let bytesPerRow = width * depth * 4 * MemoryLayout<T>.size
let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: width, height: height, depth: depth))
getBytes(bytes, bytesPerRow: bytesPerRow, from: region, mipmapLevel: 0)
let p = bytes.assumingMemoryBound(to: T.self)
str += "2d count : \(width * width * 4) \n"
if stridable {
for j in stride(from: 0, to: width * height * 4, by: width * height * 4 / 100){
str += " \(p[j])"
}
} else {
for j in 0..<width * height * 4 {
str += " \(p[j])"
}
}
print(str)
bytes.deallocate()
}
return nil
}
}
......
...@@ -84,9 +84,14 @@ public class Executor<P: PrecisionType> { ...@@ -84,9 +84,14 @@ public class Executor<P: PrecisionType> {
} }
buffer.addCompletedHandler { (commandbuffer) in buffer.addCompletedHandler { (commandbuffer) in
// for op in self.ops {
// op.delogOutput()
// }
let afterDate = Date.init() let afterDate = Date.init()
print(afterDate.timeIntervalSince(beforeDate)) print(" encoder end ! time: \(afterDate.timeIntervalSince(beforeDate))")
print(" encoder end ! ")
} }
buffer.commit() buffer.commit()
......
...@@ -18,6 +18,7 @@ import Foundation ...@@ -18,6 +18,7 @@ import Foundation
protocol Runable { protocol Runable {
func run(device: MTLDevice, buffer: MTLCommandBuffer) throws func run(device: MTLDevice, buffer: MTLCommandBuffer) throws
func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws
func delogOutput()
} }
extension Runable where Self: OperatorProtocol{ extension Runable where Self: OperatorProtocol{
...@@ -27,8 +28,11 @@ extension Runable where Self: OperatorProtocol{ ...@@ -27,8 +28,11 @@ extension Runable where Self: OperatorProtocol{
} catch let error { } catch let error {
throw error throw error
} }
// print(type + ": " + para.outputDesc())
print(type + ": " + para.outputDesc()) }
func delogOutput() {
print(type + ": has no implementation" )
} }
} }
......
...@@ -48,7 +48,6 @@ class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>, BatchNormKernel ...@@ -48,7 +48,6 @@ class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>, BatchNormKernel
} }
typealias OpType = BatchNormOp<P> typealias OpType = BatchNormOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is BatchNormOp")
} }
} }
......
...@@ -64,6 +64,17 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>, ConvKernel<P>>, Runable, ...@@ -64,6 +64,17 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>, ConvKernel<P>>, Runable,
typealias OpType = ConvOp<P> typealias OpType = ConvOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is conv") do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print("conv output : ")
print(para.output.metalTexture)
// let _: Float16? = para.output.metalTexture.logDesc()
} }
} }
...@@ -40,7 +40,6 @@ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>, Eleme ...@@ -40,7 +40,6 @@ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>, Eleme
typealias OpType = ElementwiseAddOp<P> typealias OpType = ElementwiseAddOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is ElementwiseAddOp")
} }
} }
......
...@@ -33,26 +33,37 @@ struct FeedParam<P: PrecisionType>: OpParam{ ...@@ -33,26 +33,37 @@ struct FeedParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
} }
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable { class FeedOp<P: PrecisionType>: Operator<FeedParam<P>, Texture2DTo2DArrayKernel<P>>, Runable, Creator, InferShaperable {
typealias OpType = FeedOp<P> typealias OpType = FeedOp<P>
func inferShape() { func inferShape() {
// print("feed input: \(para.input.expectDim)") // print("feed input: \(para.input.expectDim)")
print("feed output: \(para.output.dim)") print("feed output: \(para.output.dim)")
// para.output.dim = // para.output.dim =
// para.output.dim = para.input.expectDim // para.output.dim = para.input.expectDim
} }
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
let resizeKernel = ResizeKernel<P>.init(device: device) let locPara = Texture2DTo2DArrayParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim)
let resizeParam = ResizeParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim)
do { do {
print("feed op to compute ") try kernel.compute(commandBuffer: buffer, param: locPara)
try resizeKernel.compute(commandBuffer: buffer, param: resizeParam)
print("feed op end compute ")
} catch let error { } catch let error {
throw error throw error
} }
// let resizeKernel = ResizeKernel<P>.init(device: device)
// let resizeParam = ResizeParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim)
// do {
// try resizeKernel.compute(commandBuffer: buffer, param: resizeParam)
// } catch let error {
// throw error
// }
}
func delogOutput() {
// para.input.mtlTexture.logDesc()
let _: Float16? = para.input.mtlTexture.logDesc(header: "feed input: ")
let _: Float16? = para.output.metalTexture.logDesc(header: "feed output: ")
} }
} }
...@@ -32,13 +32,11 @@ struct FetchParam<P: PrecisionType>: OpParam{ ...@@ -32,13 +32,11 @@ struct FetchParam<P: PrecisionType>: OpParam{
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable{ class FetchOp<P: PrecisionType>: Operator<FetchParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
print(para.input.dim) print(para.input.dim)
} }
typealias OpType = FetchOp<P> typealias OpType = FetchOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("fetch op")
} }
} }
...@@ -11,10 +11,18 @@ import Foundation ...@@ -11,10 +11,18 @@ import Foundation
class ConvKernel<P: PrecisionType>: Kernel, Computable { class ConvKernel<P: PrecisionType>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
} }
required init(device: MTLDevice) { required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv") super.init(device: device, inFunctionName: "conv")
} }
} }
...@@ -36,7 +36,6 @@ kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]], ...@@ -36,7 +36,6 @@ kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z); outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z);
} }
kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]], kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
...@@ -49,7 +48,6 @@ kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]], ...@@ -49,7 +48,6 @@ kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]],
outTexture.write(half4(relu), gid.xy, gid.z); outTexture.write(half4(relu), gid.xy, gid.z);
} }
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]], kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]], const device half4 *biasTerms [[buffer(0)]],
...@@ -62,10 +60,8 @@ kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[text ...@@ -62,10 +60,8 @@ kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[text
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]], kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
...@@ -75,17 +71,27 @@ kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]], ...@@ -75,17 +71,27 @@ kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]], kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z); const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height()){
return;
}
const half4 input = inTexture.read(gid.xy);
outTexture.write(input, gid.xy, 0);
}
...@@ -19,7 +19,6 @@ class ReluKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -19,7 +19,6 @@ class ReluKernel<P: PrecisionType>: Kernel, Computable{
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
} }
print(" the usage of input of relu \(param.input.metalTexture.usage)")
encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1) encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
import Foundation import Foundation
import MetalPerformanceShaders
struct ResizeParam { struct ResizeParam {
...@@ -29,23 +30,29 @@ struct OutputDim { ...@@ -29,23 +30,29 @@ struct OutputDim {
} }
class ResizeKernel<P: PrecisionType>: Kernel, Computable{ class ResizeKernel<P: PrecisionType>: Kernel, Computable{
var lanczos: MPSImageLanczosScale
required init(device: MTLDevice) {
lanczos = MPSImageLanczosScale.init(device: device)
super.init(device: device, inFunctionName: "resize")
}
func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws { func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { // guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") // throw PaddleMobileError.predictError(message: " encode is nil")
} // }
lanczos.encode(commandBuffer: commandBuffer, sourceTexture: param.input, destinationTexture: param.output)
encoder.setTexture(param.input, index: 0) // encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1) // encoder.setTexture(param.output, index: 1)
let strideX = param.input.width/param.expectDim[2] // let strideX = param.input.width/param.expectDim[2]
let strideY = param.input.height/param.expectDim[1] // let strideY = param.input.height/param.expectDim[1]
var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY)) // var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
encoder.setBytes(&outputDim, length: MemoryLayout<OutputDim>.size, index: 0) // encoder.setBytes(&outputDim, length: MemoryLayout<OutputDim>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output) // encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding() // encoder.endEncoding()
} }
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "resize")
}
} }
//
// Texture2DTo2DArrayKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/6.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
struct Texture2DTo2DArrayParam {
let input: MTLTexture
let output: MTLTexture
let expectDim: Dim
}
class Texture2DTo2DArrayKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: Texture2DTo2DArrayParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input)
encoder.endEncoding()
}
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "texture2d_to_2d_array")
}
}
...@@ -62,14 +62,16 @@ public class Texture<P: PrecisionType>: Tensorial { ...@@ -62,14 +62,16 @@ public class Texture<P: PrecisionType>: Tensorial {
fatalError(" didn't support yet") fatalError(" didn't support yet")
} }
if MemoryLayout<P>.size == 1 { if MemoryLayout<P>.size == 1 {
tmpTextureDes.pixelFormat = .r8Sint tmpTextureDes.pixelFormat = .rgba8Unorm
} else if MemoryLayout<P>.size == 2 { } else if MemoryLayout<P>.size == 2 {
tmpTextureDes.pixelFormat = .r16Float tmpTextureDes.pixelFormat = .rgba16Float
} else if MemoryLayout<P>.size == 4 { } else if MemoryLayout<P>.size == 4 {
tmpTextureDes.pixelFormat = .r32Float // tmpTextureDes.pixelFormat = .r32Float
tmpTextureDes.pixelFormat = .rgba32Float
} }
tmpTextureDes.usage = .unknown tmpTextureDes.usage = [.shaderRead, .shaderWrite]
tmpTextureDes.storageMode = .shared tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil " metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
...@@ -123,6 +125,7 @@ extension Texture { ...@@ -123,6 +125,7 @@ extension Texture {
public var debugDescription: String{ public var debugDescription: String{
var str = "" var str = ""
str += "Dim: \(dim) \n value:[ " str += "Dim: \(dim) \n value:[ "
// str += "\(metalTexture)"
str += " ]" str += " ]"
return str return str
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册