提交 4e945c18 编写于 作者: D dolphin8

resize_bilinear

上级 d522cd52
...@@ -60,7 +60,8 @@ class OpCreator<P: PrecisionType> { ...@@ -60,7 +60,8 @@ class OpCreator<P: PrecisionType> {
gTransposeType : TransposeOp<P>.creat, gTransposeType : TransposeOp<P>.creat,
gPriorBoxType : PriorBoxOp<P>.creat, gPriorBoxType : PriorBoxOp<P>.creat,
gPreluType : PreluOp<P>.creat, gPreluType : PreluOp<P>.creat,
gConv2dTransposeType : ConvTransposeOp<P>.creat] gConv2dTransposeType : ConvTransposeOp<P>.creat,
gResizeBilinearType : ResizeBilinearOp<P>.creat]
private init(){} private init(){}
} }
...@@ -139,6 +139,7 @@ let gConvBnReluType = "conv_bn_relu" ...@@ -139,6 +139,7 @@ let gConvBnReluType = "conv_bn_relu"
let gDwConvBnReluType = "depth_conv_bn_relu" let gDwConvBnReluType = "depth_conv_bn_relu"
let gPreluType = "prelu" let gPreluType = "prelu"
let gConv2dTransposeType = "conv2d_transpose" let gConv2dTransposeType = "conv2d_transpose"
let gResizeBilinearType = "resize_bilinear"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
...@@ -161,5 +162,6 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out ...@@ -161,5 +162,6 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out
gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]), gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]),
gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]), gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]),
gPreluType : (inputs: ["X"], outputs: ["Out"]), gPreluType : (inputs: ["X"], outputs: ["Out"]),
gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"]) gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"]),
gResizeBilinearType : (inputs: ["X"], outputs: ["Out"])
] ]
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct ResizeBilinearMetalParam {
var ratio_h: Float32
var ratio_w: Float32
}
class ResizeBilinearKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ResizeBilinearParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
let ratio_h: Float32 = Float32(param.input.tensorDim.dims[2]) / Float32(param.output.tensorDim.dims[2])
let ratio_w: Float32 = Float32(param.input.tensorDim.dims[3]) / Float32(param.output.tensorDim.dims[3])
var p = ResizeBilinearMetalParam.init(ratio_h: ratio_h, ratio_w: ratio_w)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: ResizeBilinearParam<P>) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "resize_bilinear")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "resize_bilinear_half")
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct resize_bilinear_param {
// int32_t out_h;
// int32_t out_w;
float ratio_h;
float ratio_w;
};
kernel void resize_bilinear(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(2)]],
constant resize_bilinear_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z)
} else {
float w = gid.x * pm.ratio_w;
float h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
float w1lambda = w - w0, h1lambda = h - h0;
float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
float4 r0 = input.read(uint2(w0, h0), gid.z);
float4 r1 = input.read(uint2(w1, h0), gid.z);
float4 r2 = input.read(uint2(w0, h1), gid.z);
float4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r3 + w1lambda * r4);
}
output.write(r, gid.xy, gid.z);
}
kernel void resize_bilinear_half(texture2d_array<half, access::read> input [[texture(0)]],
texture2d_array<half, access::write> output [[texture(2)]],
constant resize_bilinear_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
half4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z)
} else {
half w = gid.x * pm.ratio_w;
half h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
half w1lambda = w - w0, h1lambda = h - h0;
half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
half4 r0 = input.read(uint2(w0, h0), gid.z);
half4 r1 = input.read(uint2(w1, h0), gid.z);
half4 r2 = input.read(uint2(w0, h1), gid.z);
half4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r3 + w1lambda * r4);
}
output.write(r, gid.xy, gid.z);
output.write(r, gid.xy, gid.z);
}
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
class ResizeBilinearParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ResizeBilinearParam.inputX(inputs: opDesc.inputs, from: inScope)
if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
fatalError()
}
output = try ResizeBilinearParam.outputOut(outputs: opDesc.outputs, from: inScope)
out_h = try ResizeBilinearParam.getAttr(key: "out_h", attrs: opDesc.attrs)
out_w = try ResizeBilinearParam.getAttr(key: "out_w", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
let out_h: Int32
let out_w: Int32
}
class ResizeBilinearOp<P: PrecisionType>: Operator<ResizeBilinearKernel<P>, ResizeBilinearParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = ResizeBilinearOp<P>
func inferShape() {
// para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册