where_backward.cpp 1.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
#include "src/cuda/where/common.cuh"
#include "src/cuda/where/opr_impl.h"

#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {

void WhereBackwardImpl::exec(
        _megdnn_tensor_in diff, _megdnn_tensor_in mask, _megdnn_tensor_out grad_data1,
        _megdnn_tensor_out grad_data2, _megdnn_workspace workspace) {
    check_exec(
            diff.layout, mask.layout, grad_data1.layout, grad_data2.layout,
            workspace.size);
    auto stream = cuda_stream(this->handle());
    auto n = diff.layout.total_nr_elems();
#define cb(DType)                                                                \
    if (diff.layout.dtype == DType()) {                                          \
        using ctype = typename DTypeTrait<DType>::ctype;                         \
        where_backward::backward_proxy<ctype>(                                   \
                diff.ptr<ctype>(), mask.ptr<dt_bool>(), grad_data1.ptr<ctype>(), \
                grad_data2.ptr<ctype>(), n, stream);                             \
    }
    MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
    cb(::megdnn::dtype::Bool)
#undef cb
}

}  // namespace cuda
}  // namespace megdnn