提交 01040cfc 编写于 作者: M Megvii Engine Team

feat(mge): add kernel for where operator

GitOrigin-RevId: 6c02a87be62214de8691f29cf4045e7087b0d80a
上级 66b79160
......@@ -410,6 +410,24 @@ protected:
size_t workspace_in_bytes);
};
class NonZero : public OperatorBase {
DEF_OPR_IMPL(NonZero, OperatorBase, 1, 1);
DEF_OPR_PARAM(Empty);
public:
DType infer_type(DType Data);
virtual size_t get_workspace_in_bytes(const TensorLayout& src) = 0;
virtual TensorND exec(
_megdnn_tensor_in src, _megdnn_workspace workspace,
DynOutMallocPolicyCall malloc_policy) = 0;
protected:
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
class TransposeForward : public OperatorBase {
DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1);
DEF_OPR_PARAM(Empty);
......@@ -1535,6 +1553,74 @@ public:
};
using Norm = NormForward;
class WhereBase : public OperatorBase {
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(WhereBase, OperatorBase, 3, 1);
protected:
void deduce_layout_fwd(
const TensorLayout& mask, const TensorLayout& data1,
const TensorLayout& data2, TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& mask, const TensorLayout& data1,
const TensorLayout& data2, const TensorLayout& dst);
void deduce_layout_bwd(
const TensorLayout& diff, const TensorLayout& mask,
TensorLayout& grad_data1, TensorLayout& grad_data2);
void check_layout_bwd(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2);
};
class WhereForward : public WhereBase {
DEF_OPR_IMPL(WhereForward, WhereBase, 3, 1);
public:
virtual void exec(
_megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& mask, const TensorLayout& data1,
const TensorLayout& data2, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& mask, const TensorLayout& data1,
const TensorLayout& data2, const TensorLayout& dst) = 0;
protected:
void check_exec(
const TensorLayout& mask, const TensorLayout& data1,
const TensorLayout& data2, const TensorLayout& dst,
size_t get_workspace_in_bytes);
};
using Where = WhereForward;
class WhereBackward : public WhereBase {
DEF_OPR_IMPL(WhereBackward, WhereBase, 2, 2);
public:
/**
* \param[in] diff the backpropagated gradient wrt. dst
* \param[in] mask the `mask' parameter in WhereForward::exec
* \param[out] grad1 the backpropagated gradient wrt. data1
*/
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in mask,
_megdnn_tensor_out grad_data1, _megdnn_tensor_out grad_data2,
_megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& diff, const TensorLayout& mask,
TensorLayout& grad_data1, TensorLayout& grad_data2);
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2) = 0;
protected:
void check_exec(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2,
size_t get_workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -225,7 +225,10 @@ private:
cb(MaskedFill) \
cb(MultiHeadAttnForward)\
cb(MultiHeadAttnBackward) \
cb(Cross)
cb(Cross) \
cb(WhereForward) \
cb(WhereBackward) \
cb(NonZero)
// clang-format on
/*!
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void NonZero::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
dst.dtype.assert_is(infer_type(src.dtype));
if (!src.is_empty())
megdnn_assert(src.is_physical_contiguous());
auto require_workspace_in_bytes = get_workspace_in_bytes(src);
megdnn_assert(workspace_in_bytes >= require_workspace_in_bytes);
}
DType NonZero::infer_type(DType /*input*/) {
return dtype::Int32();
}
}; // namespace megdnn
\ No newline at end of file
......@@ -17,6 +17,9 @@ struct OprTrait {};
static const bool can_deduce_layout = CanDeduceLayout; \
}
DEF(NonZero, 2, true, false);
DEF(Where, 4, true, true);
DEF(WhereBackward, 4, true, true);
DEF(Norm, 2, true, true);
DEF(Padding, 2, false, true);
DEF(PaddingBackward, 2, false, false);
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void WhereBase::deduce_layout_fwd(
const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2,
TensorLayout& dst) {
if (!mask.is_empty())
megdnn_assert(mask.is_physical_contiguous());
if (!data1.is_empty())
megdnn_assert(data1.is_physical_contiguous());
if (!data2.is_empty())
megdnn_assert(data2.is_physical_contiguous());
if (!dst.is_empty())
megdnn_assert(dst.is_physical_contiguous());
auto errmsg = [&]() {
return megdnn_layout_msg(mask) + ", " + megdnn_layout_msg(data1) + ", " +
megdnn_layout_msg(data2) + ", " + megdnn_layout_msg(dst);
};
auto mask_dtype = mask.dtype, data1_dtype = data1.dtype, data2_dtype = data2.dtype;
megdnn_assert(mask_dtype.category() == DTypeCategory::BOOL);
megdnn_assert(
data1_dtype == data2_dtype &&
(data1_dtype.category() == DTypeCategory::INT ||
data1_dtype.category() == DTypeCategory::FLOAT ||
data1_dtype.category() == DTypeCategory::BOOL));
megdnn_assert(data1.ndim == data2.ndim, "%s", errmsg().c_str());
megdnn_assert(data1.ndim == mask.ndim, "%s", errmsg().c_str());
dst = TensorLayout{data1};
}
void WhereBase::check_layout_fwd(
const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& dst) {
TensorLayout dst_expected;
megdnn_assert_eq_shape(mask, data1);
megdnn_assert_eq_dtype(data1, dst);
megdnn_assert_eq_shape(data1, data2);
deduce_layout_fwd(mask, data1, data2, dst_expected);
megdnn_assert_eq_shape(dst_expected, dst);
}
void WhereForward::deduce_layout(
const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2,
TensorLayout& dst) {
deduce_layout_fwd(mask, data1, data2, dst);
}
void WhereForward::check_exec(
const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(mask, data1, data2, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(mask, data1, data2, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void WhereBase::deduce_layout_bwd(
const TensorLayout& diff, const TensorLayout& mask, TensorLayout& grad_data1,
TensorLayout& grad_data2) {
if (!diff.is_empty())
megdnn_assert(diff.is_physical_contiguous());
if (!mask.is_empty())
megdnn_assert(mask.is_physical_contiguous());
if (!grad_data1.is_empty())
megdnn_assert(grad_data1.is_physical_contiguous());
if (!grad_data2.is_empty())
megdnn_assert(grad_data2.is_physical_contiguous());
auto errmsg = [&]() {
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(mask) + ", " +
megdnn_layout_msg(grad_data1) + megdnn_layout_msg(grad_data2);
};
auto diff_dtype = diff.dtype, mask_dtype = mask.dtype;
megdnn_assert(mask_dtype.category() == DTypeCategory::BOOL);
megdnn_assert(
diff_dtype.category() == DTypeCategory::INT ||
diff_dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(diff.ndim == mask.ndim, "%s", errmsg().c_str());
grad_data1 = TensorLayout{diff};
grad_data2 = TensorLayout{diff};
}
void WhereBase::check_layout_bwd(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2) {
TensorLayout grad_expected1;
TensorLayout grad_expected2;
megdnn_assert_eq_shape(diff, mask);
megdnn_assert_eq_shape(diff, grad_data1);
megdnn_assert_eq_dtype(diff, grad_data1);
megdnn_assert_eq_shape(diff, grad_data2);
megdnn_assert_eq_dtype(diff, grad_data2);
deduce_layout_bwd(diff, mask, grad_expected1, grad_expected2);
megdnn_assert_eq_shape(grad_expected1, grad_data1);
megdnn_assert_eq_dtype(grad_expected1, grad_data1);
megdnn_assert_eq_shape(grad_expected2, grad_data2);
megdnn_assert_eq_dtype(grad_expected2, grad_data2);
}
void WhereBackward::deduce_layout(
const TensorLayout& diff, const TensorLayout& mask, TensorLayout& grad_data1,
TensorLayout& grad_data2) {
deduce_layout_bwd(diff, mask, grad_data1, grad_data2);
}
void WhereBackward::check_exec(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2,
size_t workspace_in_bytes) {
check_layout_bwd(diff, mask, grad_data1, grad_data2);
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, mask, grad_data1, grad_data2);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn
......@@ -52,6 +52,7 @@
#include "src/cuda/max_tensor_diff/opr_impl.h"
#include "src/cuda/mesh_indexing/opr_impl.h"
#include "src/cuda/multi_head_attn/opr_impl.h"
#include "src/cuda/non_zero/opr_impl.h"
#include "src/cuda/norm/opr_impl.h"
#include "src/cuda/padding/opr_impl.h"
#include "src/cuda/param_pack/opr_impl.h"
......@@ -84,6 +85,7 @@
#include "src/cuda/type_cvt/opr_impl.h"
#include "src/cuda/warp_affine/opr_impl.h"
#include "src/cuda/warp_perspective/opr_impl.h"
#include "src/cuda/where/opr_impl.h"
namespace megdnn {
namespace cuda {
......@@ -122,6 +124,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SVDForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CondTake);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(NonZero);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumsumForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward);
......@@ -236,6 +239,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Cross);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WhereForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WhereBackward);
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
......
#include "./kernel.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/utils.cuh"
using namespace megdnn;
using namespace megdnn::cuda;
using namespace megdnn::cuda::non_zero;
namespace {
__global__ void multi_div_size(
dt_int32* div_size, const size_t* shape_arr, size_t ndim) {
div_size[0] = 1;
for (int div_index = 1; div_index < ndim; div_index++) {
div_size[div_index] = shape_arr[ndim - div_index] * div_size[div_index - 1];
}
}
__global__ void expansion(
dt_int32* pt, dt_int32* div_size, const size_t* shape_arr, int loop_count,
int index_size, int ndim) {
int dim_idx = blockIdx.x;
int offset_from_each_dim = (blockDim.x * threadIdx.y + threadIdx.x) +
(loop_count * (blockDim.x * blockDim.y));
if (offset_from_each_dim >= index_size)
return;
int offset = dim_idx * (index_size) + offset_from_each_dim;
dt_int32* target_pt = pt + offset;
int dim_pos_of_ele = *target_pt / div_size[ndim - 1 - dim_idx];
dt_int32 dim_index_of_ele = dim_pos_of_ele % shape_arr[dim_idx];
target_pt[0] = dim_index_of_ele;
}
__global__ void copy_kern(dt_int32* dest_idx, const dt_int32* src_idx, uint32_t size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < size && src_idx[tid] > src_idx[tid - 1]) {
uint32_t v = src_idx[tid] - 1;
dest_idx[v] = tid;
}
}
__global__ void set_zero(dt_int32* dest) {
dest[0] = 0;
}
} // namespace
void megdnn::cuda::non_zero::copy_idx(
dt_int32* dest_idx, dt_int32* src_idx, uint32_t size, cudaStream_t stream) {
int nr_thread = query_blocksize_for_kernel(copy_kern);
int nr_block = DIVUP(size, nr_thread);
// Todo Set block and thread to 1 to set an element to 0. Need to consider
// optimization
set_zero<<<1, 1, 0, stream>>>(src_idx);
after_kernel_launch();
copy_kern<<<nr_block, nr_thread, 0, stream>>>(dest_idx, src_idx + 1, size);
after_kernel_launch();
}
void megdnn::cuda::non_zero::expansion_index(
dt_int32* dst_pt, size_t index_size, const size_t* src_shape,
size_t* src_shape_workspace_pt, size_t src_ndim, dt_int32* div_workspace_pt,
cudaStream_t stream) {
cuda_check(cudaMemcpyAsync(
src_shape_workspace_pt, src_shape, sizeof(size_t) * 7,
cudaMemcpyHostToDevice, stream));
// Todo change the cuda kernel to cpu for loop, or make the number of threads and
// blocks more reasonable
multi_div_size<<<1, 1, 0, stream>>>(
div_workspace_pt, src_shape_workspace_pt, src_ndim);
after_kernel_launch();
dim3 threadsPerBlock(
std::min<int>(NR_THREADS_X, index_size),
std::min<int>(NR_THREADS_Y, DIVUP(index_size, NR_THREADS_X)));
int loop_size = DIVUP(index_size, (NR_THREADS_X * NR_THREADS_Y));
for (int loop_idx = 0; loop_idx < loop_size; loop_idx++) {
expansion<<<src_ndim, threadsPerBlock, 0, stream>>>(
dst_pt, div_workspace_pt, src_shape_workspace_pt, loop_idx, index_size,
src_ndim);
after_kernel_launch();
}
}
\ No newline at end of file
#pragma once
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace non_zero {
void expansion_index(
dt_int32* dst_pt, size_t index_size, const size_t* src_shape,
size_t* src_shape_workspace_pt, size_t src_ndim, dt_int32* div_workspace_pt,
cudaStream_t stream);
void copy_idx(
dt_int32* dest_idx, dt_int32* src_idx, uint32_t size, cudaStream_t stream);
} // namespace non_zero
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
#include "megdnn/dtype.h"
#include "../cond_take/opr_impl.h"
#include "./kernel.cuh"
#include "./opr_impl.h"
#include "src/common/utils.h"
#include "src/cuda/cond_take/kern.cuh"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace megdnn::cuda;
using namespace megdnn::cuda::non_zero;
WorkspaceBundle NonZeroImpl::make_bundle(const TensorLayout& data) {
size_t nr_item = data.total_nr_elems();
cuda_check(cudaSetDevice(concrete_handle(handle())->device_id()));
auto gen_idx_wk_size = cuda::cond_take::gen_idx_get_workspace_size(nr_item);
SmallVector<size_t> sizes_in_bytes;
sizes_in_bytes.push_back((nr_item + 1) * sizeof(megdnn::cuda::cond_take::IdxType));
sizes_in_bytes.push_back(gen_idx_wk_size);
// the two ele is the shape of arr and the reverse multiply arr of the shape
sizes_in_bytes.push_back(sizeof(TensorLayout::shape));
sizes_in_bytes.push_back(sizeof(TensorLayout::shape));
return {nullptr, sizes_in_bytes, handle()->alignment_requirement()};
}
size_t NonZeroImpl::get_workspace_in_bytes(const TensorLayout& data) {
return make_bundle(data).total_size_in_bytes();
}
TensorND NonZeroImpl::exec(
_megdnn_tensor_in data, _megdnn_workspace workspace,
DynOutMallocPolicyCall malloc_policy) {
size_t size = data.layout.total_nr_elems();
if (size == 0) {
TensorShape target_shape({data.layout.ndim, 0});
TensorND rst = malloc_policy.alloc_output(0, dtype::Int32(), target_shape);
return rst;
}
auto wk_bundle = make_bundle(data.layout);
wk_bundle.set(workspace.raw_ptr);
auto idx_tmp = static_cast<megdnn::cuda::cond_take::IdxType*>(wk_bundle.get(0));
CondTake::Param param;
param.mode = CondTake::Param::Mode::NEQ;
param.val = 0;
param.eps = 1e-06;
megdnn::cond_take::KParam kparam(param);
auto stream = cuda_stream(handle());
size_t out_size;
switch (data.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
using ctype = DTypeTrait<_dt>::ctype; \
out_size = megdnn::cuda::cond_take::gen_idx( \
wk_bundle.get(1), wk_bundle.get_size(1), idx_tmp, data.ptr<ctype>(), \
size, static_cast<uint32_t>(param.mode), kparam, stream); \
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default : {
std::string data_type = data.layout.dtype.name();
megdnn_throw(
"bad mask dtype,support_types is support types: [Float32, Float16, "
"BFloat16, Int32, Int16, Int8, Uint8, Bool]" +
std::string("but the data type is ") + data_type);
}
}
TensorShape dst_shape({data.layout.ndim, out_size});
TensorND out_idx = malloc_policy.alloc_output(0, dtype::Int32(), dst_shape);
dt_int32* out_idx_ptr = out_idx.ptr<dt_int32>();
switch (data.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
for (size_t idx = 0; idx < data.layout.ndim; idx++) { \
dt_int32* copy_idx_ptr = out_idx_ptr + idx * out_size; \
copy_idx(copy_idx_ptr, idx_tmp, size, stream); \
} \
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default : megdnn_throw("bad data dtype");
}
expansion_index(
out_idx_ptr, out_size, data.layout.shape,
static_cast<size_t*>(wk_bundle.get(2)), data.layout.ndim,
static_cast<dt_int32*>(wk_bundle.get(3)), stream);
return out_idx;
}
\ No newline at end of file
#pragma once
#include "megdnn/oprs/general.h"
#include "src/common/utils.h"
namespace megdnn {
namespace cuda {
class NonZeroImpl final : public NonZero {
WorkspaceBundle make_bundle(const TensorLayout& data);
public:
using NonZero::NonZero;
virtual TensorND exec(
_megdnn_tensor_in src, _megdnn_workspace workspace,
DynOutMallocPolicyCall malloc_policy);
size_t get_workspace_in_bytes(const TensorLayout& data) override;
};
} // namespace cuda
} // namespace megdnn
#pragma once
#include <cuda_runtime_api.h>
#include <stdint.h>
namespace megdnn {
namespace cuda {
namespace where {
template <typename T>
void forward_proxy(
const bool* __restrict mask, const T* __restrict data1,
const T* __restrict data2, T* __restrict dst, size_t n, cudaStream_t stream);
} // namespace where
namespace where_backward {
template <typename T>
void backward_proxy(
const T* __restrict diff, const bool* mask, T* __restrict grad_data1,
T* __restrict grad_data2, size_t n, cudaStream_t stream);
} // namespace where_backward
} // namespace cuda
} // namespace megdnn
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class WhereForwardImpl : public WhereForward {
public:
using WhereForward::WhereForward;
void exec(
_megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& mask, const TensorLayout& data1,
const TensorLayout& data2, const TensorLayout& dst) override {
return 0;
}
};
class WhereBackwardImpl : public WhereBackward {
public:
using WhereBackward::WhereBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in mask,
_megdnn_tensor_out grad_data1, _megdnn_tensor_out grad_data2,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
#include "src/cuda/where/common.cuh"
#include "src/cuda/where/opr_impl.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void WhereBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in mask, _megdnn_tensor_out grad_data1,
_megdnn_tensor_out grad_data2, _megdnn_workspace workspace) {
check_exec(
diff.layout, mask.layout, grad_data1.layout, grad_data2.layout,
workspace.size);
auto stream = cuda_stream(this->handle());
auto n = diff.layout.total_nr_elems();
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
where_backward::backward_proxy<ctype>( \
diff.ptr<ctype>(), mask.ptr<dt_bool>(), grad_data1.ptr<ctype>(), \
grad_data2.ptr<ctype>(), n, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
} // namespace cuda
} // namespace megdnn
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/where/common.cuh"
namespace {
template <typename T>
__global__ void backward_kernel(
const T* __restrict diff, const bool* __restrict mask, T* __restrict grad_data1,
T* __restrict grad_data2, size_t n) {
size_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < n) {
grad_data1[i] = mask[i] ? diff[i] : 0;
grad_data2[i] = mask[i] ? 0 : diff[i];
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace where_backward {
template <typename T>
void backward_proxy(
const T* __restrict diff, const dt_bool* __restrict mask,
T* __restrict grad_data1, T* __restrict grad_data2, size_t n,
cudaStream_t stream) {
if (n == 0)
return;
backward_kernel<T><<<DIVUP(n, NR_THREADS), NR_THREADS, 0, stream>>>(
diff, mask, grad_data1, grad_data2, n);
after_kernel_launch();
}
#define INST(T) \
template void backward_proxy<T>( \
const T* __restrict, const dt_bool* __restrict, T* __restrict, \
T* __restrict, size_t, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace where_backward
} // namespace cuda
} // namespace megdnn
#include "src/cuda/where/common.cuh"
#include "src/cuda/where/opr_impl.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void WhereForwardImpl::exec(
_megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(mask.layout, data1.layout, data2.layout, dst.layout, workspace.size);
auto stream = cuda_stream(this->handle());
auto n = data1.layout.total_nr_elems();
#define cb(DType) \
if (data1.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
where::forward_proxy<ctype>( \
mask.ptr<dt_bool>(), data1.ptr<ctype>(), data2.ptr<ctype>(), \
dst.ptr<ctype>(), n, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
} // namespace cuda
} // namespace megdnn
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/where/common.cuh"
namespace {
template <typename T>
__global__ void forward_kernel(
const bool* __restrict mask, const T* __restrict data1,
const T* __restrict data2, T* __restrict dst, size_t n) {
size_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < n) {
dst[i] = mask[i] ? data1[i] : data2[i];
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace where {
template <typename T>
void forward_proxy(
const dt_bool* __restrict mask, const T* __restrict data1,
const T* __restrict data2, T* __restrict dst, size_t n, cudaStream_t stream) {
forward_kernel<T><<<DIVUP(n, NR_THREADS), NR_THREADS, 0, stream>>>(
mask, data1, data2, dst, n);
after_kernel_launch();
}
#define INST(T) \
template void forward_proxy<T>( \
const dt_bool* __restrict, const T* __restrict, const T* __restrict, \
T* __restrict, size_t, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace where
} // namespace cuda
} // namespace megdnn
......@@ -56,6 +56,7 @@
#include "src/naive/max_tensor_diff/opr_impl.h"
#include "src/naive/mesh_indexing/opr_impl.h"
#include "src/naive/multi_head_attn/opr_impl.h"
#include "src/naive/non_zero/opr_impl.h"
#include "src/naive/norm/opr_impl.h"
#include "src/naive/padding/opr_impl.h"
#include "src/naive/param_pack/opr_impl.h"
......@@ -90,6 +91,7 @@
#include "src/naive/type_cvt/opr_impl.h"
#include "src/naive/warp_affine/opr_impl.h"
#include "src/naive/warp_perspective/opr_impl.h"
#include "src/naive/where/opr_impl.h"
namespace megdnn {
namespace naive {
......
#include "./opr_impl.h"
#include "src/common/cond_take/predicate.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/non_zero/opr_impl.h"
using namespace megdnn;
using namespace naive;
using Param = NonZero::Param;
size_t NonZeroImpl::get_workspace_in_bytes(const TensorLayout& data) {
// save the size of index array in the last element of workspace
return (data.total_nr_elems() + 1) * sizeof(dt_int32);
}
template <uint32_t mode, typename ctype>
void gen_index(dt_int32* dest, const TensorND& src, cond_take::Pred<mode, ctype> pred) {
int idx = 0;
ctype* inp = src.ptr<ctype>();
size_t number_of_data = src.layout.total_nr_elems();
for (size_t data_pos = 0; data_pos < number_of_data; ++data_pos) {
if (pred(inp[data_pos])) {
dest[idx++] = data_pos;
}
}
// last element is the size of index array
dest[number_of_data] = idx;
}
void expansion_index(
const dt_int32* const index_arr, const size_t index_size, const TensorND* rst,
const size_t* shape_arr, const int ndim) {
SmallVector<int, 8> shape_reverse_multiply_reduce_arr({1});
for (int div_index = 1; div_index < ndim; div_index++) {
shape_reverse_multiply_reduce_arr[div_index] =
shape_arr[ndim - div_index] *
shape_reverse_multiply_reduce_arr[div_index - 1];
}
for (int dim_pos = 0; dim_pos < ndim; dim_pos++) {
dt_int32* dim_pt = rst->ptr<dt_int32>() + index_size * dim_pos;
for (size_t ele_pos = 0; ele_pos < index_size; ele_pos++) {
int dim_pos_of_ele = index_arr[ele_pos] /
shape_reverse_multiply_reduce_arr[ndim - 1 - dim_pos];
int dim_index_of_ele = dim_pos_of_ele % shape_arr[dim_pos];
dim_pt[ele_pos] = dim_index_of_ele;
}
}
}
TensorND NonZeroImpl::exec(
_megdnn_tensor_in src, _megdnn_workspace workspace,
DynOutMallocPolicyCall malloc_policy) {
#if !MGE_BUILD_WITHOUT_NAIVE_EXEC
auto idx_tmp = workspace.ptr<dt_int32>();
switch (src.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
using ctype = DTypeTrait<_dt>::ctype; \
using namespace cond_take; \
KParam param({}); \
param.val = 0.0; \
param.eps = 1e-6; \
Pred<PEnum::NEQ, ctype> pred(param); \
MEGDNN_DISPATCH_CPU_KERN_OPR(gen_index(idx_tmp, src, pred)); \
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default : {
std::string data_type = src.layout.dtype.name();
megdnn_throw(
"bad mask dtype,support_types is support types: [Float32, Float16, "
"BFloat16, Int32, Int16, Int8, Uint8, Bool]" +
std::string("but the data type is ") + data_type);
}
}
static_cast<HandleImpl*>(handle())->megcore_dispatcher()->sync();
size_t index_size_pos = src.layout.total_nr_elems();
size_t index_size = idx_tmp[index_size_pos];
TensorND ret;
size_t ndim = src.layout.ndim;
TensorShape dst_shape({ndim, index_size});
ret = malloc_policy.alloc_output(0, dtype::Int32(), {ndim, index_size});
MEGDNN_DISPATCH_CPU_KERN_OPR(
expansion_index(idx_tmp, index_size, &ret, src.layout.shape, ndim));
return ret;
#else
__builtin_trap();
return {};
#endif
}
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class NonZeroImpl : public NonZero {
public:
using NonZero::NonZero;
TensorND exec(
_megdnn_tensor_in src, _megdnn_workspace workspace,
DynOutMallocPolicyCall malloc_policy);
size_t get_workspace_in_bytes(const TensorLayout& src);
};
} // namespace naive
} // namespace megdnn
\ No newline at end of file
#include "src/naive/where/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
template <typename T>
void WhereForwardImpl::exec_internal(
const dt_bool* __restrict mask, const T* __restrict data1,
const T* __restrict data2, T* __restrict dst, size_t n) {
rep(i, n) { dst[i] = mask[i] ? data1[i] : data2[i]; }
}
void WhereForwardImpl::exec(
_megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(mask.layout, data1.layout, data2.layout, dst.layout, workspace.size);
auto n = data1.layout.total_nr_elems();
#define cb(DType) \
if (data1.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \
mask.ptr<dt_bool>(), data1.ptr<ctype>(), data2.ptr<ctype>(), \
dst.ptr<ctype>(), n)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
megdnn_assert_internal(0);
}
template <typename T>
void WhereBackwardImpl::exec_internal(
const T* __restrict diff, const dt_bool* __restrict mask,
T* __restrict grad_data1, T* __restrict grad_data2, size_t n) {
rep(i, n) {
grad_data1[i] = mask[i] ? diff[i] : 0;
grad_data2[i] = mask[i] ? 0 : diff[i];
}
}
void WhereBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in mask, _megdnn_tensor_out grad_data1,
_megdnn_tensor_out grad_data2, _megdnn_workspace workspace) {
check_exec(
diff.layout, mask.layout, grad_data1.layout, grad_data2.layout,
workspace.size);
auto n = diff.layout.total_nr_elems();
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \
diff.ptr<ctype>(), mask.ptr<dt_bool>(), grad_data1.ptr<ctype>(), \
grad_data2.ptr<ctype>(), n)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
megdnn_assert_internal(0);
}
} // namespace naive
} // namespace megdnn
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
namespace naive {
class WhereForwardImpl : public WhereForward {
public:
using WhereForward::WhereForward;
void exec(
_megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
template <typename T>
void exec_internal(
const dt_bool* __restrict mask, const T* __restrict data1,
const T* __restrict data2, T* __restrict dst, size_t n);
};
class WhereBackwardImpl : public WhereBackward {
public:
using WhereBackward::WhereBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in mask,
_megdnn_tensor_out grad_data1, _megdnn_tensor_out grad_data2,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& mask,
const TensorLayout& grad_data1, const TensorLayout& grad_data2) override {
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(mask);
MEGDNN_MARK_USED_VAR(grad_data1);
MEGDNN_MARK_USED_VAR(grad_data2);
return 0;
}
private:
template <typename T>
void exec_internal(
const T* __restrict diff, const dt_bool* __restrict mask,
T* __restrict grad_data1, T* __restrict grad_data2, size_t n);
};
} // namespace naive
} // namespace megdnn
#include "./non_zero.h"
#include "./rng.h"
#include "./tensor.h"
#include "./utils.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
using Param = NonZero::Param;
std::vector<NonZeroTestcase> NonZeroTestcase::make() {
Param param;
std::vector<NonZeroTestcase> ret;
TensorShape shape1{4};
ret.push_back({param, TensorLayout{shape1, dtype::Int8()}});
NonZeroTestcase& case1 = ret.back();
case1.m_mem.reset(new uint8_t[case1.m_data.layout.span().dist_byte()]);
memset(case1.m_mem.get(), 0, case1.m_data.layout.span().dist_byte());
case1.m_data.reset_ptr(case1.m_mem.get());
dt_int8* pt_1 = reinterpret_cast<dt_int8*>(case1.m_mem.get());
pt_1[3] = 1;
case1.correct_answer.push_back(3);
TensorShape shape2{1, 1, 1, 1, 1, 1, 1};
ret.push_back({param, TensorLayout{shape2, dtype::Float32()}});
NonZeroTestcase& case2 = ret.back();
case2.m_mem.reset(new uint8_t[case2.m_data.layout.span().dist_byte()]);
memset(case2.m_mem.get(), 0, case2.m_data.layout.span().dist_byte());
case2.m_data.reset_ptr(case2.m_mem.get());
dt_float32* pt_2 = reinterpret_cast<dt_float32*>(case2.m_mem.get());
pt_2[0] = 1.0;
case2.correct_answer = {0, 0, 0, 0, 0, 0, 0};
TensorShape shape3{0};
ret.push_back({param, TensorLayout{shape3, dtype::Float32()}});
NonZeroTestcase& case3 = ret.back();
case3.m_mem.reset(new uint8_t[case3.m_data.layout.span().dist_byte()]);
memset(case3.m_mem.get(), 0, case3.m_data.layout.span().dist_byte());
case3.m_data.reset_ptr(case3.m_mem.get());
case3.correct_answer = {};
TensorShape shape4{1, 2, 3, 4, 5, 6, 7};
ret.push_back({param, TensorLayout{shape4, dtype::Float32()}});
NonZeroTestcase& case4 = ret.back();
case4.m_mem.reset(new uint8_t[case4.m_data.layout.span().dist_byte()]);
memset(case4.m_mem.get(), 0, case4.m_data.layout.span().dist_byte());
case4.m_data.reset_ptr(case4.m_mem.get());
dt_float32* pt_4 = reinterpret_cast<dt_float32*>(case4.m_mem.get());
pt_4[shape4.total_nr_elems() - 1] = 1.0;
case4.correct_answer = {0, 1, 2, 3, 4, 5, 6};
TensorShape shape5{2, 2, 2, 2, 2, 2, 2};
ret.push_back({param, TensorLayout{shape5, dtype::Float32()}});
NonZeroTestcase& case5 = ret.back();
case5.m_mem.reset(new uint8_t[case5.m_data.layout.span().dist_byte()]);
memset(case5.m_mem.get(), 0, case5.m_data.layout.span().dist_byte());
case5.m_data.reset_ptr(case5.m_mem.get());
dt_float32* pt_5 = reinterpret_cast<dt_float32*>(case5.m_mem.get());
pt_5[63] = 1.0;
case5.correct_answer = {
0, 1, 1, 1, 1, 1, 1,
};
return ret;
}
NonZeroTestcase::Result NonZeroTestcase::run_naive(NonZero* opr) {
auto handle = opr->handle();
DynOutMallocPolicyImpl malloc_policy(handle);
opr->param() = m_param;
auto workspace_size = opr->get_workspace_in_bytes(m_data.layout);
auto workspace_ptr = malloc_policy.alloc_workspace(workspace_size, nullptr);
auto result = opr->exec(
m_data, {(dt_byte*)workspace_ptr, workspace_size}, &malloc_policy);
malloc_policy.free_workspace(workspace_ptr, nullptr);
return result;
}
NonZeroTestcase::CUDAResult NonZeroTestcase::run_cuda(NonZero* opr) {
auto handle = opr->handle();
DynOutMallocPolicyImpl malloc_policy(handle);
opr->param() = m_param;
auto data = make_tensor_h2d(handle, m_data);
auto workspace_size = opr->get_workspace_in_bytes(m_data.layout);
auto workspace_ptr = malloc_policy.alloc_workspace(workspace_size, nullptr);
auto result =
opr->exec(*data, {(dt_byte*)workspace_ptr, workspace_size}, &malloc_policy);
malloc_policy.free_workspace(workspace_ptr, nullptr);
return {make_tensor_d2h(handle, result)};
}
void NonZeroTestcase::Assert(
std::vector<int>& correct_answer, int ndim, NonZeroTestcase::Result result) {
dt_int32* data_pt = result.ptr<dt_int32>();
ASSERT_EQ(result.layout.total_nr_elems(), correct_answer.size());
ASSERT_EQ(ndim, result.layout.shape[0]);
for (size_t ele_idx = 0; ele_idx < result.layout.total_nr_elems(); ele_idx++) {
ASSERT_EQ(data_pt[ele_idx], correct_answer[ele_idx]);
}
}
\ No newline at end of file
#pragma once
#include "./checker.h"
#include "megdnn/oprs.h"
namespace megdnn {
namespace test {
class NonZeroTestcase {
public:
std::unique_ptr<uint8_t> m_mem;
NonZero::Param m_param;
TensorND m_data;
std::vector<int> correct_answer;
NonZeroTestcase(NonZero::Param param, const TensorLayout& data)
: m_param(param), m_data(nullptr, data) {}
using Result = TensorND;
using CUDAResult = std::shared_ptr<TensorND>;
Result run_naive(NonZero* opr);
CUDAResult run_cuda(NonZero* opr);
static std::vector<NonZeroTestcase> make();
static void Assert(
std::vector<int>& correct_answer, int ndim, NonZeroTestcase::Result result);
};
} // namespace test
} // namespace megdnn
\ No newline at end of file
#include "test/common/non_zero.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
using namespace megdnn;
using namespace test;
TEST_F(CUDA, NONZERO) {
std::vector<NonZeroTestcase> test_cases = NonZeroTestcase::make();
auto opr_cuda = handle_cuda()->create_operator<NonZero>();
auto opr_naive = handle_naive()->create_operator<NonZero>();
for (NonZeroTestcase& test_case : test_cases) {
NonZeroTestcase::CUDAResult data = test_case.run_cuda(opr_cuda.get());
NonZeroTestcase::CUDAResult data_naive = test_case.run_cuda(opr_naive.get());
std::vector<int> result = test_case.correct_answer;
MEGDNN_ASSERT_TENSOR_EQ(*data, *data_naive);
}
}
\ No newline at end of file
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, WHERE) {
Checker<Where> checker(handle_cuda());
checker.exect(
Testcase{
TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}),
TensorValue({1, 2, 2}, dtype::Float32(), {1, 2, 3, 4}),
TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}),
{}},
Testcase{
{},
{},
{},
TensorValue({1, 2, 2}, dtype::Float32(), {1, 6, 7, 4})});
}
TEST_F(CUDA, WHEREBACKWARD) {
Checker<WhereBackward> checker(handle_cuda());
checker.exect(
Testcase{
TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}),
TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}),
{},
{}},
Testcase{
{},
{},
TensorValue({1, 2, 2}, dtype::Float32(), {5, 0, 0, 8}),
TensorValue({1, 2, 2}, dtype::Float32(), {0, 6, 7, 0})});
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
#include "test/common/non_zero.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
using namespace megdnn;
using namespace test;
TEST_F(NAIVE, NONZERO) {
std::vector<NonZeroTestcase> test_cases = NonZeroTestcase::make();
auto opr_naive = handle()->create_operator<NonZero>();
for (NonZeroTestcase& test_case : test_cases) {
NonZeroTestcase::Result data = test_case.run_naive(opr_naive.get());
int ndim = test_case.m_data.layout.ndim;
std::vector<int> result = test_case.correct_answer;
NonZeroTestcase::Assert(result, ndim, data);
}
}
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, WHERE) {
Checker<Where> checker(handle());
checker.exect(
Testcase{
TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}),
TensorValue({1, 2, 2}, dtype::Float32(), {1, 2, 3, 4}),
TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}),
{}},
Testcase{
{},
{},
{},
TensorValue({1, 2, 2}, dtype::Float32(), {1, 6, 7, 4})});
}
TEST_F(NAIVE, WHEREBACKWARD) {
Checker<WhereBackward> checker(handle());
checker.exect(
Testcase{
TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}),
TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}),
{},
{}},
Testcase{
{},
{},
TensorValue({1, 2, 2}, dtype::Float32(), {5, 0, 0, 8}),
TensorValue({1, 2, 2}, dtype::Float32(), {0, 6, 7, 0})});
}
} // namespace test
} // namespace megdnn
......@@ -30,6 +30,7 @@ __all__ = [
"broadcast_to",
"concat",
"cond_take",
"non_zero",
"copy",
"cumsum",
"diag",
......@@ -821,30 +822,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
return inp
@lru_cache(maxsize=None)
def _get_where_op(dtype=None, device=None):
@subgraph_fn(
"Where",
dtype=dtype,
device=device,
nr_inputs=3,
jit_fusion=True,
custom_grad=True,
)
def where(inputs, f, c):
(mask, x, y) = inputs[0:3]
oup = f("switch_gt0", mask, x)
ksam = f("-", c(1), mask)
oup = f("+", oup, f("switch_gt0", ksam, y))
(oup_grad,) = yield (oup,)
x_grad = f("switch_gt0", mask, oup_grad)
y_grad = f("switch_gt0", ksam, oup_grad)
yield (None, x_grad, y_grad)
return where
def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
def where(mask: Tensor, x: Tensor = None, y: Tensor = None) -> Tensor:
r"""Selects elements either from Tensor x or Tensor y, according to mask.
.. math::
......@@ -870,6 +848,8 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
array([[1., 6.],
[7., 4.]], dtype=float32)
"""
if x is None and y is None:
return non_zero(mask, as_tuple=True)
if not isinstance(x, Tensor):
raise TypeError("input x must be a tensor")
......@@ -882,18 +862,8 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
if x.device != mask.device:
raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
dtype = dtype_promotion(x, y)
device = x.device
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)
mask = mask.astype(dtype)
where = _get_where_op(dtype=dtype, device=device)
(oup,) = where(mask, x, y)
return oup
where = builtin.Where()
return apply(where, mask, x, y)[0]
def cond_take(mask: Tensor, x: Tensor) -> Tensor:
......@@ -961,6 +931,42 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
return inp.transpose(pattern)
def non_zero(condition: Tensor, as_tuple=False):
r"""When as_tuple is False (default):
Returns a tensor including the indices of all non-zero elements of Tensor condition.
Every row in the result including the indices of a non-zero element in input.
The result is sorted in lexicography order, with the last index changing the fastest (C-style).
When as_tuple is True:
Returns a tuple of 1-D tensors, one for each dimension in input,
each containing the indices (in that dimension) of all non-zero elements of condition.
Args:
condition(Tensor) - the input tensor
Returns:
one tuple of 1-D tensors or one tensor
Examples:
>>> import numpy as np
>>> condition = Tensor(np.array([1,1,0,1]))
>>> index = F.non_zero(condition,as_tuple=True)
>>> print(index)
(Tensor([0 1 3], dtype=int32, device=xpux:0),)
"""
if not isinstance(condition, Tensor):
raise TypeError("input must be a tensor")
op = builtin.NonZero()
(index,) = apply(op, condition)
ret = None
if as_tuple == True:
arr = []
for index_ele in range(0, condition.ndim):
arr.append(index[index_ele, :])
ret = tuple(arr)
else:
ret = transpose(index, (1, 0))
return ret
def swapaxes(inp: Tensor, axis1: int, axis2: int) -> Tensor:
r"""Interchange two axes of a tensor.
......
......@@ -708,6 +708,47 @@ std::optional<ValueRefList> warp_affine_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<ValueRefList> where_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& where = op.cast_final_safe<Where>();
auto&& param = where.param();
mgb_assert(inputs.size() == 3);
SmallVector<ValueRef> inps;
if (inputs_require_grad[1] || inputs_require_grad[2]) {
inps.push_back(inputs[0]);
}
bool data1_requires_grad = inputs_require_grad[1],
data2_requires_grad = inputs_require_grad[2];
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inps), data1_requires_grad, data2_requires_grad,
param](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(3);
if (!grad) {
return ret;
}
if (data1_requires_grad == false && data2_requires_grad == false) {
return ret;
}
auto&& grad_op = WhereBackward::make();
ValueRefList args_(2);
args_[0] = grads[0];
args_[1] = inputs[0];
auto back_grad = imperative::apply(*grad_op, args_);
if (data1_requires_grad)
ret[1] = back_grad[0];
if (data2_requires_grad)
ret[2] = back_grad[1];
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}
struct Init {
Init() {
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
......@@ -733,6 +774,7 @@ struct Init {
BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule);
CustomBackward::register_grad_rule(
WarpAffine::typeinfo(), warp_affine_grad_rule);
CustomBackward::register_grad_rule(Where::typeinfo(), where_grad_rule);
}
} _;
......
......@@ -128,6 +128,26 @@ def test_condtake(is_varnode):
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
@pytest.mark.parametrize("as_tuple", [True, False])
def test_nonzero(as_tuple):
def test_impl(np_condition):
tensor_condition = make_tensor(np_condition, None)
megengine_result = F.non_zero(tensor_condition, as_tuple=as_tuple)
np_result = np.nonzero(np_condition)
if as_tuple == False:
np_result = np.transpose(np_result, (1, 0))
for pos in range(len(megengine_result)):
np.testing.assert_equal(megengine_result[pos].numpy(), np_result[pos])
test_impl(
np.array([[True, False, True, False, False], [False, True, True, False, False]])
)
test_impl(np.random.randint(1, 10, size=[0, 3, 0]))
test_impl(np.random.randint(1, 10, size=[1, 2, 3]))
test_impl(np.random.randint(1, 10, size=[1, 2, 3, 4, 5, 6, 7]))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_stack_device(is_varnode):
if is_varnode:
......
#include <utility>
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/misc.h"
using namespace megdnn;
namespace mgb::imperative {
namespace {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<NonZero>();
OperatorNodeConfig config{op.make_name()};
return opr::NonZero::make(inputs[0], {}, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
mgb_assert(inputs.size() == 1, "NonZero take 1 inputs, got %lu", inputs.size());
auto&& condition = inputs[0];
if (condition->layout().is_empty()) {
// empty tensor
return {Tensor::make(
TensorLayout{{condition->layout().ndim, 0}, dtype::Int32()},
condition->comp_node())};
} else {
megdnn::NonZero::Param param;
DnnOprCaller<megdnn::NonZero> dnn_op(condition->comp_node(), param);
auto&& [out] = dnn_op.exec_dynout<1>(condition);
return {out};
}
}
std::tuple<SmallVector<LogicalTensorDesc, 1>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
LogicalTensorDesc input_0 = inputs[0];
auto cn = inputs[0].comp_node;
return {{{TensorLayout(dtype::Int32()), cn}}, false};
}
OP_TRAIT_REG(NonZero, NonZero, opr::NonZero)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // namespace
} // namespace mgb::imperative
#include "../dnn_op_helper.h"
#include "megbrain/imperative/ops/autogen.h"
#include "../op_trait.h"
#include "megbrain/opr/misc.h"
#include "megdnn/oprs/general.h"
namespace mgb {
namespace imperative {
namespace {
namespace where {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
mgb_assert(input_descs.size() == 3, "Where expects three inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout mask = input_descs[0].layout, data1 = input_descs[1].layout,
data2 = input_descs[2].layout;
mgb_assert(mask.dtype == dtype::Bool(), "mask dtype must be boolean");
mgb_assert(
data1.dtype == dtype::Float32() || data1.dtype == dtype::Int32() ||
data1.dtype == dtype::Bool(),
"data1 dtype must be float32 or int32");
mgb_assert(
data2.dtype == dtype::Float32() || data2.dtype == dtype::Int32() ||
data2.dtype == dtype::Bool(),
"data2 dtype must be float32 or int32");
if (!mask.ndim || !data1.ndim || !data2.ndim) {
return {{{TensorLayout{data1.dtype}, comp_node, {}}}, false};
}
if (!mask.is_empty())
mgb_assert(mask.is_contiguous(), "mask should be contiguous");
if (!data1.is_empty())
mgb_assert(data1.is_contiguous(), "data1 should be contiguous");
if (!data2.is_empty())
mgb_assert(data2.is_contiguous(), "data2 should be contiguous");
mgb_assert(mask.eq_shape(data1), "mask shape doesn't match data1");
mgb_assert(mask.eq_shape(data2), "mask shape doesn't match data2");
mgb_assert(data1.eq_layout(data2), "data1 layout doesn't match data2");
TensorLayout dst = data1;
dst.init_contiguous_stride();
return {{{dst, comp_node}}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Where>();
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
return opr::Where::make(inputs[0], inputs[1], inputs[2], {}, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validatad) {
auto&& mask = inputs[0];
auto&& data1 = inputs[1];
auto&& data2 = inputs[2];
auto&& mask_layout = mask->layout();
auto&& data1_layout = data1->layout();
auto&& data2_layout = data2->layout();
DnnOprCaller<megdnn::Where> dnn_op(mask->comp_node());
auto tlayout = dnn_op.deduce_layout(mask_layout, data1_layout, data2_layout);
auto out = Tensor::make(tlayout, mask->comp_node());
if (!mask_layout.is_empty())
dnn_op.exec_with_ws(mask, data1, data2, out);
return {out};
}
OP_TRAIT_REG(Where, Where)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace where
namespace where_backward {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
mgb_assert(input_descs.size() == 2, "WhereBackward expects two inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout diff = input_descs[0].layout, mask = input_descs[1].layout;
mgb_assert(
diff.dtype == dtype::Float32() || diff.dtype == dtype::Int32(),
"diff dtype must be float32 or int32");
mgb_assert(mask.dtype == dtype::Bool(), "mask dtype must be boolean");
if (!diff.ndim || !mask.ndim) {
return {{{diff, comp_node}}, false};
}
if (!diff.is_empty())
mgb_assert(diff.is_contiguous(), "diff should be contiguous");
if (!mask.is_empty())
mgb_assert(mask.is_contiguous(), "mask should be contiguous");
mgb_assert(diff.eq_shape(mask), "diff shape doesn't match mask");
TensorLayout data1 = diff;
data1.init_contiguous_stride();
TensorLayout data2 = diff;
data2.init_contiguous_stride();
return {{{data1, comp_node}, {data2, comp_node}}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<WhereBackward>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::WhereBackward::make(inputs[0], inputs[1], {}, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& diff = inputs[0];
auto&& mask = inputs[1];
auto&& diff_layout = diff->layout();
auto&& mask_layout = mask->layout();
DnnOprCaller<megdnn::WhereBackward> dnn_op(diff->comp_node());
auto tlayouts = dnn_op.deduce_layouts<2>(diff_layout, mask_layout);
auto grad1 = Tensor::make(tlayouts.at(0), diff->comp_node());
auto grad2 = Tensor::make(tlayouts.at(1), diff->comp_node());
dnn_op.exec_with_ws(diff, mask, grad1, grad2);
return {grad1, grad2};
}
OP_TRAIT_REG(WhereBackward, WhereBackward)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace where_backward
} // anonymous namespace
} // namespace imperative
} // namespace mgb
\ No newline at end of file
......@@ -404,6 +404,23 @@ ValueRefList setsubtensor_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, converted);
}
ValueRefList where_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes({inputs.begin() + 1, inputs.end()});
mgb::DType target_dtype = get_promoted_dtype(dtypes);
ValueRefList converted(inputs.size());
converted[0] = inputs[0];
for (int idx = 1; idx < inputs.size(); idx++) {
if (*(inputs[idx].dtype()) != target_dtype) {
converted[idx] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[idx])[0];
} else {
converted[idx] = inputs[idx];
}
}
return imperative::apply(op, converted);
}
struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
......@@ -424,6 +441,7 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule<GroupNorm>(norm_rule);
register_dtype_promote_rule<SetSubtensor>(setsubtensor_rule);
register_dtype_promote_rule<IndexingSetMultiAxisVec>(setsubtensor_rule);
register_dtype_promote_rule<Where>(where_rule);
}
} register_helper;
......
......@@ -304,6 +304,7 @@ struct ScalarRuleRegistry {
register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>();
register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>();
register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>();
register_scalar_rule<Where, elemwise_rule<Where, 3>>();
}
} _;
} // namespace
......
e4035bfefce3a2cc0e8cc6ec7fcac227 ../../dnn/scripts/opr_param_defs.py
13ab898fce3749ebbcabf7c145876147 ../../src/core/include/megbrain/ir/ops.td
9dda6e2db75279373ec6809b297a2370 generated/opdef.h.inl
aabc2d8146742faacabf56e376177e7b generated/opdef.cpp.inl
8a5dffac1df3286178b3fd304c39b5da generated/opdef.py.inl
04322b642bba8f684034fcc5dc27efcf generated/opdef.cpy.inl
5ed571605e6f376c8801611c573707e3 ../../src/core/include/megbrain/ir/ops.td
28acc4b2c91ecbe3e0e27c1c45e5bc0c generated/opdef.h.inl
d554e79f054cb152251e88a41621b524 generated/opdef.cpp.inl
132afb96c40ae64cbab73b89aa00a844 generated/opdef.py.inl
87da68c8a60f965088e1c32c38939195 generated/opdef.cpy.inl
911001ef0dd771024919f7a1a3a009db generated/enum_macro.h
......@@ -5426,6 +5426,40 @@ OP_TRAIT_REG(NMSKeep, NMSKeep)
.props(NMSKeep_props_impl)
.make_name(NMSKeep_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NonZero);
namespace {
size_t NonZero_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<NonZero>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
return val;
}
bool NonZero_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<NonZero>(),
&&b_ = rhs_.cast_final_safe<NonZero>();
static_cast<void>(a_);
static_cast<void>(b_);
return true;
}
std::vector<std::pair<const char*, std::string>> NonZero_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<NonZero>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
return props_;
}
std::string NonZero_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<NonZero>();
static_cast<void>(op_);
return "NonZero";
}
} // anonymous namespace
OP_TRAIT_REG(NonZero, NonZero)
.hash(NonZero_hash_impl)
.is_same_st(NonZero_is_same_st_impl)
.props(NonZero_props_impl)
.make_name(NonZero_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf);
namespace {
......@@ -8333,4 +8367,72 @@ OP_TRAIT_REG(WarpPerspectiveBackwardMat, WarpPerspectiveBackwardMat)
.props(WarpPerspectiveBackwardMat_props_impl)
.make_name(WarpPerspectiveBackwardMat_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Where);
namespace {
size_t Where_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Where>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
return val;
}
bool Where_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Where>(),
&&b_ = rhs_.cast_final_safe<Where>();
static_cast<void>(a_);
static_cast<void>(b_);
return true;
}
std::vector<std::pair<const char*, std::string>> Where_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Where>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
return props_;
}
std::string Where_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Where>();
static_cast<void>(op_);
return "Where";
}
} // anonymous namespace
OP_TRAIT_REG(Where, Where)
.hash(Where_hash_impl)
.is_same_st(Where_is_same_st_impl)
.props(Where_props_impl)
.make_name(Where_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WhereBackward);
namespace {
size_t WhereBackward_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WhereBackward>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
return val;
}
bool WhereBackward_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<WhereBackward>(),
&&b_ = rhs_.cast_final_safe<WhereBackward>();
static_cast<void>(a_);
static_cast<void>(b_);
return true;
}
std::vector<std::pair<const char*, std::string>> WhereBackward_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WhereBackward>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
return props_;
}
std::string WhereBackward_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WhereBackward>();
static_cast<void>(op_);
return "WhereBackward";
}
} // anonymous namespace
OP_TRAIT_REG(WhereBackward, WhereBackward)
.hash(WhereBackward_hash_impl)
.is_same_st(WhereBackward_is_same_st_impl)
.props(WhereBackward_props_impl)
.make_name(WhereBackward_make_name_impl);
// clang-format on
......@@ -16010,6 +16010,88 @@ void _init_py_NMSKeep(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(NMSKeep::typeinfo(), &py_type).second);
}
PyOpDefBegin(NonZero) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(NonZero)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(NonZero)*>(self)->inst();
static_cast<void>(opdef);
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(NonZero)
int PyOp(NonZero)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
return 0;
}
PyGetSetDef PyOp(NonZero)::py_getsetters[] = {
{NULL} /* Sentinel */
};
PyMethodDef PyOp(NonZero)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(NonZero)::getstate, METH_NOARGS, "NonZero getstate"},
{const_cast<char*>("__setstate__"), PyOp(NonZero)::setstate, METH_VARARGS, "NonZero setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(NonZero)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(NonZero)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(NonZero)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(NonZero)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self) -> None\n"
};
void _init_py_NonZero(py::module m) {
using py_op = PyOp(NonZero);
auto& py_type = PyOpType(NonZero);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.NonZero";
py_type.tp_basicsize = sizeof(PyOp(NonZero));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "NonZero";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(NonZero), &PyOp(NonZero)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("NonZero", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(NonZero::typeinfo(), &py_type).second);
}
PyOpDefBegin(NvOf) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -23534,6 +23616,170 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
m.add_object("WarpPerspectiveBackwardMat", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspectiveBackwardMat::typeinfo(), &py_type).second);
}
PyOpDefBegin(Where) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Where)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Where)*>(self)->inst();
static_cast<void>(opdef);
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Where)
int PyOp(Where)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
return 0;
}
PyGetSetDef PyOp(Where)::py_getsetters[] = {
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Where)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Where)::getstate, METH_NOARGS, "Where getstate"},
{const_cast<char*>("__setstate__"), PyOp(Where)::setstate, METH_VARARGS, "Where setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Where)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Where)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Where)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Where)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self) -> None\n"
};
void _init_py_Where(py::module m) {
using py_op = PyOp(Where);
auto& py_type = PyOpType(Where);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Where";
py_type.tp_basicsize = sizeof(PyOp(Where));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Where";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Where), &PyOp(Where)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("Where", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Where::typeinfo(), &py_type).second);
}
PyOpDefBegin(WhereBackward) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(WhereBackward)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(WhereBackward)*>(self)->inst();
static_cast<void>(opdef);
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(WhereBackward)
int PyOp(WhereBackward)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
return 0;
}
PyGetSetDef PyOp(WhereBackward)::py_getsetters[] = {
{NULL} /* Sentinel */
};
PyMethodDef PyOp(WhereBackward)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(WhereBackward)::getstate, METH_NOARGS, "WhereBackward getstate"},
{const_cast<char*>("__setstate__"), PyOp(WhereBackward)::setstate, METH_VARARGS, "WhereBackward setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(WhereBackward)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(WhereBackward)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(WhereBackward)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(WhereBackward)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self) -> None\n"
};
void _init_py_WhereBackward(py::module m) {
using py_op = PyOp(WhereBackward);
auto& py_type = PyOpType(WhereBackward);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.WhereBackward";
py_type.tp_basicsize = sizeof(PyOp(WhereBackward));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "WhereBackward";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(WhereBackward), &PyOp(WhereBackward)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("WhereBackward", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WhereBackward::typeinfo(), &py_type).second);
}
#define INIT_ALL_OP(m) \
_init_py_AdaptivePooling(m); \
_init_py_AddAxis(m); \
......@@ -23614,6 +23860,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_MeshIndexing(m); \
_init_py_MultiHeadAttn(m); \
_init_py_NMSKeep(m); \
_init_py_NonZero(m); \
_init_py_NvOf(m); \
_init_py_Padding(m); \
_init_py_ParamPackConcat(m); \
......@@ -23654,5 +23901,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_WarpAffine(m); \
_init_py_WarpPerspective(m); \
_init_py_WarpPerspectiveBackwardData(m); \
_init_py_WarpPerspectiveBackwardMat(m);
_init_py_WarpPerspectiveBackwardMat(m); \
_init_py_Where(m); \
_init_py_WhereBackward(m);
// clang-format on
......@@ -1457,6 +1457,17 @@ public:
NMSKeep(float iou_thresh_, uint32_t max_output_, std::string scope_ = {}): iou_thresh(iou_thresh_), max_output(max_output_) { set_scope(scope_); }
};
class NonZero : public OpDefImplBase<NonZero> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
NonZero() = default;
NonZero(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
class NvOf : public OpDefImplBase<NvOf> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......@@ -2093,4 +2104,26 @@ public:
}
};
class Where : public OpDefImplBase<Where> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
Where() = default;
Where(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
class WhereBackward : public OpDefImplBase<WhereBackward> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
WhereBackward() = default;
WhereBackward(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
// clang-format on
......@@ -1553,6 +1553,11 @@ NMSKeepInst
.def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
.def_readwrite("max_output", &NMSKeep::max_output);
py::class_<NonZero, std::shared_ptr<NonZero>, OpDef> NonZeroInst(m, "NonZero");
NonZeroInst
.def(py::init<>());
py::class_<NvOf, std::shared_ptr<NvOf>, OpDef> NvOfInst(m, "NvOf");
NvOfInst
......@@ -2128,4 +2133,14 @@ WarpPerspectiveBackwardMatInst
.def_readwrite("format", &WarpPerspectiveBackwardMat::format)
.def_readwrite("border_val", &WarpPerspectiveBackwardMat::border_val);
py::class_<Where, std::shared_ptr<Where>, OpDef> WhereInst(m, "Where");
WhereInst
.def(py::init<>());
py::class_<WhereBackward, std::shared_ptr<WhereBackward>, OpDef> WhereBackwardInst(m, "WhereBackward");
WhereBackwardInst
.def(py::init<>());
// clang-format on
#include "megbrain/comp_node_api.h"
#include <unordered_map>
#include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h"
#include "megbrain_build_config.h"
namespace mgb {
namespace pubapi {
std::unordered_map<std::string, mgb::CompNode*>& cn_cache() {
static std::unordered_map<std::string, mgb::CompNode*> cn_map;
return cn_map;
}
class CompNodeDepedentObjectInst final : public CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override { return {}; }
public:
bool is_finalized() const { return CompNodeDepedentObject::is_finalized(); }
};
bool is_finalize() {
static CompNodeDepedentObjectInst* obj = new CompNodeDepedentObjectInst;
return obj->is_finalized();
}
void sync(mgbComputeNode_t cn) {
if (!is_finalize()) {
auto* s = reinterpret_cast<mgb::CompNode*>(cn);
if (s->valid())
s->sync();
}
}
mgbComputeNode_t load_cuda_cn(int device_id, int stream) {
std::string loc = ssprintf("gpu%i:%i", device_id, stream);
mgb_assert(!is_finalize());
auto& cache = cn_cache();
if (cache.find(loc) == cache.end()) {
auto* cn = new mgb::CompNode;
(*cn) = mgb::CompNode::load(loc);
mgb_assert(cn->to_string_physical() == loc);
cache[loc] = cn;
cn->activate();
}
return reinterpret_cast<mgbComputeNode_t>(cache[loc]);
}
void unload_cuda_cn(mgbComputeNode_t cn) {
auto* device = reinterpret_cast<mgb::CompNode*>(cn);
auto& cache = cn_cache();
mgb_assert(
cache.find(device->to_string_physical()) != cache.end() &&
device == cache[device->to_string_physical()]);
cache.erase(device->to_string_physical());
delete device;
}
void* alloc(mgbComputeNode_t device, size_t s) {
if (s == 0)
return nullptr;
auto* cn = reinterpret_cast<mgb::CompNode*>(device);
mgb_assert(!is_finalize());
return cn->alloc_device(s);
}
void dealloc(mgbComputeNode_t device, void* addr) {
if (addr != nullptr) {
auto* cn = reinterpret_cast<mgb::CompNode*>(device);
if (!is_finalize()) {
cn->free_device(addr);
}
}
}
void* get_cuda_stream(mgbComputeNode_t device) {
void* rst = nullptr;
#if MGB_CUDA
auto* cn = reinterpret_cast<mgb::CompNode*>(device);
MGB_TRY { rst = CompNodeEnv::from_comp_node(*cn).cuda_env().stream; }
MGB_CATCH(MegBrainError & exc, {
mgb_log_error("failed to get stream: %s", exc.what());
})
#else
mgb_log_error("megbrain compiled without cuda support!");
#endif
return rst;
}
MGB_API DeviceLocator get_physical_location(mgbComputeNode_t device) {
auto location = reinterpret_cast<CompNode*>(device)->locator().to_physical();
return {location.device, location.stream};
}
} // namespace pubapi
} // namespace mgb
#pragma once
#include <cstddef>
#if defined(_WIN32)
#define MGB_API __declspec(dllexport)
#else
#define MGB_API __attribute__((visibility("default")))
#endif
namespace mgb {
namespace pubapi {
typedef struct _MgbComputeNode* mgbComputeNode_t;
struct DeviceLocator {
int device = -1;
int stream = -1;
};
MGB_API mgbComputeNode_t load_cuda_cn(int device_id, int stream_id);
MGB_API void unload_cuda_cn(mgbComputeNode_t);
MGB_API void* alloc(mgbComputeNode_t cn, size_t);
MGB_API void dealloc(mgbComputeNode_t cn, void* addr);
MGB_API void* get_cuda_stream(mgbComputeNode_t cn);
MGB_API DeviceLocator get_physical_location(mgbComputeNode_t);
MGB_API void sync(mgbComputeNode_t cn);
MGB_API bool is_finalize();
} // namespace pubapi
} // namespace mgb
......@@ -153,6 +153,8 @@ def Argmin : MgbHashableOp<"Argmin", [AxisParam]>;
def CondTake : MgbHashableOp<"CondTake">;
def NonZero: MgbHashableOp<"NonZero",[EmptyParam]>;
def TopK: MgbHashableOp<"TopK", [TopKParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
......@@ -640,4 +642,9 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
def Cross: MgbHashableOp<"Cross", [CrossParam]>;
def Where: MgbHashableOp<"Where", [EmptyParam]>;
def WhereBackward: MgbHashableOp<"WhereBackward", [EmptyParam]>;
#endif // MGB_OPS
......@@ -385,6 +385,71 @@ void CondTake::scn_do_execute() {
}
}
/* ================= NonZero ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NonZero);
NonZero::NonZero(VarNode* data, const Param& param, const OperatorNodeConfig& config)
: Super(data->owner_graph(), config, "NonZero", {data}) {
init_megdnn_opr(*this, param);
add_input({data});
auto dtype = megdnn_opr()->infer_type(data->dtype());
output(0)
->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.dtype(dtype);
}
NonZero::NodeProp* NonZero::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
SymbolVar NonZero::make(
SymbolVar data, const Param& param, const OperatorNodeConfig& config) {
auto ret = data.insert_single_output_opr<NonZero>(data.node(), param, config);
return ret;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(NonZero) {
MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr);
mgb_assert(!wrt_idx);
return nullptr;
}
#endif
void NonZero::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_workspace_shape = [this](TensorShape& dest, const InpVal& iv) {
auto dtype = this->input(0)->dtype();
TensorLayout ily(iv.val[0].shape(), dtype);
dest.ndim = 1;
dest.shape[0] = this->megdnn_opr()->get_workspace_in_bytes(ily);
return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(1),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace_shape});
}
void NonZero::add_input_layout_constraint() {
mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}
void NonZero::scn_do_execute() {
auto&& data = input(0)->dev_tensor();
intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()};
if (data.layout().is_empty()) {
dyn_malloc.alloc_output(0, dtype::Int32(), {data.layout().ndim, 0}, nullptr);
} else {
megdnn_opr()->exec(
data.as_megdnn(), intl::get_megdnn_workspace_from_var(output().back()),
&dyn_malloc);
}
}
/* ================= TopK ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TopK);
......@@ -605,4 +670,158 @@ void CheckNonFinite::add_input_layout_constraint() {
i->add_layout_constraint_contiguous();
}
}
/* ================= Where ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WhereForward);
WhereForward::WhereForward(
VarNode* mask, VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config)
: Super(mask->owner_graph(), config, "where", {mask, data1, data2}) {
init_megdnn_opr(*this, param);
mgb_assert(mask->shape().eq_shape(data1->shape()));
mgb_assert(mask->shape().eq_shape(data2->shape()));
mgb_assert(data1->shape().eq_shape(data2->shape()));
add_input({mask, data1, data2});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE).dtype(data1->dtype());
}
SymbolVar WhereForward::make(
SymbolVar mask, SymbolVar data1, SymbolVar data2, const Param& param,
const OperatorNodeConfig& config) {
return mask.insert_single_output_opr<WhereForward>(
mask.node(), data1.node(), data2.node(), param, config);
}
void WhereForward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
auto ishp = iv.val[0].shape();
dest = ishp;
return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
auto mask_dtype = input(0)->dtype();
auto mask_shape = iv.val[0].shape();
TensorLayout mask_layout(mask_shape, mask_dtype);
auto data_dtype = input(1)->dtype();
auto data_shape = iv.val[1].shape();
TensorLayout data_layout(data_shape, data_dtype);
dest.ndim = 1;
dest[0] = megdnn_opr()->get_workspace_in_bytes(
mask_layout, data_layout, data_layout, data_layout);
return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(1), {SourceType::DEP,
{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}},
infer_workspace});
}
void WhereForward::add_input_layout_constraint() {
mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}
void WhereForward::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty());
return;
}
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
}
MAKE_NODE_PROP_WITH_ZERO_SHAPE_3(WhereForward, 0, 1, 2);
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(WhereForward) {
mgb_assert(out_grad.size() == 2 && !out_grad[1]);
VarNodeArray ret;
SymbolVarArray grad;
grad = WhereBackward::make(out_grad[0], opr.input(0), opr.param());
if (wrt_idx == 0)
return nullptr;
else if (wrt_idx == 1)
return grad[0].node();
else if (wrt_idx == 2)
return grad[1].node();
return nullptr;
}
#endif
/* ================= WhereBackward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WhereBackward);
WhereBackward::WhereBackward(
VarNode* diff, VarNode* mask, const Param& param,
const OperatorNodeConfig& config)
: Super(diff->owner_graph(), config, "WhereBackward", {diff, mask}) {
init_megdnn_opr(*this, param);
mgb_assert(diff->shape().eq_shape(mask->shape()));
add_input({diff, mask});
output(0)->dtype(diff->dtype());
output(1)->dtype(diff->dtype());
}
SymbolVarArray WhereBackward::make(
SymbolVar diff, SymbolVar mask, const Param& param,
const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<WhereBackward>(
diff.node(), mask.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs)
ret.emplace_back(out);
return ret;
}
void WhereBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
auto ishp = iv.val[0].shape();
dest = ishp;
return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
owner_graph()->static_infer_manager().register_shape_infer(
output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
auto diff_dtype = input(0)->dtype();
auto diff_shape = iv.val[0].shape();
TensorLayout diff_layout(diff_shape, diff_dtype);
auto mask_dtype = input(1)->dtype();
auto mask_shape = iv.val[1].shape();
TensorLayout mask_layout(mask_shape, mask_dtype);
dest.ndim = 1;
dest[0] = megdnn_opr()->get_workspace_in_bytes(
diff_layout, mask_layout, diff_layout, diff_layout);
return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(2), {SourceType::DEP,
{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}},
infer_workspace});
}
void WhereBackward::add_input_layout_constraint() {
mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}
void WhereBackward::scn_do_execute() {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/misc.h"
#include "megbrain/serialization/sereg.h"
......@@ -31,6 +32,19 @@ struct OprMaker<opr::CondTake, 2> {
}
};
template <>
struct OprMaker<opr::NonZero, 1> {
using Opr = opr::NonZero;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
auto out = Opr::make(inputs[0], param, config);
return out.node()->owner_opr();
}
};
template <>
struct OprMaker<opr::TopK, 2> {
using Opr = opr::TopK;
......@@ -57,8 +71,66 @@ struct OprMaker<opr::CheckNonFinite, 0> {
}
};
template <>
struct OprMaker<opr::Where, 3> {
using Opr = opr::Where;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
auto out = Opr::make(inputs[0], inputs[1], inputs[2], param, config);
return out.node()->owner_opr();
}
};
template <>
struct OprMaker<opr::WhereBackward, 4> {
using Opr = opr::WhereBackward;
using Param = opr::WhereBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
auto out = Opr::make(inputs[0], inputs[1], param, config);
return out[0].node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::Where, 3> {
using Opr = opr::Where;
using Mode = opr::Elemwise::Mode;
using PersisParam = opr::Where::Param;
using PersisWhereParam = opr::Where::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
auto mask = SymbolVar(inputs[0]);
auto x = SymbolVar(inputs[1]);
mask = opr::TypeCvt::make(mask, x.dtype());
auto y = SymbolVar(inputs[2]);
auto oup = opr::Elemwise::make({mask, x}, Mode::SWITCH_GT0);
auto ksam = 1.0f - mask;
oup = oup + opr::Elemwise::make({ksam, y}, Mode::SWITCH_GT0);
return oup.node()->owner_opr();
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
return OprMaker<opr::Where, 3>::make(
ctx.read_param<PersisParam>(), inputs, ctx.graph(), config);
}
};
} // namespace serialization
// OprMaker in MGB_SEREG_OPR only support unique output opr
namespace opr {
MGB_SEREG_OPR(Argmax, 1);
......@@ -66,7 +138,14 @@ MGB_SEREG_OPR(Argmin, 1);
MGB_SEREG_OPR(Argsort, 1);
MGB_SEREG_OPR(ArgsortBackward, 3);
MGB_SEREG_OPR(CondTake, 2);
MGB_SEREG_OPR(NonZero, 1);
MGB_SEREG_OPR(TopK, 2);
MGB_SEREG_OPR_CONDITION(Where, 3, false);
MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0(
Where, 3, (mgb::serialization::OprLoadDumpImplV2<opr::Where, 3>::replace_opr),
VERSION_1, VERSION_1);
MGB_SEREG_OPR(WhereBackward, 4)
//! current cumsum version
using CumsumV1 = opr::Cumsum;
MGB_SEREG_OPR(CumsumV1, 1);
......
......@@ -142,6 +142,9 @@ using TopKBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::TopK>>;
using CheckNonFiniteBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::CheckNonFinite>>;
using NonZeroBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::NonZero>>;
} // namespace intl
/*!
......@@ -162,6 +165,18 @@ public:
SymbolVar data, SymbolVar mask, const Param& param,
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(NonZero, intl::NonZeroBase) // {
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void add_input_layout_constraint() override;
NodeProp* do_make_node_prop() const override;
public:
MGE_WIN_DECLSPEC_FUC NonZero(
VarNode* data, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar data, const Param& param, const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(TopK, intl::TopKBase) // {
void init_output_dtype() override;
......@@ -196,6 +211,40 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS(
WhereForward, intl::MegDNNOprWrapperFwd<megdnn::WhereForward>) // {
void scn_do_execute() override;
void init_output_static_infer_desc() override;
void add_input_layout_constraint() override;
NodeProp* do_make_node_prop() const override;
public:
MGE_WIN_DECLSPEC_FUC WhereForward(
VarNode* mask, VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar mask, SymbolVar data1, SymbolVar data2, const Param& param = {},
const OperatorNodeConfig& config = {});
};
using Where = WhereForward;
MGB_DEFINE_OPR_CLASS(
WhereBackward, intl::MegDNNOprWrapperFwd<megdnn::WhereBackward>) // {
void scn_do_execute() override;
void init_output_static_infer_desc() override;
void add_input_layout_constraint() override;
public:
MGE_WIN_DECLSPEC_FUC WhereBackward(
VarNode* diff, VarNode* mask, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar mask, const Param& param = {},
const OperatorNodeConfig& config = {});
};
} // namespace opr
} // namespace mgb
......
......@@ -262,6 +262,25 @@ TEST(TestOprMisc, CondTakeEmptyIO) {
check({1, 0});
}
TEST(TestOprMisc, NonZero) {
using Param = opr::NonZero::Param;
Param param;
HostTensorGenerator<> gen;
auto check = [&](const TensorShape& shp) {
auto host_data = gen(shp);
auto graph = ComputingGraph::make();
auto data = opr::Host2DeviceCopy::make(*graph, host_data);
auto out = opr::NonZero::make(data, param);
HostTensorND host_out;
auto func = graph->compile({make_callback_copy(out, host_out)});
func->execute();
};
check({1});
check({0});
check({1, 0});
check({1, 2, 3, 4, 5, 6, 7});
}
TEST(TestOprMisc, TopKValueOnly) {
auto run = [](bool dyn_k, bool non_contig) {
using Checker = AutoOprChecker<1, 1>;
......@@ -414,4 +433,57 @@ TEST(TestOprMisc, TopKGrad) {
EXPECT_TRUE(gk == nullptr);
}
TEST(TestOprMisc, Where) {
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Bool> gen_mask;
auto host_mask = gen_mask({2, 2, 2});
HostTensorGenerator<> gen_data{0, 1000};
auto host_data1 = gen_data({2, 2, 2}), host_data2 = gen_data({2, 2, 2});
auto mask = opr::Host2DeviceCopy::make(*graph, host_mask),
data1 = opr::Host2DeviceCopy::make(*graph, host_data1),
data2 = opr::Host2DeviceCopy::make(*graph, host_data2),
dst = opr::Where::make(mask, data1, data2, {});
HostTensorND host_dst;
auto func = graph->compile({make_callback_copy(dst, host_dst)});
func->execute();
auto pmask = host_mask->ptr<bool>();
auto pdata1 = host_data1->ptr<float>();
auto pdata2 = host_data2->ptr<float>();
auto pdst = host_dst.ptr<float>();
for (size_t i = 0; i < host_mask->layout().total_nr_elems(); ++i) {
ASSERT_EQ(pmask[i] ? pdata1[i] : pdata2[i], pdst[i]);
}
}
TEST(TestOprMisc, WhereBackward) {
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen_diff{0, 1000};
auto host_diff = gen_diff({2, 2, 2});
HostTensorGenerator<dtype::Bool> gen_mask;
auto host_mask = gen_mask({2, 2, 2});
auto diff = opr::Host2DeviceCopy::make(*graph, host_diff),
mask = opr::Host2DeviceCopy::make(*graph, host_mask);
auto grads = opr::WhereBackward::make(diff, mask, {});
auto grad1 = grads[0];
auto grad2 = grads[1];
HostTensorND host_grad1;
HostTensorND host_grad2;
auto func = graph->compile(
{make_callback_copy(grad1, host_grad1),
make_callback_copy(grad2, host_grad2)});
func->execute();
auto pdiff = host_diff->ptr<float>();
auto pmask = host_mask->ptr<bool>();
auto pgrad1 = host_grad1.ptr<float>();
auto pgrad2 = host_grad2.ptr<float>();
for (size_t i = 0; i < host_diff->layout().total_nr_elems(); ++i) {
ASSERT_EQ(pmask[i] ? pdiff[i] : 0, pgrad1[i]);
ASSERT_EQ(pmask[i] ? 0 : pdiff[i], pgrad2[i]);
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册