main.es6 2.0 KB
Newer Older
W
wangqun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* eslint-disable */
/**
 * @file 主函数
 * @author chenhaoze
 */
export default `
    // start函数
    void main(void) {
        ivec4 oPos = getOutputTensorPosLIMIT_OUT();
        int x = oPos.a;
        int c = oPos.g;
        int y = oPos.b;
        int b = oPos.r; 
        float res = 0.0;
W
wangqun 已提交
15 16 17 18 19 20
        int temp_x = 0;
        int temp_y = 0;
        float o = 0.0;
        float f = 0.0;
        if (x % 2 == 1) x = x - 2;
        if (y % 2 == 1) y = y - 2;
W
wangqun 已提交
21 22 23 24 25 26 27 28 29
// 重排遍历顺序
//int sumVal = oPos.g + oPos.a * channel_out + oPos.b * channel_out * width_shape_out;
//int new_a = sumVal % width_shape_out;
//int new_b = int((sumVal - new_a) / width_shape_out) % height_shape_out;
//int new_g = int((((sumVal - new_a) / width_shape_out) - new_b) / height_shape_out);
//int x = new_a;
//int c = new_g;
//int y = new_b;
        // 获取output的坐标
W
wangqun 已提交
30 31
        int oTensorChannel = int(c * groups / channel_out) * channel_origin;
        int oy = y;
W
wangqun 已提交
32 33 34 35 36
        for (int fy = 0; fy < height_shape_filter; fy++) {
            if (oy < 0) {
                oy += dilation_v;
                continue;
            }
W
wangqun 已提交
37
            int ox = x;
W
wangqun 已提交
38
            for (int fx = 0; fx < width_shape_filter; fx++) {
W
wangqun 已提交
39

W
wangqun 已提交
40 41 42 43 44
                if (ox < 0) {
                    ox += dilation_h;
                    continue;
                }
                // channel计算
W
wangqun 已提交
45 46
                for (int j = 0; j < channel_origin; j++) {

W
wangqun 已提交
47
                	if (ox % stride_h == 0 && oy % stride_v == 0) {
W
wangqun 已提交
48 49 50 51 52 53 54
						temp_x = int(floor(float(ox) / float(stride_h)));
						temp_y = int(floor(float(oy) / float(stride_v)));
                        if (temp_x < width_shape_origin && temp_y < height_shape_origin){
						    o = getValueFromTensorPosLIMIT_ORIGIN_origin(b, j, temp_y, temp_x);
                            f = getValueFromTensorPosLIMIT_FILTER_filter(j, c, fy, fx);
                            res += f * o;
                        }
W
wangqun 已提交
55 56 57 58 59 60
					}
                }
                ox += dilation_h;
            }
            oy += dilation_v;
        }
W
wangqun 已提交
61
        setOutput(float(res));
W
wangqun 已提交
62 63
    }
`;