From 46fcbeeb28b59ba4be06ff296b7a041019b57194 Mon Sep 17 00:00:00 2001 From: dolphin8 Date: Mon, 27 Aug 2018 17:29:26 +0800 Subject: [PATCH] boxcoder fix --- .../Operators/Kernels/metal/BoxCoder.metal | 45 +++++++++++++++++++ .../Operators/Kernels/metal/Kernels.metal | 15 ------- 2 files changed, 45 insertions(+), 15 deletions(-) create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BoxCoder.metal diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BoxCoder.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BoxCoder.metal new file mode 100644 index 0000000000..9a17748886 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BoxCoder.metal @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +using namespace metal; + +kernel void boxcoder(texture2d_array priorBox [[texture(0)]], + texture2d_array priorBoxVar [[texture(1)]], + texture2d_array targetBox [[texture(2)]], + texture2d_array output[[texture(3)]], + uint3 gid [[thread_position_in_grid]]) { + float4 t = targetBox.read(gid.xy, gid.z); + float4 p = priorBox.read(gid.xy, gid.z); + float4 pv = priorBoxVar.read(gid.xy, gid.z); + + float px = (p.x + p.z) / 2; + float py = (p.y + p.w) / 2; + float pw = p.z - p.x; + float ph = p.w - p.y; + + float tx = pv.x * t.x * pw + px; + float ty = pv.y * t.y * ph + py; + float tw = exp(pv.z * t.z) * pw; + float th = exp(pv.w * t.w) * ph; + + + float4 r; + r.x = tx - tw / 2; + r.y = ty - th / 2; + r.z = tx + tw / 2; + r.w = ty + th / 2; + + output.write(r, gid.xy, gid.z); +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Kernels.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Kernels.metal index 4032d0c364..c94f0551ba 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Kernels.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Kernels.metal @@ -294,18 +294,3 @@ kernel void concat(texture2d_array in0 [[texture(0)]], } out.write(r, gid.xy, gid.z); } - -kernel void boxcoder(texture2d_array priorBox [[texture(0)]], - texture2d_array priorBoxVar [[texture(1)]], - texture2d_array targetBox [[texture(2)]], - texture2d_array output[[texture(3)]], - uint3 gid [[thread_position_in_grid]]) { - float4 t = targetBox.read(gid.xy, gid.z); - float4 p = priorBox.read(gid.xy, gid.z); - float4 pv = priorBoxVar.read(gid.xy, gid.z); - float ox = (p.z * pv.x * t.x + p.x) - t.z / 2; - float oy = (p.w * pv.y * t.y + p.y) - t.w / 2; - float ow = exp(pv.z * t.z) * p.z + t.z / 2; - float oh = exp(pv.w * t.w) * p.w + t.w / 2; - output.write(float4(ox, oy, ow, oh), gid.xy, gid.z); -} -- GitLab