提交 1b7329e2 编写于 作者: Y Yanzhan Yang 提交者: GitHub

add exp, sigmoid and leaky_relu op (#1633)

上级 ef827162
......@@ -8,6 +8,7 @@
/* Begin PBXBuildFile section */
165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; };
16FBFB3E22925D040025B406 /* ActivationKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3D22925D040025B406 /* ActivationKernel.metal */; };
5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; };
A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */; };
FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */; };
......@@ -55,6 +56,7 @@
/* Begin PBXFileReference section */
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = "<group>"; };
16FBFB3D22925D040025B406 /* ActivationKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ActivationKernel.metal; sourceTree = "<group>"; };
33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = "<group>"; };
5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; };
A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ScaleKernel.metal; sourceTree = "<group>"; };
......@@ -196,6 +198,7 @@
FCC15DDA221E69E000DC3CB2 /* TransposeKernel.metal */,
A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */,
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */,
16FBFB3D22925D040025B406 /* ActivationKernel.metal */,
);
path = "paddle-mobile-metallib";
sourceTree = "<group>";
......@@ -299,6 +302,7 @@
FCC15DFB221E69E100DC3CB2 /* Softmax.inc.metal in Sources */,
FCC15E03221E69E100DC3CB2 /* TransposeKernel.metal in Sources */,
FCC15DFE221E69E100DC3CB2 /* ReshapeKernel.metal in Sources */,
16FBFB3E22925D040025B406 /* ActivationKernel.metal in Sources */,
FCC15E0D221E69E100DC3CB2 /* ConvAddMetal.metal in Sources */,
FCC15DF7221E69E100DC3CB2 /* ReshapeKernel.inc.metal in Sources */,
FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include <metal_math>
using namespace metal;
kernel void exp(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float4 output = exp(input);
outTexture.write(output, gid.xy, gid.z);
}
kernel void exp_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = float4(inTexture.read(gid.xy, gid.z));
const float4 output = exp(input);
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void sigmoid(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float4 output = 1.0 / (1.0 + exp(-input));
outTexture.write(output, gid.xy, gid.z);
}
kernel void sigmoid_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = float4(inTexture.read(gid.xy, gid.z));
const float4 output = 1.0 / (1.0 + exp(-input));
outTexture.write(half4(output), gid.xy, gid.z);
}
......@@ -19,6 +19,10 @@ struct Relu6Param {
float threshold;
};
struct LeakyReluParam {
float alpha;
};
kernel void relu_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
......@@ -70,3 +74,31 @@ kernel void relu6(texture2d_array<float, access::sample> inTexture [[texture(0)]
const float4 relu = fmin(fmax((float4)input, 0.0), threshold);
outTexture.write(float4(relu), gid.xy, gid.z);
}
kernel void leaky_relu(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant LeakyReluParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float alpha = pm.alpha;
const float4 output = fmax(input, input * alpha);
outTexture.write(output, gid.xy, gid.z);
}
kernel void leaky_relu_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant LeakyReluParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = float4(inTexture.read(gid.xy, gid.z));
const float alpha = pm.alpha;
const float4 output = fmax(input, input * alpha);
outTexture.write(half4(output), gid.xy, gid.z);
}
......@@ -9,6 +9,12 @@
/* Begin PBXBuildFile section */
165F38D32276CDEA0088E29F /* ConvAddReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */; };
165F38D52276CE7D0088E29F /* ConvAddReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */; };
16FBFB36229259C60025B406 /* ExpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB35229259C60025B406 /* ExpOp.swift */; };
16FBFB3822925B030025B406 /* ExpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3722925B030025B406 /* ExpKernel.swift */; };
16FBFB3A22925C3E0025B406 /* SigmoidKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */; };
16FBFB3C22925C800025B406 /* SigmoidOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3B22925C800025B406 /* SigmoidOp.swift */; };
16FBFB40229266FE0025B406 /* LeakyReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3F229266FE0025B406 /* LeakyReluOp.swift */; };
16FBFB422292684E0025B406 /* LeakyReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB412292684E0025B406 /* LeakyReluKernel.swift */; };
456BB7B421F5B356001474E2 /* Framework.pbobjc.m in Sources */ = {isa = PBXBuildFile; fileRef = 456BB7B221F5B356001474E2 /* Framework.pbobjc.m */; settings = {COMPILER_FLAGS = "-fno-objc-arc"; }; };
456BB7B521F5B356001474E2 /* Framework.pbobjc.h in Headers */ = {isa = PBXBuildFile; fileRef = 456BB7B321F5B356001474E2 /* Framework.pbobjc.h */; settings = {ATTRIBUTES = (Public, ); }; };
4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */; };
......@@ -109,6 +115,12 @@
/* Begin PBXFileReference section */
165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluOp.swift; sourceTree = "<group>"; };
165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluKernel.swift; sourceTree = "<group>"; };
16FBFB35229259C60025B406 /* ExpOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ExpOp.swift; sourceTree = "<group>"; };
16FBFB3722925B030025B406 /* ExpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ExpKernel.swift; sourceTree = "<group>"; };
16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SigmoidKernel.swift; sourceTree = "<group>"; };
16FBFB3B22925C800025B406 /* SigmoidOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SigmoidOp.swift; sourceTree = "<group>"; };
16FBFB3F229266FE0025B406 /* LeakyReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LeakyReluOp.swift; sourceTree = "<group>"; };
16FBFB412292684E0025B406 /* LeakyReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LeakyReluKernel.swift; sourceTree = "<group>"; };
456BB7B221F5B356001474E2 /* Framework.pbobjc.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = Framework.pbobjc.m; sourceTree = "<group>"; };
456BB7B321F5B356001474E2 /* Framework.pbobjc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Framework.pbobjc.h; sourceTree = "<group>"; };
4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpOp.swift; sourceTree = "<group>"; };
......@@ -338,6 +350,9 @@
165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */,
A73DC748227F1C7A001EB663 /* ScaleOp.swift */,
A7F26FD922842EF200365D47 /* Relu6Op.swift */,
16FBFB35229259C60025B406 /* ExpOp.swift */,
16FBFB3B22925C800025B406 /* SigmoidOp.swift */,
16FBFB3F229266FE0025B406 /* LeakyReluOp.swift */,
);
path = Operators;
sourceTree = "<group>";
......@@ -395,6 +410,9 @@
165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */,
A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */,
A7F26FDB2284301500365D47 /* Relu6Kernel.swift */,
16FBFB3722925B030025B406 /* ExpKernel.swift */,
16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */,
16FBFB412292684E0025B406 /* LeakyReluKernel.swift */,
);
path = Kernels;
sourceTree = "<group>";
......@@ -545,6 +563,7 @@
A73DC74B227F1EDE001EB663 /* ScaleOpKernel.swift in Sources */,
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */,
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */,
16FBFB3A22925C3E0025B406 /* SigmoidKernel.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
......@@ -565,7 +584,9 @@
FCEB684C212F093800D2448E /* PreluOp.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */,
16FBFB422292684E0025B406 /* LeakyReluKernel.swift in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
16FBFB36229259C60025B406 /* ExpOp.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
......@@ -596,6 +617,7 @@
4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */,
FCBCCC6D2123073A00D94F7E /* BoxcoderKernel.swift in Sources */,
FCB40E5921E0DCAB0075EC91 /* FetchKernel.swift in Sources */,
16FBFB3822925B030025B406 /* ExpKernel.swift in Sources */,
FCBCCC69212306D300D94F7E /* ConcatKernel.swift in Sources */,
FCDDC6C8212FA3CA00E5EF74 /* ConvTransposeKernel.swift in Sources */,
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */,
......@@ -613,6 +635,7 @@
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
16FBFB3C22925C800025B406 /* SigmoidOp.swift in Sources */,
165F38D32276CDEA0088E29F /* ConvAddReluOp.swift in Sources */,
FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
......@@ -625,6 +648,7 @@
FCD04E6820F315020007374F /* PoolKernel.swift in Sources */,
FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */,
FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */,
16FBFB40229266FE0025B406 /* LeakyReluOp.swift in Sources */,
FC039BBE20E11CC20081E9F8 /* PMOpDesc.swift in Sources */,
FC9797C921D6101D00F2FD90 /* ResizeBilinearOp.swift in Sources */,
4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */,
......
......@@ -73,7 +73,10 @@ class OpCreator<P: PrecisionProtocol> {
gReshape2Type : ReshapeOp<P>.creat,
gTranspose2Type : TransposeOp<P>.creat,
gScaleType : ScaleOp<P>.creat,
gRelu6Type : Relu6Op<P>.creat
gRelu6Type : Relu6Op<P>.creat,
gExpType : ExpOp<P>.creat,
gSigmoidType : SigmoidOp<P>.creat,
gLeakyReluType : LeakyReluOp<P>.creat,
]
......
......@@ -185,6 +185,9 @@ let gReshape2Type = "reshape2"
let gTranspose2Type = "transpose2"
let gScaleType = "scale"
let gRelu6Type = "relu6"
let gExpType = "exp"
let gSigmoidType = "sigmoid"
let gLeakyReluType = "leaky_relu"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
......@@ -220,4 +223,7 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out
gTranspose2Type : (inputs: ["X"], outputs: ["Out"]),
gScaleType : (inputs: ["X"], outputs: ["Out"]),
gRelu6Type : (inputs: ["X"], outputs: ["Out"]),
gExpType : (inputs: ["X"], outputs: ["Out"]),
gSigmoidType : (inputs: ["X"], outputs: ["Out"]),
gLeakyReluType : (inputs: ["X"], outputs: ["Out"]),
]
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ExpParam<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try ExpParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ExpParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture
var output: Texture
}
class ExpOp<P: PrecisionProtocol>: Operator<ExpKernel<P>, ExpParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = ExpOp<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ExpKernel<P: PrecisionProtocol>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ExpParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: ExpParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "exp", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "exp_half", initContext: initContext)
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct LeakyReluMetalParam {
let alpha: Float32
}
class LeakyReluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: LeakyReluMetalParam
func compute(commandBuffer: MTLCommandBuffer, param: LeakyReluParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<Relu6MetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: LeakyReluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = LeakyReluMetalParam(alpha: param.alpha)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "leaky_relu", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "leaky_relu_half", initContext: initContext)
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class SigmoidKernel<P: PrecisionProtocol>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: SigmoidParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: SigmoidParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "sigmoid", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "sigmoid_half", initContext: initContext)
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class LeakyReluParam<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try LeakyReluParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try LeakyReluParam.outputOut(outputs: opDesc.outputs, from: inScope)
alpha = try LeakyReluParam.getAttr(key: "alpha", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture
var output: Texture
let alpha: Float32
}
class LeakyReluOp<P: PrecisionProtocol>: Operator<LeakyReluKernel<P>, LeakyReluParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = LeakyReluOp<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class SigmoidParam<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try SigmoidParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try SigmoidParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture
var output: Texture
}
class SigmoidOp<P: PrecisionProtocol>: Operator<SigmoidKernel<P>, SigmoidParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = SigmoidOp<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册