提交 8e06da6c 编写于 作者: D dolphin8

softmax

上级 630acb7a
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void softmax(__read_only image2d_t input, __kernel void softmax(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const int d0, __private const int d0,
...@@ -24,18 +26,19 @@ __kernel void softmax(__read_only image2d_t input, ...@@ -24,18 +26,19 @@ __kernel void softmax(__read_only image2d_t input,
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP | CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST; CLK_FILTER_NEAREST;
half4 maxv = read_imageh(input, sampler, int2(z * d3, y)); half4 cv = read_imageh(input, sampler, (int2)(x, y));
half4 buf[d3] = {piece}; half4 maxv = cv;
for (int i = 1; i < d3; i++) { for (int i = 0; i < d3; i++) {
buf[i] = read_imageh(input, sampler, int2(z * d3 + i, y)); half4 temp = read_imageh(input, sampler, (int2)(z * d3 + i, y));
maxv = max(maxv, buf[i]); maxv = max(maxv, temp);
} }
float4 sum = 0; half4 sum = (half4)0.0f;
// half4 x = = (half4)0.0f;
for (int i = 0; i < d3; i++) { for (int i = 0; i < d3; i++) {
buf[i] = exp(buf[i] - maxv); half4 temp = read_imageh(input, sampler, (int2)(z * d3 + i, y));
sum += buf[i]; sum += exp(temp - maxv);
} }
half4 r = buf[x] / sum; half4 r = exp(cv - maxv) / sum;
write_imageh(output, int2(z * d3 + x, y), r); write_imageh(output, (int2)(z * d3 + x, y), r);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册