cross.cpp 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
#include "megdnn/oprs.h"
#include "src/common/utils.h"

#include <algorithm>
#include <numeric>

namespace megdnn {

void Cross::deduce_layout(
        const TensorLayout& A, const TensorLayout& B, TensorLayout& C) {
    auto calibrated_axis = [](int ndim, int axis) {
        return axis < 0 ? (axis + ndim) : axis;
    };

    int axis_a = calibrated_axis(A.ndim, param().axisa);
    int axis_b = calibrated_axis(B.ndim, param().axisb);
    int axis_c = calibrated_axis(A.ndim, param().axisc);

    megdnn_assert(
            A[axis_a] == 3 && B[axis_b] == 3,
            "incompatible dimensions for cross product (dimension must be 3)");

    bool matched = true;
    TensorShape shp;
    if (A.ndim != B.ndim) {
        matched = false;
    } else {
        for (int i = 0, j = 0, k = 0; i < static_cast<int>(A.ndim); i++) {
            if (i == axis_a)
                continue;
            if (j == axis_b)
                ++j;
            if (A[i] != B[j]) {
                matched = false;
                break;
            }
            if (k == axis_c)
                ++k;
            shp[k++] = A[i];
            ++j;
        }
    }

    megdnn_assert(
            matched, "cross op shape mismatch: %s vs %s", A.to_string().c_str(),
            B.to_string().c_str());

    shp.ndim = A.ndim;
    shp[axis_c] = A[axis_a];
    C = TensorLayout{shp, A.dtype};
}

void Cross::check_exec(
        const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
        size_t workspace_in_bytes) {
    megdnn_assert_eq_dtype(A, B);
    megdnn_assert_eq_dtype(B, C);
    TensorLayout c_expected;
    deduce_layout(A, B, c_expected);
    megdnn_assert_eq_layout(c_expected, C);

    megdnn_assert_contiguous(A);
    megdnn_assert_contiguous(B);
    megdnn_assert_contiguous(C);
    auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void Cross::get_ABC(
        const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis) {
    auto shape_arr = shape.shape;
    auto ndim = shape.ndim;
    if (axis < 0)
        axis += ndim;
    A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies<size_t>());
    B = shape_arr[axis];
    C = std::accumulate(
            shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies<size_t>());
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen