未验证 提交 58c804e4 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #608 from dolphin8/metal

reshape & softmax
...@@ -70,7 +70,7 @@ kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0) ...@@ -70,7 +70,7 @@ kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
const half4 input = inTexture.read(gid.xy, gid.z); const half4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z]; half4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(input, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
//kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]], //kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[texture(0)]],
......
...@@ -15,11 +15,30 @@ ...@@ -15,11 +15,30 @@
import Foundation import Foundation
class PoolKernel<P: PrecisionType>: Kernel, Computable{ class PoolKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: PoolParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: PoolParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
print("Pool compute")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(UnsafeRawPointer(param.ksize), length: param.ksize.count * 4, index: 0)
encoder.setBytes(UnsafeRawPointer(param.stride), length: param.stride.count * 4, index: 1)
encoder.setBytes(UnsafeRawPointer(param.padding), length: param.padding.count * 4, index: 2)
var poolType: Int32
switch param.poolType {
case "max":
poolType = 0
case "avg":
poolType = 1
default:
throw PaddleMobileError.predictError(message: " unknown pooltype " + param.poolType)
}
encoder.setBytes(&poolType, length: 4, index: 3)
encoder.endEncoding()
} }
required init(device: MTLDevice, param: PoolParam<P>) { required init(device: MTLDevice, param: PoolParam<P>) {
super.init(device: device, inFunctionName: "relu") super.init(device: device, inFunctionName: "pool")
} }
} }
...@@ -16,11 +16,16 @@ import Foundation ...@@ -16,11 +16,16 @@ import Foundation
class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: ReshapeParam<P>) { required init(device: MTLDevice, param: ReshapeParam<P>) {
super.init(device: device, inFunctionName: "relu") super.init(device: device, inFunctionName: "reshape")
} }
func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
print("Reshape compute")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.endEncoding()
} }
} }
...@@ -17,9 +17,16 @@ import Foundation ...@@ -17,9 +17,16 @@ import Foundation
class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: SoftmaxParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: SoftmaxParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
print("softmax compute")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.endEncoding()
} }
required init(device: MTLDevice, param: SoftmaxParam<P>) { required init(device: MTLDevice, param: SoftmaxParam<P>) {
super.init(device: device, inFunctionName: "relu") super.init(device: device, inFunctionName: "softmax")
} }
} }
...@@ -20,18 +20,31 @@ class PoolParam<P: PrecisionType>: OpParam { ...@@ -20,18 +20,31 @@ class PoolParam<P: PrecisionType>: OpParam {
do { do {
input = try PoolParam.inputX(inputs: opDesc.inputs, from: inScope) input = try PoolParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try PoolParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try PoolParam.outputOut(outputs: opDesc.outputs, from: inScope)
poolType = try PoolParam.getAttr(key: "pooling_type", attrs: opDesc.attrs)
ksize = try PoolParam.getAttr(key: "ksize", attrs: opDesc.attrs)
stride = try PoolParam.getAttr(key: "strides", attrs: opDesc.attrs)
padding = try PoolParam.getAttr(key: "paddings", attrs: opDesc.attrs)
ceilMode = try PoolParam.getAttr(key: "ceil_mode", attrs: opDesc.attrs)
globalPooling = try PoolParam.getAttr(key: "global_pooling", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
// let buffer = input.metalTexture.buffer.contents().assumingMemoryBound(to: P.self)
} }
let input: Texture<P> let input: Texture<P>
var output: Texture<P> var output: Texture<P>
var ksize: [Int32]
var stride: [Int32]
var padding: [Int32]
var poolType: String
var ceilMode: Bool
var globalPooling: Bool
} }
class PoolOp<P: PrecisionType>: Operator<PoolKernel<P>, PoolParam<P>>, Runable, Creator, InferShaperable{ class PoolOp<P: PrecisionType>: Operator<PoolKernel<P>, PoolParam<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
para.output.dim = para.input.dim // para.output.dim = para.input.dim
} }
typealias OpType = PoolOp<P> typealias OpType = PoolOp<P>
...@@ -42,4 +55,14 @@ class PoolOp<P: PrecisionType>: Operator<PoolKernel<P>, PoolParam<P>>, Runable, ...@@ -42,4 +55,14 @@ class PoolOp<P: PrecisionType>: Operator<PoolKernel<P>, PoolParam<P>>, Runable,
throw error throw error
} }
} }
func delogOutput() {
print("pool2d delog")
let _: P? = para.input.metalTexture.logDesc(header: "pool2d input: ", stridable: false)
print(para.ksize)
print(para.stride)
print(para.padding)
print(para.poolType)
let _: P? = para.output.metalTexture.logDesc(header: "pool2d output: ", stridable: false)
}
} }
...@@ -31,7 +31,7 @@ class ReshapeParam<P: PrecisionType>: OpParam { ...@@ -31,7 +31,7 @@ class ReshapeParam<P: PrecisionType>: OpParam {
class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, Runable, Creator, InferShaperable{ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
para.output.dim = para.input.dim // para.output.dim = para.input.dim
} }
typealias OpType = ReshapeOp<P> typealias OpType = ReshapeOp<P>
......
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
int main() { int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader; paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto time1 = time(); auto time1 = time();
// auto program = loader.Load(g_mobilenet_combine, true); // auto program = loader.Load(g_mobilenet_combine, true);
auto program = loader.Load(g_mobilenet_combine + "/model", auto program = loader.Load(g_mobilenet_combine + "/model",
g_mobilenet_combine + "/params", true); g_mobilenet_combine + "/params", true);
auto time2 = time(); auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time1) << "ms"; DLOG << "load cost :" << time_diff(time1, time1) << "ms";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册