未验证 提交 2919d335 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #521 from codeWorm2015/metal

 add relu kernel
......@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key>
<dict>
<key>orderHint</key>
<integer>4</integer>
<integer>3</integer>
</dict>
</dict>
</dict>
......
......@@ -36,13 +36,13 @@ class ViewController: UIViewController {
fatalError(" texture is nil !")
}
let loader = Loader<Float>.init()
let loader = Loader<Float16>.init()
do {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(inProgram: program)
let output = try executor.predict(input: inTexture, expect: [1, 224, 224, 3])
let executor = try Executor<Float16>.init(inDevice: device, inQueue: queue!, inProgram: program)
let output = try executor.predict(input: inTexture, expect: [1, 227, 227, 3])
print(output)
} catch let error {
print(error)
......
......@@ -30,6 +30,10 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */; };
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */; };
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */; };
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
......@@ -69,6 +73,10 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReluKernel.swift; sourceTree = "<group>"; };
FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvKernel.swift; sourceTree = "<group>"; };
FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BatchNormKernel.swift; sourceTree = "<group>"; };
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
......@@ -197,9 +205,13 @@
FC086BA520E67E8500D85EF7 /* Kernels */ = {
isa = PBXGroup;
children = (
FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */,
FCF2D73720E64E70007AC5F5 /* Kernel.swift */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */,
FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */,
FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */,
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */,
);
path = Kernels;
sourceTree = "<group>";
......@@ -316,12 +328,14 @@
files = (
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */,
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */,
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
......@@ -335,7 +349,9 @@
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */,
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */,
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
......
......@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key>
<dict>
<key>orderHint</key>
<integer>3</integer>
<integer>4</integer>
</dict>
</dict>
</dict>
......
......@@ -29,11 +29,11 @@ extension MTLDevice {
fatalError("Counld't find paddle mobile library")
}
do {
print(path)
paddleMobileMetalLibrary = try makeLibrary(filepath: path)
} catch _ {
fatalError("Counld't load paddle mobile library")
}
paddleMobileMetalLibrary = makeDefaultLibrary()
}
if let inPaddleMobileLib = paddleMobileMetalLibrary {
......@@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder {
let height = computePipline.maxTotalThreadsPerThreadgroup/width
let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1)
print(" threads per group: \(threadsPerGroup) ")
print(" out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
let groupWidth = (outTexture.width + width - 1)/width
let groupHeight = (outTexture.height + height - 1)/height
let groupDepth = slices
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth)
print("groups: \(groups) ")
setComputePipelineState(computePipline)
dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup)
}
......
......@@ -14,14 +14,22 @@
import Foundation
//typealias Float16 = Int16
//extension Float16: PrecisionType {
//}
public typealias Float16 = Int16
extension Float16: PrecisionType {
public init(inFloat: Float32) {
self = Int16(inFloat)
}
}
public protocol PrecisionType {
init(inFloat: Float32)
}
extension Float32: PrecisionType {
public init(inFloat: Float32) {
self = inFloat
}
}
public enum DataLayout {
......
......@@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
public class Executor<P: PrecisionType> {
var ops: [Runable & InferShaperable] = []
let program: Program
public init(inProgram: Program) throws {
let device: MTLDevice
let queue: MTLCommandQueue
public init(inDevice:MTLDevice, inQueue: MTLCommandQueue, inProgram: Program) throws {
program = inProgram
device = inDevice
queue = inQueue
for block in inProgram.programDesc.blocks {
for op in block.ops {
do {
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: inProgram.scope)
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
op.inferShape()
ops.append(op)
} catch let error {
......@@ -65,12 +68,29 @@ public class Executor<P: PrecisionType> {
}
public func predict(input: MTLTexture, expect: [Int]) throws -> ResultHolder<P> {
let beforeDate = Date.init()
let inputTexture = InputTexture.init(inMTLTexture: input, inExpectDim: Dim.init(inDim: expect))
program.scope.setInput(input: inputTexture)
guard let buffer = queue.makeCommandBuffer() else {
throw PaddleMobileError.predictError(message: "CommandBuffer is nil")
}
for op in ops {
op.run()
do {
try op.run(device: device, buffer: buffer)
} catch let error {
throw error
}
}
buffer.addCompletedHandler { (commandbuffer) in
let afterDate = Date.init()
print(afterDate.timeIntervalSince(beforeDate))
print(" encoder end ! ")
}
buffer.commit()
guard let outputVar = program.scope.output() else {
throw PaddleMobileError.netError(message: "output nil")
}
......@@ -78,6 +98,8 @@ public class Executor<P: PrecisionType> {
guard let output = outputVar as? ResultHolder<P> else {
throw PaddleMobileError.netError(message: "output var type error")
}
return output
}
......
......@@ -68,11 +68,24 @@ public class Loader<P: PrecisionType> {
/*
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
*/
let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file)
guard bytesRead == tensor.data.size else {
throw PaddleMobileError.loaderError(message: "param read size error")
//现在模型传入模型为 Float 类型, 这块应该根据模型来
let tmpCapacity = MemoryLayout<Float>.size * tensor.numel()
let tmpPointer = UnsafeMutablePointer<Float>.allocate(capacity: tmpCapacity);
// let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file)
// guard bytesRead == tensor.data.size else {
// throw PaddleMobileError.loaderError(message: "param read size error")
// }
// TODO: use script to convert
let bytesRead = fread(tmpPointer, 1, tmpCapacity, file)
for i in 0..<tensor.numel() {
tensor.data[i] = P.init(inFloat: tmpPointer[i])
}
tmpPointer.deinitialize(count: tmpCapacity)
tmpPointer.deallocate()
nowIndex += bytesRead
}
......@@ -125,9 +138,9 @@ public class Loader<P: PrecisionType> {
throw PaddleMobileError.loaderError(message: "get tensor desc failed")
}
guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
throw PaddleMobileError.memoryError(message: "PrecisionType not support")
}
// guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
// throw PaddleMobileError.memoryError(message: "PrecisionType not support")
// }
if (varDesc.persistable
&& varDesc.type != .FeedMiniBatch
......@@ -149,7 +162,7 @@ public class Loader<P: PrecisionType> {
scope[varDesc.name] = tensor
} else {
let dim = Dim.init(inDim: tensorDesc.NHWCDim)
scope[varDesc.name] = Texture.init(device: device, inDim: dim)
scope[varDesc.name] = Texture<P>.init(device: device, inDim: dim)
}
} else {
if varDesc.name == fetchKey {
......
......@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> {
}
}
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
func creat(device: MTLDevice, opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
}
do {
return try opCreator(opDesc, scope)
return try opCreator(device, opDesc, scope)
} catch let error {
throw error
}
}
let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] =
let opCreators: [String : (MTLDevice, OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
......
......@@ -12,29 +12,35 @@
See the License for the specific language governing permissions and
limitations under the License. */
import Metal
import Foundation
protocol Runable {
func run()
func runImpl()
func run(device: MTLDevice, buffer: MTLCommandBuffer) throws
func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws
}
extension Runable where Self: OperatorProtocol{
func run() {
runImpl()
func run(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try runImpl(device: device, buffer: buffer)
} catch let error {
throw error
}
print(type + ": " + para.outputDesc())
}
}
protocol Creator where Self: OperatorProtocol{
associatedtype OpType: OperatorProtocol & Runable & InferShaperable
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType
static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType
}
extension Creator where Self: OperatorProtocol {
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType {
static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType {
do {
return try OpType.provide(opDesc: opDesc, inScope: inScope)
return try OpType.provide(device:device, opDesc: opDesc, inScope: inScope)
} catch let error {
throw error
}
......@@ -47,19 +53,21 @@ protocol InferShaperable {
protocol OperatorProtocol {
associatedtype ParamType: OpParam
associatedtype KerType: Computable
var type: String { get }
var inputs: [String : [String]] { get }
var paraInputs: [String : [String]] { get }
var outpus: [String : [String]] { get }
var attrs: [String : Attr] { get }
var para: ParamType { get }
init(opDesc: OpDesc, inScope: Scope) throws
var kernel: KerType { get }
init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws
}
extension OperatorProtocol {
static func provide(opDesc: OpDesc, inScope: Scope) throws -> Self {
static func provide(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> Self {
do {
return try Self.init(opDesc: opDesc, inScope: inScope)
return try Self.init(device: device, opDesc: opDesc, inScope: inScope)
} catch let error {
throw error
}
......@@ -67,20 +75,23 @@ extension OperatorProtocol {
}
class Operator <ParameterType: OpParam>: OperatorProtocol{
typealias ParamType = ParameterType
class Operator <ParameterType: OpParam, KernelType: Computable>: OperatorProtocol{
typealias ParamType = ParameterType
typealias KerType = KernelType
let type: String
let inputs: [String : [String]]
let paraInputs: [String : [String]]
let outpus: [String : [String]]
let attrs: [String : Attr]
let para: ParamType
required init(opDesc: OpDesc, inScope: Scope) throws {
var kernel: KerType
required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws {
type = opDesc.type
inputs = opDesc.inputs
outpus = opDesc.outputs
attrs = opDesc.attrs
paraInputs = opDesc.paraInputs
kernel = KerType.init(device: device)
do {
para = try ParamType.init(opDesc:opDesc, inScope: inScope)
} catch let error {
......
......@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
throw error
}
}
let input: Texture
var output: Texture
let input: Texture<P>
var output: Texture<P>
let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType>
......@@ -42,12 +42,12 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let is_test: Bool
}
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator, InferShaperable{
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>, BatchNormKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = BatchNormOp<P>
func runImpl() {
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is BatchNormOp")
}
}
......
......@@ -30,8 +30,8 @@ struct ConvParam<P: PrecisionType>: OpParam {
}
}
let input: Texture
var output: Texture
let input: Texture<P>
var output: Texture<P>
let filter: Tensor<ParamPrecisionType>
let stride: [Int32]
let paddings: [Int32]
......@@ -39,7 +39,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
let groups: Int
}
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferShaperable {
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>, ConvKernel<P>>, Runable, Creator, InferShaperable {
func inferShape() {
let inDims = para.input.dim
let filterDim = para.filter.dim
......@@ -63,7 +63,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferS
}
typealias OpType = ConvOp<P>
func runImpl() {
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is conv")
}
}
......@@ -26,20 +26,20 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
throw error
}
}
let input: Texture
let input: Texture<P>
let inputY: Tensor<P>
var output: Texture
var output: Texture<P>
let axis: Int
}
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>, ElementwiseAddKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ElementwiseAddOp<P>
func runImpl() {
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is ElementwiseAddOp")
}
}
......
......@@ -15,7 +15,7 @@
import Foundation
struct FeedParam<P: PrecisionType>: OpParam{
var output: Texture
var output: Texture<P>
var input: InputTexture {
return scope.input() as! InputTexture
}
......@@ -33,19 +33,26 @@ struct FeedParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P
}
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferShaperable {
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable {
typealias OpType = FeedOp<P>
func inferShape() {
// print("feed input: \(para.input.expectDim)")
print("feed output: \(para.output.dim)")
// para.ou/tput.dim = para.input.expectDim
// para.output.dim =
// para.output.dim = para.input.expectDim
}
func runImpl() {
print("feed op")
// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>)
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
let resizeKernel = ResizeKernel<P>.init(device: device)
let resizeParam = ResizeParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim)
do {
print("feed op to compute ")
try resizeKernel.compute(commandBuffer: buffer, param: resizeParam)
print("feed op end compute ")
} catch let error {
throw error
}
}
}
......@@ -16,7 +16,7 @@ import Foundation
struct FetchParam<P: PrecisionType>: OpParam{
var output: ResultHolder<P> = ResultHolder.init(inDim: [], inResult: [])
let input: Texture
let input: Texture<P>
let scope: Scope
init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
......@@ -30,14 +30,14 @@ struct FetchParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P
}
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() {
print(para.input.dim)
}
typealias OpType = FetchOp<P>
func runImpl() {
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("fetch op")
}
}
......
//
// BatchNormKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "batchnorm")
}
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
}
}
//
// ConvKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ConvKernel<P: PrecisionType>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
}
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv")
}
}
//
// ElementwiseAddKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv")
}
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws {
}
}
......@@ -18,6 +18,12 @@ import Foundation
protocol Computable {
associatedtype ParamType
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
init(device: MTLDevice)
}
protocol KernelProtocol {
var pipline: MTLComputePipelineState { get set }
var functionName: String { get set }
}
class Kernel {
......
//
// Kernels.metal
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
......@@ -16,19 +22,70 @@ struct OutputDim {
ushort strideY;
};
kernel void resize(
texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::write> outTexture [[texture(1)]],
constant OutputDim &params [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]) {
kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant OutputDim &params [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
return;
}
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z, 0.0h), gid);
outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z);
}
kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float4 relu = fmax((float4)input, 0.0);
outTexture.write(half4(relu), gid.xy, gid.z);
}
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ReluKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ReluParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
print(" the usage of input of relu \(param.input.metalTexture.usage)")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "relu")
}
}
//
// ResizeKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
......@@ -22,15 +28,14 @@ struct OutputDim {
let strideY: UInt16
}
class ResizeKernel: Kernel, Computable{
class ResizeKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1)
encoder.setTexture(param.output, index: 1)
let strideX = param.input.width/param.expectDim[2]
let strideY = param.input.height/param.expectDim[1]
var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
......@@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{
encoder.endEncoding()
}
init(device: MTLDevice) {
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "resize")
}
}
......
......@@ -24,19 +24,23 @@ struct ReluParam<P: PrecisionType>: OpParam {
throw error
}
}
let input: Texture
var output: Texture
let input: Texture<P>
var output: Texture<P>
}
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>, ReluKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ReluOp<P>
func runImpl() {
print("this is ReluOp")
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
}
......
......@@ -38,7 +38,7 @@ extension InputTexture {
}
}
public class Texture: Tensorial {
public class Texture<P: PrecisionType>: Tensorial {
var dim: Dim
let textureDesc: MTLTextureDescriptor
var metalTexture: MTLTexture
......@@ -61,7 +61,15 @@ public class Texture: Tensorial {
} else {
fatalError(" didn't support yet")
}
tmpTextureDes.pixelFormat = .r32Float
if MemoryLayout<P>.size == 1 {
tmpTextureDes.pixelFormat = .r8Sint
} else if MemoryLayout<P>.size == 2 {
tmpTextureDes.pixelFormat = .r16Float
} else if MemoryLayout<P>.size == 4 {
tmpTextureDes.pixelFormat = .r32Float
}
tmpTextureDes.usage = .unknown
tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册