提交 8e06da6c 编写于 作者: D dolphin8

softmax

上级 630acb7a
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void softmax(__read_only image2d_t input,
__write_only image2d_t output,
__private const int d0,
......@@ -24,18 +26,19 @@ __kernel void softmax(__read_only image2d_t input,
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 maxv = read_imageh(input, sampler, int2(z * d3, y));
half4 buf[d3] = {piece};
for (int i = 1; i < d3; i++) {
buf[i] = read_imageh(input, sampler, int2(z * d3 + i, y));
maxv = max(maxv, buf[i]);
half4 cv = read_imageh(input, sampler, (int2)(x, y));
half4 maxv = cv;
for (int i = 0; i < d3; i++) {
half4 temp = read_imageh(input, sampler, (int2)(z * d3 + i, y));
maxv = max(maxv, temp);
}
float4 sum = 0;
half4 sum = (half4)0.0f;
// half4 x = = (half4)0.0f;
for (int i = 0; i < d3; i++) {
buf[i] = exp(buf[i] - maxv);
sum += buf[i];
half4 temp = read_imageh(input, sampler, (int2)(z * d3 + i, y));
sum += exp(temp - maxv);
}
half4 r = buf[x] / sum;
half4 r = exp(cv - maxv) / sum;
write_imageh(output, int2(z * d3 + x, y), r);
write_imageh(output, (int2)(z * d3 + x, y), r);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册