提交 2d7d1749 编写于 作者: M Molly Smith

Add alibi/mask check

上级 fd803088
......@@ -92,22 +92,79 @@ __global__ void attn_softmax_v2(__half* vals,
bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + ( __half2float(alibi[data_id + alibi_offset])) + ( __half2float(mask[data_id + mask_offset]))
if (mask && alibi){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset])) + (__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + ( __half2float(alibi[data_id + alibi_offset + reduceWidth])) + ( __half2float(mask[data_id + mask_offset + reduceWidth]))
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + (__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2])) + (__half2float(mask[data_id + mask_offset + reduceWidth*2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3])) + (__half2float(mask[data_id + mask_offset + reduceWidth*3]))
: minus_infinity;
b.sync();
}
else if (mask){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + ( __half2float(alibi[data_id + alibi_offset + reduceWidth*2])) + ( __half2float(mask[data_id + mask_offset + reduceWidth*2]))
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*3]))
: minus_infinity;
b.sync();
}
else if (alibi){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + ( __half2float(alibi[data_id + alibi_offset + reduceWidth*3])) + ( __half2float(mask[data_id + mask_offset + reduceWidth*3]))
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3]))
: minus_infinity;
b.sync();
}
else {
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
b.sync();
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale
: minus_infinity;
b.sync();
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册