未验证 提交 4359398d 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #518 from codeWorm2015/metal

add kernel imp
......@@ -77,6 +77,7 @@ metal/Pods/
SwiftProtobuf.framework
paddle-mobile.xcworkspace
metal/models/
metal/images/
......
......@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key>
<dict>
<key>orderHint</key>
<integer>3</integer>
<integer>4</integer>
</dict>
</dict>
</dict>
......
......@@ -13,19 +13,36 @@
limitations under the License. */
import UIKit
import MetalKit
import paddle_mobile
class ViewController: UIViewController {
let device: MTLDevice! = MTLCreateSystemDefaultDevice()
var textureLoader: MTKTextureLoader!
// let queue: MTLCommandQueue
override func viewDidLoad() {
super.viewDidLoad()
let queue = device.makeCommandQueue()
textureLoader = MTKTextureLoader.init(device: device)
guard let appleImage = UIImage.init(named: "apple.jpg"), let cgImage = appleImage.cgImage else {
fatalError(" image nil !")
}
let texture = try? textureLoader.newTexture(cgImage: cgImage, options: [:]) ?! " texture loader error"
guard let inTexture = texture else {
fatalError(" texture is nil !")
}
let loader = Loader<Float>.init()
do {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(modelPath: modelPath, paraPath: paraPath)
let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(inProgram: program)
let output = try executor.predict(input: Texture.init())
let output = try executor.predict(input: inTexture, expect: [1, 224, 224, 3])
print(output)
} catch let error {
print(error)
......
......@@ -30,11 +30,15 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; };
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
......@@ -65,11 +69,15 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; };
FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Kernel.swift"; sourceTree = SOURCE_ROOT; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
......@@ -140,6 +148,7 @@
FC039B9420E11C9A0081E9F8 /* Extensions.swift */,
FC039B9520E11C9A0081E9F8 /* Errors.swift */,
FC039B9620E11C9A0081E9F8 /* Types.swift */,
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */,
);
path = Common;
sourceTree = "<group>";
......@@ -157,8 +166,8 @@
FC039BA320E11CBC0081E9F8 /* Operators */ = {
isa = PBXGroup;
children = (
FC086BA520E67E8500D85EF7 /* Kernels */,
FCD592FA20E248EC00252966 /* Base */,
FCD592F920E248EC00252966 /* Kernels */,
FC039BA420E11CBC0081E9F8 /* ConvOp.swift */,
FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
......@@ -185,9 +194,12 @@
path = Program;
sourceTree = "<group>";
};
FCD592F920E248EC00252966 /* Kernels */ = {
FC086BA520E67E8500D85EF7 /* Kernels */ = {
isa = PBXGroup;
children = (
FCF2D73720E64E70007AC5F5 /* Kernel.swift */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */,
);
path = Kernels;
sourceTree = "<group>";
......@@ -308,6 +320,10 @@
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */,
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */,
......@@ -461,17 +477,19 @@
CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic;
DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = Z5M2UUN5YV;
DEVELOPMENT_TEAM = A798K58VVL;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
......@@ -487,17 +505,19 @@
CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic;
DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = Z5M2UUN5YV;
DEVELOPMENT_TEAM = A798K58VVL;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
......
......@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key>
<dict>
<key>orderHint</key>
<integer>4</integer>
<integer>3</integer>
</dict>
</dict>
</dict>
......
......@@ -20,4 +20,5 @@ public enum PaddleMobileError: Error{
case memoryError(message: String)
case paramError(message: String)
case opError(message: String)
case predictError(message: String)
}
//
// MetalExtension.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/2.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
fileprivate var defaultMetalLibrary: MTLLibrary?
fileprivate var paddleMobileMetalLibrary: MTLLibrary?
extension MTLDevice {
func defaultLibrary() -> MTLLibrary {
if defaultMetalLibrary == nil {
defaultMetalLibrary = makeDefaultLibrary()
}
if let inDefaultLib = defaultMetalLibrary {
return inDefaultLib
} else {
fatalError(" default metal libary is nil")
}
}
func paddleMobileLibrary() -> MTLLibrary {
if paddleMobileMetalLibrary == nil {
guard let path = Bundle.init(for: Kernel.self).path(forResource: "default", ofType: "metallib") else {
fatalError("Counld't find paddle mobile library")
}
do {
paddleMobileMetalLibrary = try makeLibrary(filepath: path)
} catch _ {
fatalError("Counld't load paddle mobile library")
}
paddleMobileMetalLibrary = makeDefaultLibrary()
}
if let inPaddleMobileLib = paddleMobileMetalLibrary {
return inPaddleMobileLib
} else {
fatalError("PaddleMobile metal libary is nil")
}
}
func pipeLine(funcName: String, inPaddleMobileLib: Bool = true) -> MTLComputePipelineState {
let useLib = inPaddleMobileLib ? paddleMobileLibrary() : defaultLibrary()
guard let function = useLib.makeFunction(name: funcName) else {
fatalError(" function " + funcName + " not found")
}
do {
let pipLine = try makeComputePipelineState(function: function)
return pipLine
} catch _ {
fatalError("make pip line error occured")
}
}
}
extension MTLComputeCommandEncoder {
func dispatch(computePipline: MTLComputePipelineState, outTexture: MTLTexture) {
let slices = (outTexture.depth + 3)/4
let width = computePipline.threadExecutionWidth
let height = computePipline.maxTotalThreadsPerThreadgroup/width
let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1)
let groupWidth = (outTexture.width + width - 1)/width
let groupHeight = (outTexture.height + height - 1)/height
let groupDepth = slices
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth)
setComputePipelineState(computePipline)
dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup)
}
}
......@@ -40,3 +40,10 @@ extension Texture: Variant {
extension ResultHolder: Variant {
}
extension InputTexture: Variant {
}
extension MTLTexture where Self: Variant {
}
......@@ -64,17 +64,23 @@ public class Executor<P: PrecisionType> {
}
}
public func predict(input: Texture) throws -> ResultHolder<P> {
program.scope[program.feedKey] = input
public func predict(input: MTLTexture, expect: [Int]) throws -> ResultHolder<P> {
let inputTexture = InputTexture.init(inMTLTexture: input, inExpectDim: Dim.init(inDim: expect))
program.scope.setInput(input: inputTexture)
for op in ops {
op.run()
}
let outputVar = program.scope[program.fetchKey]
guard let outputVar = program.scope.output() else {
throw PaddleMobileError.netError(message: "output nil")
}
guard let output = outputVar as? ResultHolder<P> else {
throw PaddleMobileError.netError(message: "output var type error")
}
return output
}
}
//public let paddle_executor: Executor = Executor.init()
......@@ -81,7 +81,7 @@ public class Loader<P: PrecisionType> {
}
}
public init(){}
public func load(modelPath: String, paraPath: String) throws -> Program{
public func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program{
guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else {
throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !")
}
......@@ -89,7 +89,6 @@ public class Loader<P: PrecisionType> {
do {
let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
serializedData: modelData)
let scope = Scope.init()
let programDesc = ProgramDesc.init(protoProgram: protoProgram)
guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else {
......@@ -116,6 +115,8 @@ public class Loader<P: PrecisionType> {
throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
}
let scope = Scope.init(inFeedKey: feedKey, inFetchKey: fetchKey)
// to load memory
for block in programDesc.blocks {
for varDesc in block.vars {
......@@ -148,19 +149,18 @@ public class Loader<P: PrecisionType> {
scope[varDesc.name] = tensor
} else {
let dim = Dim.init(inDim: tensorDesc.NHWCDim)
scope[varDesc.name] = Texture.init(inDim: dim, inLayout: .NHWC)
scope[varDesc.name] = Texture.init(device: device, inDim: dim)
}
} else {
if varDesc.name == fetchKey {
scope[varDesc.name] = ResultHolder<P>.init(inDim: [], inResult: [])
} else if varDesc.name == feedKey {
scope[varDesc.name] = Texture.init()
}
}
}
}
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey)
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope)
return program
} catch _ {
......
......@@ -23,11 +23,11 @@ import Foundation
protocol OpParam {
associatedtype OutputType: Variant
var output: OutputType { get }
var output: OutputType { get set }
func outputDesc() -> String
associatedtype ParamPrecisionType: PrecisionType
init(opDesc: OpDesc, scope: Scope) throws
init(opDesc: OpDesc, inScope: Scope) throws
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType
static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
......
......@@ -82,7 +82,7 @@ class Operator <ParameterType: OpParam>: OperatorProtocol{
attrs = opDesc.attrs
paraInputs = opDesc.paraInputs
do {
para = try ParamType.init(opDesc:opDesc, scope: inScope)
para = try ParamType.init(opDesc:opDesc, inScope: inScope)
} catch let error {
throw error
}
......
......@@ -16,14 +16,14 @@ import Foundation
struct BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope)
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope)
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: scope)
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope)
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
is_test = try BatchNormParam.getAttr(key: "is_test", attrs: opDesc.attrs)
......@@ -32,7 +32,7 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
}
}
let input: Texture
let output: Texture
var output: Texture
let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType>
......
......@@ -16,11 +16,11 @@ import Foundation
struct ConvParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
init(opDesc: OpDesc, inScope: Scope) throws {
do {
filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: scope)
input = try ConvParam.input(inputs: opDesc.inputs, from: scope)
output = try ConvParam.output(outputs: opDesc.outputs, from: scope)
filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope)
input = try ConvParam.input(inputs: opDesc.inputs, from: inScope)
output = try ConvParam.output(outputs: opDesc.outputs, from: inScope)
stride = try ConvParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs)
......@@ -31,7 +31,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
}
let input: Texture
let output: Texture
var output: Texture
let filter: Tensor<ParamPrecisionType>
let stride: [Int32]
let paddings: [Int32]
......
......@@ -16,11 +16,11 @@ import Foundation
struct ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope)
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: inScope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
throw error
......@@ -28,7 +28,7 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
}
let input: Texture
let inputY: Tensor<P>
let output: Texture
var output: Texture
let axis: Int
}
......
......@@ -16,12 +16,15 @@ import Foundation
struct FeedParam<P: PrecisionType>: OpParam{
var output: Texture
var input: Texture
var input: InputTexture {
return scope.input() as! InputTexture
}
let scope: Scope
init(opDesc: OpDesc, scope: Scope) throws {
init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
do {
input = try FeedParam.inputX(inputs: opDesc.inputs, from: scope)
output = try FeedParam.outputOut(outputs: opDesc.outputs, from: scope)
output = try FeedParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
......@@ -34,10 +37,15 @@ class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferS
typealias OpType = FeedOp<P>
func inferShape() {
// print("feed input: \(para.input.expectDim)")
print("feed output: \(para.output.dim)")
// para.ou/tput.dim = para.input.expectDim
}
func runImpl() {
print("feed op")
// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>)
}
}
......@@ -15,13 +15,13 @@
import Foundation
struct FetchParam<P: PrecisionType>: OpParam{
let output: ResultHolder<P>
var output: ResultHolder<P> = ResultHolder.init(inDim: [], inResult: [])
let input: Texture
init(opDesc: OpDesc, scope: Scope) throws {
let scope: Scope
init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
do {
input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope)
output = try FetchParam.outputOut(outputs: opDesc.outputs, from: scope)
input = try FetchParam.inputX(inputs: opDesc.inputs, from: inScope)
} catch let error {
throw error
}
......@@ -32,6 +32,7 @@ struct FetchParam<P: PrecisionType>: OpParam{
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
print(para.input.dim)
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Metal
import Foundation
protocol Computable {
associatedtype ParamType
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
}
class Kernel {
let pipline: MTLComputePipelineState
let functionName: String
init(device: MTLDevice, inFunctionName: String) {
pipline = device.pipeLine(funcName: inFunctionName)
functionName = inFunctionName
}
}
//
// Kernels.metal
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
#include <metal_stdlib>
using namespace metal;
struct OutputDim {
ushort width;
ushort height;
ushort strideX;
ushort strideY;
};
kernel void resize(
texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::write> outTexture [[texture(1)]],
constant OutputDim &params [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
return;
}
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z, 0.0h), gid);
}
//
// ResizeKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
struct ResizeParam {
let input: MTLTexture
let output: MTLTexture
let expectDim: Dim
}
struct OutputDim {
let width: UInt16
let height: UInt16
let strideX: UInt16
let strideY: UInt16
}
class ResizeKernel: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1)
let strideX = param.input.width/param.expectDim[2]
let strideY = param.input.height/param.expectDim[1]
var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
encoder.setBytes(&outputDim, length: MemoryLayout<OutputDim>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding()
}
init(device: MTLDevice) {
super.init(device: device, inFunctionName: "resize")
}
}
......@@ -16,16 +16,16 @@ import Foundation
struct ReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ReluParam.inputX(inputs: opDesc.inputs, from: scope)
output = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope)
input = try ReluParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ReluParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture
let output: Texture
var output: Texture
}
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{
......
......@@ -16,15 +16,11 @@ import Foundation
public struct Program {
let paramPath: String
let feedKey: String
let fetchKey: String
let programDesc: ProgramDesc
let scope: Scope
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) {
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) {
programDesc = ProgramDesc.init(protoProgram: protoProgramDesc)
paramPath = inParamPath
scope = inScope
feedKey = inFeedKey
fetchKey = inFetchKey
}
}
......@@ -15,6 +15,29 @@
import Foundation
class Scope {
let feedKey: String
let fetchKey: String
func setInput(input: Variant) {
vars[feedKey] = input
}
func setOutput(output: Variant) {
vars[fetchKey] = output
}
func input() -> Variant? {
return vars[feedKey];
}
func output() -> Variant? {
return vars[fetchKey];
}
init(inFeedKey: String, inFetchKey: String) {
feedKey = inFeedKey
fetchKey = inFetchKey
}
var vars: [String : Variant] = [:]
subscript(key: String) -> Variant?{
get {
......
......@@ -15,7 +15,7 @@
import Foundation
public struct Dim {
init(inDim: [Int]) {
public init(inDim: [Int]) {
dims = inDim
}
......
......@@ -18,10 +18,8 @@ protocol Tensorial: CustomStringConvertible, CustomDebugStringConvertible{
var dim: Dim { get set }
func numel() -> Int
var layout: DataLayout { get }
init(inDim: Dim, inLayout: DataLayout)
}
extension Tensorial {
func numel() -> Int {
return dim.numel()
......
......@@ -15,29 +15,96 @@
import Metal
import Foundation
class InputTexture {
let mtlTexture: MTLTexture
let expectDim: Dim
init(inMTLTexture: MTLTexture, inExpectDim: Dim) {
mtlTexture = inMTLTexture
expectDim = inExpectDim
}
}
extension InputTexture {
var description: String {
get{
return mtlTexture.description
}
}
var debugDescription: String {
get {
return mtlTexture.debugDescription ?? " MetalTexture "
}
}
}
public class Texture: Tensorial {
var dim: Dim
let textureDesc: MTLTextureDescriptor
var metalTexture: MTLTexture
required public init(inDim: Dim, inLayout: DataLayout = .NHWC) {
init(device: MTLDevice, inDim: Dim, inLayout: DataLayout = .NHWC) {
dim = inDim
layout = inLayout
let tmpTextureDes = MTLTextureDescriptor.init()
if inDim.cout() == 1 {
tmpTextureDes.width = inDim[0]
tmpTextureDes.textureType = .type1D
} else if inDim.cout() == 4 {
tmpTextureDes.height = inDim[1]
tmpTextureDes.width = inDim[2]
// print("n : \(inDim[0])")
// print(inDim[3] * inDim[0])
tmpTextureDes.depth = 1
tmpTextureDes.arrayLength = (inDim[3] * inDim[0] + 3)/4
tmpTextureDes.textureType = .type2DArray
} else {
fatalError(" didn't support yet")
}
tmpTextureDes.pixelFormat = .r32Float
tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
}
private(set) var layout: DataLayout
// let texture: MTLTexture
public init(inTexture: MTLTexture, inDim: Dim) {
// texture = inTexture
dim = inDim
layout = .NHWC
}
// required public init(inDim: Dim, inLayout: DataLayout = .NHWC, inTexture: MTLTexture) {
// dim = inDim
// layout = inLayout
// metalTexture = inTexture
// let tmpTextureDes = MTLTextureDescriptor.init()
//
// if inDim.cout() == 1 {
// tmpTextureDes.width = inDim[0]
// tmpTextureDes.textureType = .type1D
// } else if inDim.cout() == 2 {
// tmpTextureDes.height = inDim[0]
// tmpTextureDes.width = inDim[1]
// tmpTextureDes.textureType = .type2D
// } else if inDim.cout() == 3 {
// fatalError(" not support texture dim 3")
// } else if inDim.cout() == 4 {
// tmpTextureDes.height = inDim[1]
// tmpTextureDes.width = inDim[2]
// tmpTextureDes.depth = inDim[3] * inDim[1]
// tmpTextureDes.textureType = .type2DArray
// }
//
// tmpTextureDes.pixelFormat = .r32Float
// tmpTextureDes.storageMode = .shared
// textureDesc = tmpTextureDes
// let device = MTLCreateSystemDefaultDevice()
// metalTexture = device!.makeTexture(descriptor: tmpTextureDes)!
// }
public init(inLayout: DataLayout = .NHWC) {
dim = Dim.init(inDim: [])
layout = inLayout
}
// init() {
// dim = Dim.init(inDim: [])
// layout = .NCHW
// let device = MTLCreateSystemDefaultDevice()
// textureDesc = MTLTextureDescriptor.init()
// metalTexture = device!.makeTexture(descriptor: textureDesc)!
// }
private(set) var layout: DataLayout
}
extension Texture {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册