diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 9e25664c0df6ae3821368ee06873f158d3fa4e91..62a65f083b2558cd52e7bc75746a2ea4dbffbf97 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -410,6 +410,24 @@ protected: size_t workspace_in_bytes); }; +class NonZero : public OperatorBase { + DEF_OPR_IMPL(NonZero, OperatorBase, 1, 1); + DEF_OPR_PARAM(Empty); + +public: + DType infer_type(DType Data); + virtual size_t get_workspace_in_bytes(const TensorLayout& src) = 0; + + virtual TensorND exec( + _megdnn_tensor_in src, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); +}; + class TransposeForward : public OperatorBase { DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1); DEF_OPR_PARAM(Empty); @@ -1535,6 +1553,74 @@ public: }; using Norm = NormForward; +class WhereBase : public OperatorBase { + DEF_OPR_PARAM(Empty); + DEF_OPR_IMPL(WhereBase, OperatorBase, 3, 1); + +protected: + void deduce_layout_fwd( + const TensorLayout& mask, const TensorLayout& data1, + const TensorLayout& data2, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& mask, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& dst); + void deduce_layout_bwd( + const TensorLayout& diff, const TensorLayout& mask, + TensorLayout& grad_data1, TensorLayout& grad_data2); + void check_layout_bwd( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2); +}; + +class WhereForward : public WhereBase { + DEF_OPR_IMPL(WhereForward, WhereBase, 3, 1); + +public: + virtual void exec( + _megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& mask, const TensorLayout& data1, + const TensorLayout& data2, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& mask, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& mask, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& dst, + size_t get_workspace_in_bytes); +}; +using Where = WhereForward; + +class WhereBackward : public WhereBase { + DEF_OPR_IMPL(WhereBackward, WhereBase, 2, 2); + +public: + /** + * \param[in] diff the backpropagated gradient wrt. dst + * \param[in] mask the `mask' parameter in WhereForward::exec + * \param[out] grad1 the backpropagated gradient wrt. data1 + */ + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in mask, + _megdnn_tensor_out grad_data1, _megdnn_tensor_out grad_data2, + _megdnn_workspace workspace) = 0; + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& diff, const TensorLayout& mask, + TensorLayout& grad_data1, TensorLayout& grad_data2); + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2, + size_t get_workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 88bbdbc56b191000d4b9b7fae1b5fb049702213b..84f2966980bcaa39b04c36df96d3f56998ec464a 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -225,7 +225,10 @@ private: cb(MaskedFill) \ cb(MultiHeadAttnForward)\ cb(MultiHeadAttnBackward) \ - cb(Cross) + cb(Cross) \ + cb(WhereForward) \ + cb(WhereBackward) \ + cb(NonZero) // clang-format on /*! diff --git a/dnn/src/common/non_zero.cpp b/dnn/src/common/non_zero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b43f4ebd4c435aaa32f06b4867f75b81c52edaeb --- /dev/null +++ b/dnn/src/common/non_zero.cpp @@ -0,0 +1,18 @@ +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +void NonZero::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { + dst.dtype.assert_is(infer_type(src.dtype)); + + if (!src.is_empty()) + megdnn_assert(src.is_physical_contiguous()); + auto require_workspace_in_bytes = get_workspace_in_bytes(src); + megdnn_assert(workspace_in_bytes >= require_workspace_in_bytes); +} + +DType NonZero::infer_type(DType /*input*/) { + return dtype::Int32(); +} +}; // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index b4fd26b65d08ffac455b09968245764fa8101a7e..458a73446107385492e3ce45d09253f6bfb4535a 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -17,6 +17,9 @@ struct OprTrait {}; static const bool can_deduce_layout = CanDeduceLayout; \ } +DEF(NonZero, 2, true, false); +DEF(Where, 4, true, true); +DEF(WhereBackward, 4, true, true); DEF(Norm, 2, true, true); DEF(Padding, 2, false, true); DEF(PaddingBackward, 2, false, false); diff --git a/dnn/src/common/where.cpp b/dnn/src/common/where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f69e6ccb4b9ffe93431207fd2010da15fb2a165d --- /dev/null +++ b/dnn/src/common/where.cpp @@ -0,0 +1,124 @@ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +void WhereBase::deduce_layout_fwd( + const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2, + TensorLayout& dst) { + if (!mask.is_empty()) + megdnn_assert(mask.is_physical_contiguous()); + if (!data1.is_empty()) + megdnn_assert(data1.is_physical_contiguous()); + if (!data2.is_empty()) + megdnn_assert(data2.is_physical_contiguous()); + if (!dst.is_empty()) + megdnn_assert(dst.is_physical_contiguous()); + + auto errmsg = [&]() { + return megdnn_layout_msg(mask) + ", " + megdnn_layout_msg(data1) + ", " + + megdnn_layout_msg(data2) + ", " + megdnn_layout_msg(dst); + }; + auto mask_dtype = mask.dtype, data1_dtype = data1.dtype, data2_dtype = data2.dtype; + megdnn_assert(mask_dtype.category() == DTypeCategory::BOOL); + megdnn_assert( + data1_dtype == data2_dtype && + (data1_dtype.category() == DTypeCategory::INT || + data1_dtype.category() == DTypeCategory::FLOAT || + data1_dtype.category() == DTypeCategory::BOOL)); + megdnn_assert(data1.ndim == data2.ndim, "%s", errmsg().c_str()); + megdnn_assert(data1.ndim == mask.ndim, "%s", errmsg().c_str()); + + dst = TensorLayout{data1}; +} + +void WhereBase::check_layout_fwd( + const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst) { + TensorLayout dst_expected; + megdnn_assert_eq_shape(mask, data1); + megdnn_assert_eq_dtype(data1, dst); + megdnn_assert_eq_shape(data1, data2); + deduce_layout_fwd(mask, data1, data2, dst_expected); + megdnn_assert_eq_shape(dst_expected, dst); +} + +void WhereForward::deduce_layout( + const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2, + TensorLayout& dst) { + deduce_layout_fwd(mask, data1, data2, dst); +} + +void WhereForward::check_exec( + const TensorLayout& mask, const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst, size_t workspace_in_bytes) { + check_layout_fwd(mask, data1, data2, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(mask, data1, data2, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void WhereBase::deduce_layout_bwd( + const TensorLayout& diff, const TensorLayout& mask, TensorLayout& grad_data1, + TensorLayout& grad_data2) { + if (!diff.is_empty()) + megdnn_assert(diff.is_physical_contiguous()); + if (!mask.is_empty()) + megdnn_assert(mask.is_physical_contiguous()); + if (!grad_data1.is_empty()) + megdnn_assert(grad_data1.is_physical_contiguous()); + if (!grad_data2.is_empty()) + megdnn_assert(grad_data2.is_physical_contiguous()); + + auto errmsg = [&]() { + return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(mask) + ", " + + megdnn_layout_msg(grad_data1) + megdnn_layout_msg(grad_data2); + }; + auto diff_dtype = diff.dtype, mask_dtype = mask.dtype; + megdnn_assert(mask_dtype.category() == DTypeCategory::BOOL); + megdnn_assert( + diff_dtype.category() == DTypeCategory::INT || + diff_dtype.category() == DTypeCategory::FLOAT); + megdnn_assert(diff.ndim == mask.ndim, "%s", errmsg().c_str()); + + grad_data1 = TensorLayout{diff}; + grad_data2 = TensorLayout{diff}; +} + +void WhereBase::check_layout_bwd( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2) { + TensorLayout grad_expected1; + TensorLayout grad_expected2; + + megdnn_assert_eq_shape(diff, mask); + megdnn_assert_eq_shape(diff, grad_data1); + megdnn_assert_eq_dtype(diff, grad_data1); + megdnn_assert_eq_shape(diff, grad_data2); + megdnn_assert_eq_dtype(diff, grad_data2); + + deduce_layout_bwd(diff, mask, grad_expected1, grad_expected2); + + megdnn_assert_eq_shape(grad_expected1, grad_data1); + megdnn_assert_eq_dtype(grad_expected1, grad_data1); + megdnn_assert_eq_shape(grad_expected2, grad_data2); + megdnn_assert_eq_dtype(grad_expected2, grad_data2); +} + +void WhereBackward::deduce_layout( + const TensorLayout& diff, const TensorLayout& mask, TensorLayout& grad_data1, + TensorLayout& grad_data2) { + deduce_layout_bwd(diff, mask, grad_data1, grad_data2); +} + +void WhereBackward::check_exec( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2, + size_t workspace_in_bytes) { + check_layout_bwd(diff, mask, grad_data1, grad_data2); + auto required_workspace_in_bytes = + get_workspace_in_bytes(diff, mask, grad_data1, grad_data2); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index cf1c79b75a5d29842f39d48ff69f4173ef0091b2..5b0f9297b7297feb6e05d4a5ac8128f5bc270b0f 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -52,6 +52,7 @@ #include "src/cuda/max_tensor_diff/opr_impl.h" #include "src/cuda/mesh_indexing/opr_impl.h" #include "src/cuda/multi_head_attn/opr_impl.h" +#include "src/cuda/non_zero/opr_impl.h" #include "src/cuda/norm/opr_impl.h" #include "src/cuda/padding/opr_impl.h" #include "src/cuda/param_pack/opr_impl.h" @@ -84,6 +85,7 @@ #include "src/cuda/type_cvt/opr_impl.h" #include "src/cuda/warp_affine/opr_impl.h" #include "src/cuda/warp_perspective/opr_impl.h" +#include "src/cuda/where/opr_impl.h" namespace megdnn { namespace cuda { @@ -122,6 +124,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SVDForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(CondTake); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(NonZero); MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumsumForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); @@ -236,6 +239,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Cross); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(WhereForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(WhereBackward); template std::unique_ptr HandleImpl::create_operator() { diff --git a/dnn/src/cuda/non_zero/kernel.cu b/dnn/src/cuda/non_zero/kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9f099fa8a399de2f69ff423bf452d20618b908f7 --- /dev/null +++ b/dnn/src/cuda/non_zero/kernel.cu @@ -0,0 +1,79 @@ +#include "./kernel.cuh" +#include "src/cuda/query_blocksize.cuh" +#include "src/cuda/utils.cuh" + +using namespace megdnn; +using namespace megdnn::cuda; +using namespace megdnn::cuda::non_zero; + +namespace { + +__global__ void multi_div_size( + dt_int32* div_size, const size_t* shape_arr, size_t ndim) { + div_size[0] = 1; + for (int div_index = 1; div_index < ndim; div_index++) { + div_size[div_index] = shape_arr[ndim - div_index] * div_size[div_index - 1]; + } +} +__global__ void expansion( + dt_int32* pt, dt_int32* div_size, const size_t* shape_arr, int loop_count, + int index_size, int ndim) { + int dim_idx = blockIdx.x; + int offset_from_each_dim = (blockDim.x * threadIdx.y + threadIdx.x) + + (loop_count * (blockDim.x * blockDim.y)); + if (offset_from_each_dim >= index_size) + return; + int offset = dim_idx * (index_size) + offset_from_each_dim; + dt_int32* target_pt = pt + offset; + int dim_pos_of_ele = *target_pt / div_size[ndim - 1 - dim_idx]; + dt_int32 dim_index_of_ele = dim_pos_of_ele % shape_arr[dim_idx]; + target_pt[0] = dim_index_of_ele; +} + +__global__ void copy_kern(dt_int32* dest_idx, const dt_int32* src_idx, uint32_t size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < size && src_idx[tid] > src_idx[tid - 1]) { + uint32_t v = src_idx[tid] - 1; + dest_idx[v] = tid; + } +} +__global__ void set_zero(dt_int32* dest) { + dest[0] = 0; +} +} // namespace + +void megdnn::cuda::non_zero::copy_idx( + dt_int32* dest_idx, dt_int32* src_idx, uint32_t size, cudaStream_t stream) { + int nr_thread = query_blocksize_for_kernel(copy_kern); + int nr_block = DIVUP(size, nr_thread); + // Todo Set block and thread to 1 to set an element to 0. Need to consider + // optimization + set_zero<<<1, 1, 0, stream>>>(src_idx); + after_kernel_launch(); + copy_kern<<>>(dest_idx, src_idx + 1, size); + after_kernel_launch(); +} + +void megdnn::cuda::non_zero::expansion_index( + dt_int32* dst_pt, size_t index_size, const size_t* src_shape, + size_t* src_shape_workspace_pt, size_t src_ndim, dt_int32* div_workspace_pt, + cudaStream_t stream) { + cuda_check(cudaMemcpyAsync( + src_shape_workspace_pt, src_shape, sizeof(size_t) * 7, + cudaMemcpyHostToDevice, stream)); + // Todo change the cuda kernel to cpu for loop, or make the number of threads and + // blocks more reasonable + multi_div_size<<<1, 1, 0, stream>>>( + div_workspace_pt, src_shape_workspace_pt, src_ndim); + after_kernel_launch(); + dim3 threadsPerBlock( + std::min(NR_THREADS_X, index_size), + std::min(NR_THREADS_Y, DIVUP(index_size, NR_THREADS_X))); + int loop_size = DIVUP(index_size, (NR_THREADS_X * NR_THREADS_Y)); + for (int loop_idx = 0; loop_idx < loop_size; loop_idx++) { + expansion<<>>( + dst_pt, div_workspace_pt, src_shape_workspace_pt, loop_idx, index_size, + src_ndim); + after_kernel_launch(); + } +} \ No newline at end of file diff --git a/dnn/src/cuda/non_zero/kernel.cuh b/dnn/src/cuda/non_zero/kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..54147dc6f7d978590ffbc9bdfcd0862fb26b221c --- /dev/null +++ b/dnn/src/cuda/non_zero/kernel.cuh @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace non_zero { +void expansion_index( + dt_int32* dst_pt, size_t index_size, const size_t* src_shape, + size_t* src_shape_workspace_pt, size_t src_ndim, dt_int32* div_workspace_pt, + cudaStream_t stream); + +void copy_idx( + dt_int32* dest_idx, dt_int32* src_idx, uint32_t size, cudaStream_t stream); +} // namespace non_zero +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/non_zero/opr_impl.cpp b/dnn/src/cuda/non_zero/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..facfd6be14f849084ecd675cd2856a697735f7be --- /dev/null +++ b/dnn/src/cuda/non_zero/opr_impl.cpp @@ -0,0 +1,100 @@ + +#include "megdnn/dtype.h" + +#include "../cond_take/opr_impl.h" +#include "./kernel.cuh" +#include "./opr_impl.h" +#include "src/common/utils.h" +#include "src/cuda/cond_take/kern.cuh" +#include "src/cuda/handle.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace megdnn::cuda; +using namespace megdnn::cuda::non_zero; + +WorkspaceBundle NonZeroImpl::make_bundle(const TensorLayout& data) { + size_t nr_item = data.total_nr_elems(); + cuda_check(cudaSetDevice(concrete_handle(handle())->device_id())); + auto gen_idx_wk_size = cuda::cond_take::gen_idx_get_workspace_size(nr_item); + SmallVector sizes_in_bytes; + sizes_in_bytes.push_back((nr_item + 1) * sizeof(megdnn::cuda::cond_take::IdxType)); + sizes_in_bytes.push_back(gen_idx_wk_size); + // the two ele is the shape of arr and the reverse multiply arr of the shape + sizes_in_bytes.push_back(sizeof(TensorLayout::shape)); + sizes_in_bytes.push_back(sizeof(TensorLayout::shape)); + return {nullptr, sizes_in_bytes, handle()->alignment_requirement()}; +} + +size_t NonZeroImpl::get_workspace_in_bytes(const TensorLayout& data) { + return make_bundle(data).total_size_in_bytes(); +} + +TensorND NonZeroImpl::exec( + _megdnn_tensor_in data, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy) { + size_t size = data.layout.total_nr_elems(); + if (size == 0) { + TensorShape target_shape({data.layout.ndim, 0}); + TensorND rst = malloc_policy.alloc_output(0, dtype::Int32(), target_shape); + return rst; + } + auto wk_bundle = make_bundle(data.layout); + wk_bundle.set(workspace.raw_ptr); + + auto idx_tmp = static_cast(wk_bundle.get(0)); + + CondTake::Param param; + param.mode = CondTake::Param::Mode::NEQ; + param.val = 0; + param.eps = 1e-06; + megdnn::cond_take::KParam kparam(param); + + auto stream = cuda_stream(handle()); + size_t out_size; + switch (data.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + out_size = megdnn::cuda::cond_take::gen_idx( \ + wk_bundle.get(1), wk_bundle.get_size(1), idx_tmp, data.ptr(), \ + size, static_cast(param.mode), kparam, stream); \ + break; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + default : { + std::string data_type = data.layout.dtype.name(); + megdnn_throw( + "bad mask dtype,support_types is support types: [Float32, Float16, " + "BFloat16, Int32, Int16, Int8, Uint8, Bool]" + + std::string("but the data type is ") + data_type); + } + } + + TensorShape dst_shape({data.layout.ndim, out_size}); + TensorND out_idx = malloc_policy.alloc_output(0, dtype::Int32(), dst_shape); + dt_int32* out_idx_ptr = out_idx.ptr(); + + switch (data.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + for (size_t idx = 0; idx < data.layout.ndim; idx++) { \ + dt_int32* copy_idx_ptr = out_idx_ptr + idx * out_size; \ + copy_idx(copy_idx_ptr, idx_tmp, size, stream); \ + } \ + break; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + default : megdnn_throw("bad data dtype"); + } + expansion_index( + out_idx_ptr, out_size, data.layout.shape, + static_cast(wk_bundle.get(2)), data.layout.ndim, + static_cast(wk_bundle.get(3)), stream); + + return out_idx; +} \ No newline at end of file diff --git a/dnn/src/cuda/non_zero/opr_impl.h b/dnn/src/cuda/non_zero/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..6c74439d44a31c77e5ba8f57461dd5df6bf461a7 --- /dev/null +++ b/dnn/src/cuda/non_zero/opr_impl.h @@ -0,0 +1,20 @@ +#pragma once + +#include "megdnn/oprs/general.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace cuda { +class NonZeroImpl final : public NonZero { + WorkspaceBundle make_bundle(const TensorLayout& data); + +public: + using NonZero::NonZero; + virtual TensorND exec( + _megdnn_tensor_in src, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy); + size_t get_workspace_in_bytes(const TensorLayout& data) override; +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/where/common.cuh b/dnn/src/cuda/where/common.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4eaaa6f001a264259e870873558f21729c441410 --- /dev/null +++ b/dnn/src/cuda/where/common.cuh @@ -0,0 +1,26 @@ +#pragma once +#include +#include + +namespace megdnn { +namespace cuda { +namespace where { + +template +void forward_proxy( + const bool* __restrict mask, const T* __restrict data1, + const T* __restrict data2, T* __restrict dst, size_t n, cudaStream_t stream); + +} // namespace where + +namespace where_backward { + +template +void backward_proxy( + const T* __restrict diff, const bool* mask, T* __restrict grad_data1, + T* __restrict grad_data2, size_t n, cudaStream_t stream); + +} // namespace where_backward + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/where/opr_impl.h b/dnn/src/cuda/where/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..0231d303bb2fbf3fa40a4c2a917449f747256172 --- /dev/null +++ b/dnn/src/cuda/where/opr_impl.h @@ -0,0 +1,35 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class WhereForwardImpl : public WhereForward { +public: + using WhereForward::WhereForward; + void exec( + _megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& mask, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& dst) override { + return 0; + } +}; + +class WhereBackwardImpl : public WhereBackward { +public: + using WhereBackward::WhereBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in mask, + _megdnn_tensor_out grad_data1, _megdnn_tensor_out grad_data2, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/where/where_backward.cpp b/dnn/src/cuda/where/where_backward.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4dc0f00fd7a15673081641ba79ddd2555f59f142 --- /dev/null +++ b/dnn/src/cuda/where/where_backward.cpp @@ -0,0 +1,30 @@ +#include "src/cuda/where/common.cuh" +#include "src/cuda/where/opr_impl.h" + +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void WhereBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in mask, _megdnn_tensor_out grad_data1, + _megdnn_tensor_out grad_data2, _megdnn_workspace workspace) { + check_exec( + diff.layout, mask.layout, grad_data1.layout, grad_data2.layout, + workspace.size); + auto stream = cuda_stream(this->handle()); + auto n = diff.layout.total_nr_elems(); +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + where_backward::backward_proxy( \ + diff.ptr(), mask.ptr(), grad_data1.ptr(), \ + grad_data2.ptr(), n, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb +} + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/where/where_backward.cu b/dnn/src/cuda/where/where_backward.cu new file mode 100644 index 0000000000000000000000000000000000000000..aad4446b8fd253baf068e63792b11f9275e44827 --- /dev/null +++ b/dnn/src/cuda/where/where_backward.cu @@ -0,0 +1,45 @@ +#include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" +#include "src/cuda/where/common.cuh" + +namespace { + +template +__global__ void backward_kernel( + const T* __restrict diff, const bool* __restrict mask, T* __restrict grad_data1, + T* __restrict grad_data2, size_t n) { + size_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < n) { + grad_data1[i] = mask[i] ? diff[i] : 0; + grad_data2[i] = mask[i] ? 0 : diff[i]; + } +} +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace where_backward { + +template +void backward_proxy( + const T* __restrict diff, const dt_bool* __restrict mask, + T* __restrict grad_data1, T* __restrict grad_data2, size_t n, + cudaStream_t stream) { + if (n == 0) + return; + backward_kernel<<>>( + diff, mask, grad_data1, grad_data2, n); + after_kernel_launch(); +} + +#define INST(T) \ + template void backward_proxy( \ + const T* __restrict, const dt_bool* __restrict, T* __restrict, \ + T* __restrict, size_t, cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +cb(::megdnn::dtype::Bool) + +} // namespace where_backward +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/where/where_forward.cpp b/dnn/src/cuda/where/where_forward.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1715e616d1bb2b43a7db29f1844ebf0dc1e03b4f --- /dev/null +++ b/dnn/src/cuda/where/where_forward.cpp @@ -0,0 +1,28 @@ +#include "src/cuda/where/common.cuh" +#include "src/cuda/where/opr_impl.h" + +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void WhereForwardImpl::exec( + _megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(mask.layout, data1.layout, data2.layout, dst.layout, workspace.size); + auto stream = cuda_stream(this->handle()); + auto n = data1.layout.total_nr_elems(); +#define cb(DType) \ + if (data1.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + where::forward_proxy( \ + mask.ptr(), data1.ptr(), data2.ptr(), \ + dst.ptr(), n, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb +} + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/where/where_forward.cu b/dnn/src/cuda/where/where_forward.cu new file mode 100644 index 0000000000000000000000000000000000000000..99e5528348426342bc6ae2a3615aae122bbd6d73 --- /dev/null +++ b/dnn/src/cuda/where/where_forward.cu @@ -0,0 +1,42 @@ +#include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" +#include "src/cuda/where/common.cuh" + +namespace { + +template +__global__ void forward_kernel( + const bool* __restrict mask, const T* __restrict data1, + const T* __restrict data2, T* __restrict dst, size_t n) { + size_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < n) { + dst[i] = mask[i] ? data1[i] : data2[i]; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace where { + +template +void forward_proxy( + const dt_bool* __restrict mask, const T* __restrict data1, + const T* __restrict data2, T* __restrict dst, size_t n, cudaStream_t stream) { + forward_kernel<<>>( + mask, data1, data2, dst, n); + after_kernel_launch(); +} + +#define INST(T) \ + template void forward_proxy( \ + const dt_bool* __restrict, const T* __restrict, const T* __restrict, \ + T* __restrict, size_t, cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +cb(::megdnn::dtype::Bool) + +} // namespace where +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 7b4141895d95b0cfc3e0169fbd27d5d7f8353a02..21e3158ceb6a6ad1033d10c0840eb23b7c526d5d 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -56,6 +56,7 @@ #include "src/naive/max_tensor_diff/opr_impl.h" #include "src/naive/mesh_indexing/opr_impl.h" #include "src/naive/multi_head_attn/opr_impl.h" +#include "src/naive/non_zero/opr_impl.h" #include "src/naive/norm/opr_impl.h" #include "src/naive/padding/opr_impl.h" #include "src/naive/param_pack/opr_impl.h" @@ -90,6 +91,7 @@ #include "src/naive/type_cvt/opr_impl.h" #include "src/naive/warp_affine/opr_impl.h" #include "src/naive/warp_perspective/opr_impl.h" +#include "src/naive/where/opr_impl.h" namespace megdnn { namespace naive { diff --git a/dnn/src/naive/non_zero/opr_impl.cpp b/dnn/src/naive/non_zero/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17beb4f4598a82cbf28863ec95ce9b2de5093f0d --- /dev/null +++ b/dnn/src/naive/non_zero/opr_impl.cpp @@ -0,0 +1,95 @@ +#include "./opr_impl.h" +#include "src/common/cond_take/predicate.cuh" +#include "src/common/utils.h" +#include "src/naive/handle.h" +#include "src/naive/non_zero/opr_impl.h" + +using namespace megdnn; +using namespace naive; + +using Param = NonZero::Param; +size_t NonZeroImpl::get_workspace_in_bytes(const TensorLayout& data) { + // save the size of index array in the last element of workspace + return (data.total_nr_elems() + 1) * sizeof(dt_int32); +} +template +void gen_index(dt_int32* dest, const TensorND& src, cond_take::Pred pred) { + int idx = 0; + ctype* inp = src.ptr(); + size_t number_of_data = src.layout.total_nr_elems(); + for (size_t data_pos = 0; data_pos < number_of_data; ++data_pos) { + if (pred(inp[data_pos])) { + dest[idx++] = data_pos; + } + } + // last element is the size of index array + dest[number_of_data] = idx; +} + +void expansion_index( + const dt_int32* const index_arr, const size_t index_size, const TensorND* rst, + const size_t* shape_arr, const int ndim) { + SmallVector shape_reverse_multiply_reduce_arr({1}); + for (int div_index = 1; div_index < ndim; div_index++) { + shape_reverse_multiply_reduce_arr[div_index] = + shape_arr[ndim - div_index] * + shape_reverse_multiply_reduce_arr[div_index - 1]; + } + + for (int dim_pos = 0; dim_pos < ndim; dim_pos++) { + dt_int32* dim_pt = rst->ptr() + index_size * dim_pos; + for (size_t ele_pos = 0; ele_pos < index_size; ele_pos++) { + int dim_pos_of_ele = index_arr[ele_pos] / + shape_reverse_multiply_reduce_arr[ndim - 1 - dim_pos]; + int dim_index_of_ele = dim_pos_of_ele % shape_arr[dim_pos]; + dim_pt[ele_pos] = dim_index_of_ele; + } + } +} + +TensorND NonZeroImpl::exec( + _megdnn_tensor_in src, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy) { +#if !MGE_BUILD_WITHOUT_NAIVE_EXEC + auto idx_tmp = workspace.ptr(); + + switch (src.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + using namespace cond_take; \ + KParam param({}); \ + param.val = 0.0; \ + param.eps = 1e-6; \ + Pred pred(param); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(gen_index(idx_tmp, src, pred)); \ + break; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + default : { + std::string data_type = src.layout.dtype.name(); + megdnn_throw( + "bad mask dtype,support_types is support types: [Float32, Float16, " + "BFloat16, Int32, Int16, Int8, Uint8, Bool]" + + std::string("but the data type is ") + data_type); + } + } + + static_cast(handle())->megcore_dispatcher()->sync(); + size_t index_size_pos = src.layout.total_nr_elems(); + size_t index_size = idx_tmp[index_size_pos]; + TensorND ret; + size_t ndim = src.layout.ndim; + TensorShape dst_shape({ndim, index_size}); + ret = malloc_policy.alloc_output(0, dtype::Int32(), {ndim, index_size}); + MEGDNN_DISPATCH_CPU_KERN_OPR( + expansion_index(idx_tmp, index_size, &ret, src.layout.shape, ndim)); + + return ret; +#else + __builtin_trap(); + return {}; +#endif +} diff --git a/dnn/src/naive/non_zero/opr_impl.h b/dnn/src/naive/non_zero/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..a743adf97b43ecc3fa624fbe1f9a3e76096e6b93 --- /dev/null +++ b/dnn/src/naive/non_zero/opr_impl.h @@ -0,0 +1,15 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { +class NonZeroImpl : public NonZero { +public: + using NonZero::NonZero; + TensorND exec( + _megdnn_tensor_in src, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy); + size_t get_workspace_in_bytes(const TensorLayout& src); +}; +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/where/opr_impl.cpp b/dnn/src/naive/where/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1677a2475f248041fc635e1ba77a80fca795317f --- /dev/null +++ b/dnn/src/naive/where/opr_impl.cpp @@ -0,0 +1,67 @@ +#include "src/naive/where/opr_impl.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace megdnn { +namespace naive { + +template +void WhereForwardImpl::exec_internal( + const dt_bool* __restrict mask, const T* __restrict data1, + const T* __restrict data2, T* __restrict dst, size_t n) { + rep(i, n) { dst[i] = mask[i] ? data1[i] : data2[i]; } +} + +void WhereForwardImpl::exec( + _megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(mask.layout, data1.layout, data2.layout, dst.layout, workspace.size); + auto n = data1.layout.total_nr_elems(); +#define cb(DType) \ + if (data1.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal( \ + mask.ptr(), data1.ptr(), data2.ptr(), \ + dst.ptr(), n)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + megdnn_assert_internal(0); +} + +template +void WhereBackwardImpl::exec_internal( + const T* __restrict diff, const dt_bool* __restrict mask, + T* __restrict grad_data1, T* __restrict grad_data2, size_t n) { + rep(i, n) { + grad_data1[i] = mask[i] ? diff[i] : 0; + grad_data2[i] = mask[i] ? 0 : diff[i]; + } +} + +void WhereBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in mask, _megdnn_tensor_out grad_data1, + _megdnn_tensor_out grad_data2, _megdnn_workspace workspace) { + check_exec( + diff.layout, mask.layout, grad_data1.layout, grad_data2.layout, + workspace.size); + auto n = diff.layout.total_nr_elems(); +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal( \ + diff.ptr(), mask.ptr(), grad_data1.ptr(), \ + grad_data2.ptr(), n)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + megdnn_assert_internal(0); +} + +} // namespace naive +} // namespace megdnn diff --git a/dnn/src/naive/where/opr_impl.h b/dnn/src/naive/where/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..31d357cc440dd578f22c88ff3ef86802dacce011 --- /dev/null +++ b/dnn/src/naive/where/opr_impl.h @@ -0,0 +1,52 @@ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace naive { + +class WhereForwardImpl : public WhereForward { +public: + using WhereForward::WhereForward; + void exec( + _megdnn_tensor_in mask, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + +private: + template + void exec_internal( + const dt_bool* __restrict mask, const T* __restrict data1, + const T* __restrict data2, T* __restrict dst, size_t n); +}; + +class WhereBackwardImpl : public WhereBackward { +public: + using WhereBackward::WhereBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in mask, + _megdnn_tensor_out grad_data1, _megdnn_tensor_out grad_data2, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& mask, + const TensorLayout& grad_data1, const TensorLayout& grad_data2) override { + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(mask); + MEGDNN_MARK_USED_VAR(grad_data1); + MEGDNN_MARK_USED_VAR(grad_data2); + return 0; + } + +private: + template + void exec_internal( + const T* __restrict diff, const dt_bool* __restrict mask, + T* __restrict grad_data1, T* __restrict grad_data2, size_t n); +}; + +} // namespace naive +} // namespace megdnn diff --git a/dnn/test/common/non_zero.cpp b/dnn/test/common/non_zero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fcea793729eddbea9d0f3d260171bf38fd6f2f7d --- /dev/null +++ b/dnn/test/common/non_zero.cpp @@ -0,0 +1,105 @@ +#include "./non_zero.h" +#include "./rng.h" +#include "./tensor.h" +#include "./utils.h" +#include "test/common/checker.h" + +using namespace megdnn; +using namespace test; + +using Param = NonZero::Param; + +std::vector NonZeroTestcase::make() { + Param param; + std::vector ret; + + TensorShape shape1{4}; + ret.push_back({param, TensorLayout{shape1, dtype::Int8()}}); + NonZeroTestcase& case1 = ret.back(); + case1.m_mem.reset(new uint8_t[case1.m_data.layout.span().dist_byte()]); + memset(case1.m_mem.get(), 0, case1.m_data.layout.span().dist_byte()); + case1.m_data.reset_ptr(case1.m_mem.get()); + dt_int8* pt_1 = reinterpret_cast(case1.m_mem.get()); + pt_1[3] = 1; + case1.correct_answer.push_back(3); + + TensorShape shape2{1, 1, 1, 1, 1, 1, 1}; + ret.push_back({param, TensorLayout{shape2, dtype::Float32()}}); + NonZeroTestcase& case2 = ret.back(); + case2.m_mem.reset(new uint8_t[case2.m_data.layout.span().dist_byte()]); + memset(case2.m_mem.get(), 0, case2.m_data.layout.span().dist_byte()); + case2.m_data.reset_ptr(case2.m_mem.get()); + dt_float32* pt_2 = reinterpret_cast(case2.m_mem.get()); + pt_2[0] = 1.0; + case2.correct_answer = {0, 0, 0, 0, 0, 0, 0}; + + TensorShape shape3{0}; + ret.push_back({param, TensorLayout{shape3, dtype::Float32()}}); + NonZeroTestcase& case3 = ret.back(); + case3.m_mem.reset(new uint8_t[case3.m_data.layout.span().dist_byte()]); + memset(case3.m_mem.get(), 0, case3.m_data.layout.span().dist_byte()); + case3.m_data.reset_ptr(case3.m_mem.get()); + case3.correct_answer = {}; + + TensorShape shape4{1, 2, 3, 4, 5, 6, 7}; + ret.push_back({param, TensorLayout{shape4, dtype::Float32()}}); + NonZeroTestcase& case4 = ret.back(); + case4.m_mem.reset(new uint8_t[case4.m_data.layout.span().dist_byte()]); + memset(case4.m_mem.get(), 0, case4.m_data.layout.span().dist_byte()); + case4.m_data.reset_ptr(case4.m_mem.get()); + dt_float32* pt_4 = reinterpret_cast(case4.m_mem.get()); + pt_4[shape4.total_nr_elems() - 1] = 1.0; + case4.correct_answer = {0, 1, 2, 3, 4, 5, 6}; + + TensorShape shape5{2, 2, 2, 2, 2, 2, 2}; + ret.push_back({param, TensorLayout{shape5, dtype::Float32()}}); + NonZeroTestcase& case5 = ret.back(); + case5.m_mem.reset(new uint8_t[case5.m_data.layout.span().dist_byte()]); + memset(case5.m_mem.get(), 0, case5.m_data.layout.span().dist_byte()); + case5.m_data.reset_ptr(case5.m_mem.get()); + dt_float32* pt_5 = reinterpret_cast(case5.m_mem.get()); + pt_5[63] = 1.0; + case5.correct_answer = { + 0, 1, 1, 1, 1, 1, 1, + }; + + return ret; +} + +NonZeroTestcase::Result NonZeroTestcase::run_naive(NonZero* opr) { + auto handle = opr->handle(); + DynOutMallocPolicyImpl malloc_policy(handle); + opr->param() = m_param; + + auto workspace_size = opr->get_workspace_in_bytes(m_data.layout); + auto workspace_ptr = malloc_policy.alloc_workspace(workspace_size, nullptr); + auto result = opr->exec( + m_data, {(dt_byte*)workspace_ptr, workspace_size}, &malloc_policy); + malloc_policy.free_workspace(workspace_ptr, nullptr); + + return result; +} +NonZeroTestcase::CUDAResult NonZeroTestcase::run_cuda(NonZero* opr) { + auto handle = opr->handle(); + DynOutMallocPolicyImpl malloc_policy(handle); + opr->param() = m_param; + auto data = make_tensor_h2d(handle, m_data); + + auto workspace_size = opr->get_workspace_in_bytes(m_data.layout); + auto workspace_ptr = malloc_policy.alloc_workspace(workspace_size, nullptr); + auto result = + opr->exec(*data, {(dt_byte*)workspace_ptr, workspace_size}, &malloc_policy); + malloc_policy.free_workspace(workspace_ptr, nullptr); + + return {make_tensor_d2h(handle, result)}; +} + +void NonZeroTestcase::Assert( + std::vector& correct_answer, int ndim, NonZeroTestcase::Result result) { + dt_int32* data_pt = result.ptr(); + ASSERT_EQ(result.layout.total_nr_elems(), correct_answer.size()); + ASSERT_EQ(ndim, result.layout.shape[0]); + for (size_t ele_idx = 0; ele_idx < result.layout.total_nr_elems(); ele_idx++) { + ASSERT_EQ(data_pt[ele_idx], correct_answer[ele_idx]); + } +} \ No newline at end of file diff --git a/dnn/test/common/non_zero.h b/dnn/test/common/non_zero.h new file mode 100644 index 0000000000000000000000000000000000000000..c25192f8b14f1f32e12ddb14a4f9f4f380356c27 --- /dev/null +++ b/dnn/test/common/non_zero.h @@ -0,0 +1,27 @@ +#pragma once + +#include "./checker.h" +#include "megdnn/oprs.h" + +namespace megdnn { +namespace test { +class NonZeroTestcase { +public: + std::unique_ptr m_mem; + NonZero::Param m_param; + TensorND m_data; + std::vector correct_answer; + + NonZeroTestcase(NonZero::Param param, const TensorLayout& data) + : m_param(param), m_data(nullptr, data) {} + using Result = TensorND; + using CUDAResult = std::shared_ptr; + Result run_naive(NonZero* opr); + CUDAResult run_cuda(NonZero* opr); + static std::vector make(); + static void Assert( + std::vector& correct_answer, int ndim, NonZeroTestcase::Result result); +}; + +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/dnn/test/cuda/non_zero.cpp b/dnn/test/cuda/non_zero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8aeb3fdd65c0757c4ca2a9b2ba058d1243fc9aae --- /dev/null +++ b/dnn/test/cuda/non_zero.cpp @@ -0,0 +1,20 @@ +#include "test/common/non_zero.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/cuda/fixture.h" + +using namespace megdnn; +using namespace test; +TEST_F(CUDA, NONZERO) { + std::vector test_cases = NonZeroTestcase::make(); + auto opr_cuda = handle_cuda()->create_operator(); + auto opr_naive = handle_naive()->create_operator(); + for (NonZeroTestcase& test_case : test_cases) { + NonZeroTestcase::CUDAResult data = test_case.run_cuda(opr_cuda.get()); + NonZeroTestcase::CUDAResult data_naive = test_case.run_cuda(opr_naive.get()); + + std::vector result = test_case.correct_answer; + MEGDNN_ASSERT_TENSOR_EQ(*data, *data_naive); + } +} \ No newline at end of file diff --git a/dnn/test/cuda/where.cpp b/dnn/test/cuda/where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1f3f18368d66d74e95220530a21b154a015fdb3 --- /dev/null +++ b/dnn/test/cuda/where.cpp @@ -0,0 +1,41 @@ +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, WHERE) { + Checker checker(handle_cuda()); + + checker.exect( + Testcase{ + TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}), + TensorValue({1, 2, 2}, dtype::Float32(), {1, 2, 3, 4}), + TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}), + {}}, + Testcase{ + {}, + {}, + {}, + TensorValue({1, 2, 2}, dtype::Float32(), {1, 6, 7, 4})}); +} +TEST_F(CUDA, WHEREBACKWARD) { + Checker checker(handle_cuda()); + + checker.exect( + Testcase{ + TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}), + TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}), + {}, + {}}, + Testcase{ + {}, + {}, + TensorValue({1, 2, 2}, dtype::Float32(), {5, 0, 0, 8}), + TensorValue({1, 2, 2}, dtype::Float32(), {0, 6, 7, 0})}); +} + +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/dnn/test/naive/non_zero.cpp b/dnn/test/naive/non_zero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..053ef4346bb1e76e08b0931abcb9a90cdcb7a2b1 --- /dev/null +++ b/dnn/test/naive/non_zero.cpp @@ -0,0 +1,18 @@ +#include "test/common/non_zero.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +using namespace megdnn; +using namespace test; +TEST_F(NAIVE, NONZERO) { + std::vector test_cases = NonZeroTestcase::make(); + auto opr_naive = handle()->create_operator(); + for (NonZeroTestcase& test_case : test_cases) { + NonZeroTestcase::Result data = test_case.run_naive(opr_naive.get()); + int ndim = test_case.m_data.layout.ndim; + std::vector result = test_case.correct_answer; + NonZeroTestcase::Assert(result, ndim, data); + } +} diff --git a/dnn/test/naive/where.cpp b/dnn/test/naive/where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0925b3213c458081892db044b1292ff1c5f4efa8 --- /dev/null +++ b/dnn/test/naive/where.cpp @@ -0,0 +1,42 @@ +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, WHERE) { + Checker checker(handle()); + + checker.exect( + Testcase{ + TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}), + TensorValue({1, 2, 2}, dtype::Float32(), {1, 2, 3, 4}), + TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}), + {}}, + Testcase{ + {}, + {}, + {}, + TensorValue({1, 2, 2}, dtype::Float32(), {1, 6, 7, 4})}); +} + +TEST_F(NAIVE, WHEREBACKWARD) { + Checker checker(handle()); + + checker.exect( + Testcase{ + TensorValue({1, 2, 2}, dtype::Float32(), {5, 6, 7, 8}), + TensorValue({1, 2, 2}, dtype::Bool(), {true, false, false, true}), + {}, + {}}, + Testcase{ + {}, + {}, + TensorValue({1, 2, 2}, dtype::Float32(), {5, 0, 0, 8}), + TensorValue({1, 2, 2}, dtype::Float32(), {0, 6, 7, 0})}); +} + +} // namespace test +} // namespace megdnn diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py old mode 100755 new mode 100644 index cb9133f81a06bfc6c67e8adcbc4d40b6e0a42974..dc0f9b0caf522c67ccc1d13c1dc1b49e3924ae0c --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -30,6 +30,7 @@ __all__ = [ "broadcast_to", "concat", "cond_take", + "non_zero", "copy", "cumsum", "diag", @@ -821,30 +822,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: return inp -@lru_cache(maxsize=None) -def _get_where_op(dtype=None, device=None): - @subgraph_fn( - "Where", - dtype=dtype, - device=device, - nr_inputs=3, - jit_fusion=True, - custom_grad=True, - ) - def where(inputs, f, c): - (mask, x, y) = inputs[0:3] - oup = f("switch_gt0", mask, x) - ksam = f("-", c(1), mask) - oup = f("+", oup, f("switch_gt0", ksam, y)) - (oup_grad,) = yield (oup,) - x_grad = f("switch_gt0", mask, oup_grad) - y_grad = f("switch_gt0", ksam, oup_grad) - yield (None, x_grad, y_grad) - - return where - - -def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: +def where(mask: Tensor, x: Tensor = None, y: Tensor = None) -> Tensor: r"""Selects elements either from Tensor x or Tensor y, according to mask. .. math:: @@ -870,6 +848,8 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: array([[1., 6.], [7., 4.]], dtype=float32) """ + if x is None and y is None: + return non_zero(mask, as_tuple=True) if not isinstance(x, Tensor): raise TypeError("input x must be a tensor") @@ -882,18 +862,8 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: if x.device != mask.device: raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) - dtype = dtype_promotion(x, y) - device = x.device - - if x.dtype != dtype: - x = x.astype(dtype) - if y.dtype != dtype: - y = y.astype(dtype) - mask = mask.astype(dtype) - - where = _get_where_op(dtype=dtype, device=device) - (oup,) = where(mask, x, y) - return oup + where = builtin.Where() + return apply(where, mask, x, y)[0] def cond_take(mask: Tensor, x: Tensor) -> Tensor: @@ -961,6 +931,42 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: return inp.transpose(pattern) +def non_zero(condition: Tensor, as_tuple=False): + r"""When as_tuple is False (default): + Returns a tensor including the indices of all non-zero elements of Tensor condition. + Every row in the result including the indices of a non-zero element in input. + The result is sorted in lexicography order, with the last index changing the fastest (C-style). + When as_tuple is True: + Returns a tuple of 1-D tensors, one for each dimension in input, + each containing the indices (in that dimension) of all non-zero elements of condition. + Args: + condition(Tensor) - the input tensor + Returns: + one tuple of 1-D tensors or one tensor + + Examples: + >>> import numpy as np + >>> condition = Tensor(np.array([1,1,0,1])) + >>> index = F.non_zero(condition,as_tuple=True) + >>> print(index) + (Tensor([0 1 3], dtype=int32, device=xpux:0),) + """ + + if not isinstance(condition, Tensor): + raise TypeError("input must be a tensor") + op = builtin.NonZero() + (index,) = apply(op, condition) + ret = None + if as_tuple == True: + arr = [] + for index_ele in range(0, condition.ndim): + arr.append(index[index_ele, :]) + ret = tuple(arr) + else: + ret = transpose(index, (1, 0)) + return ret + + def swapaxes(inp: Tensor, axis1: int, axis2: int) -> Tensor: r"""Interchange two axes of a tensor. diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 23d2f36bf23de7046dd08a0f4185628bbee0e5f1..63bd97f068115b9182a399ddc93a8c05ba6e13f1 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -708,6 +708,47 @@ std::optional warp_affine_grad_rule( return imperative::apply(ApplyOp(op), inputs); } +std::optional where_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& where = op.cast_final_safe(); + auto&& param = where.param(); + mgb_assert(inputs.size() == 3); + SmallVector inps; + if (inputs_require_grad[1] || inputs_require_grad[2]) { + inps.push_back(inputs[0]); + } + bool data1_requires_grad = inputs_require_grad[1], + data2_requires_grad = inputs_require_grad[2]; + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([inputs = std::move(inps), data1_requires_grad, data2_requires_grad, + param](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(3); + if (!grad) { + return ret; + } + if (data1_requires_grad == false && data2_requires_grad == false) { + return ret; + } + + auto&& grad_op = WhereBackward::make(); + ValueRefList args_(2); + args_[0] = grads[0]; + args_[1] = inputs[0]; + auto back_grad = imperative::apply(*grad_op, args_); + if (data1_requires_grad) + ret[1] = back_grad[0]; + if (data2_requires_grad) + ret[2] = back_grad[1]; + return ret; + }); + maker.finalize(); + return imperative::apply(op, inputs); +} + struct Init { Init() { CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); @@ -733,6 +774,7 @@ struct Init { BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); CustomBackward::register_grad_rule( WarpAffine::typeinfo(), warp_affine_grad_rule); + CustomBackward::register_grad_rule(Where::typeinfo(), where_grad_rule); } } _; diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index a4a81e3f9e986a9cb578d3d05fd3d495d1709fc5..ffa356c01a9b1ffb216259d0436c952d9652b733 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -128,6 +128,26 @@ def test_condtake(is_varnode): np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) +@pytest.mark.parametrize("as_tuple", [True, False]) +def test_nonzero(as_tuple): + def test_impl(np_condition): + tensor_condition = make_tensor(np_condition, None) + megengine_result = F.non_zero(tensor_condition, as_tuple=as_tuple) + np_result = np.nonzero(np_condition) + if as_tuple == False: + np_result = np.transpose(np_result, (1, 0)) + + for pos in range(len(megengine_result)): + np.testing.assert_equal(megengine_result[pos].numpy(), np_result[pos]) + + test_impl( + np.array([[True, False, True, False, False], [False, True, True, False, False]]) + ) + test_impl(np.random.randint(1, 10, size=[0, 3, 0])) + test_impl(np.random.randint(1, 10, size=[1, 2, 3])) + test_impl(np.random.randint(1, 10, size=[1, 2, 3, 4, 5, 6, 7])) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_concat_stack_device(is_varnode): if is_varnode: diff --git a/imperative/src/impl/ops/non_zero.cpp b/imperative/src/impl/ops/non_zero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f697f8f07a2d2169dc9065331ed10c2796b6186d --- /dev/null +++ b/imperative/src/impl/ops/non_zero.cpp @@ -0,0 +1,55 @@ +#include + +#include "../dnn_op_helper.h" +#include "../op_trait.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/misc.h" + +using namespace megdnn; + +namespace mgb::imperative { + +namespace { + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + + OperatorNodeConfig config{op.make_name()}; + return opr::NonZero::make(inputs[0], {}, config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + mgb_assert(inputs.size() == 1, "NonZero take 1 inputs, got %lu", inputs.size()); + + auto&& condition = inputs[0]; + if (condition->layout().is_empty()) { + // empty tensor + return {Tensor::make( + TensorLayout{{condition->layout().ndim, 0}, dtype::Int32()}, + condition->comp_node())}; + } else { + megdnn::NonZero::Param param; + DnnOprCaller dnn_op(condition->comp_node(), param); + auto&& [out] = dnn_op.exec_dynout<1>(condition); + return {out}; + } +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + LogicalTensorDesc input_0 = inputs[0]; + auto cn = inputs[0].comp_node; + return {{{TensorLayout(dtype::Int32()), cn}}, false}; +} + +OP_TRAIT_REG(NonZero, NonZero, opr::NonZero) + .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .fallback(); + +} // namespace + +} // namespace mgb::imperative diff --git a/imperative/src/impl/ops/where.cpp b/imperative/src/impl/ops/where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72581d536467da4a43bd705422589d7f2b39fbcd --- /dev/null +++ b/imperative/src/impl/ops/where.cpp @@ -0,0 +1,145 @@ +#include "../dnn_op_helper.h" +#include "megbrain/imperative/ops/autogen.h" + +#include "../op_trait.h" + +#include "megbrain/opr/misc.h" +#include "megdnn/oprs/general.h" + +namespace mgb { +namespace imperative { + +namespace { +namespace where { + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& input_descs) { + mgb_assert(input_descs.size() == 3, "Where expects three inputs"); + auto comp_node = input_descs[0].comp_node; + TensorLayout mask = input_descs[0].layout, data1 = input_descs[1].layout, + data2 = input_descs[2].layout; + + mgb_assert(mask.dtype == dtype::Bool(), "mask dtype must be boolean"); + mgb_assert( + data1.dtype == dtype::Float32() || data1.dtype == dtype::Int32() || + data1.dtype == dtype::Bool(), + "data1 dtype must be float32 or int32"); + mgb_assert( + data2.dtype == dtype::Float32() || data2.dtype == dtype::Int32() || + data2.dtype == dtype::Bool(), + "data2 dtype must be float32 or int32"); + + if (!mask.ndim || !data1.ndim || !data2.ndim) { + return {{{TensorLayout{data1.dtype}, comp_node, {}}}, false}; + } + + if (!mask.is_empty()) + mgb_assert(mask.is_contiguous(), "mask should be contiguous"); + if (!data1.is_empty()) + mgb_assert(data1.is_contiguous(), "data1 should be contiguous"); + if (!data2.is_empty()) + mgb_assert(data2.is_contiguous(), "data2 should be contiguous"); + + mgb_assert(mask.eq_shape(data1), "mask shape doesn't match data1"); + mgb_assert(mask.eq_shape(data2), "mask shape doesn't match data2"); + mgb_assert(data1.eq_layout(data2), "data1 layout doesn't match data2"); + + TensorLayout dst = data1; + dst.init_contiguous_stride(); + return {{{dst, comp_node}}, true}; +} + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(inputs.size() == 3); + OperatorNodeConfig config{op.make_name()}; + return opr::Where::make(inputs[0], inputs[1], inputs[2], {}, config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, SmallVector inputs, + SmallVector& output_descs, const bool& validatad) { + auto&& mask = inputs[0]; + auto&& data1 = inputs[1]; + auto&& data2 = inputs[2]; + auto&& mask_layout = mask->layout(); + auto&& data1_layout = data1->layout(); + auto&& data2_layout = data2->layout(); + DnnOprCaller dnn_op(mask->comp_node()); + auto tlayout = dnn_op.deduce_layout(mask_layout, data1_layout, data2_layout); + auto out = Tensor::make(tlayout, mask->comp_node()); + if (!mask_layout.is_empty()) + dnn_op.exec_with_ws(mask, data1, data2, out); + return {out}; +} + +OP_TRAIT_REG(Where, Where) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); +} // namespace where + +namespace where_backward { + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& input_descs) { + mgb_assert(input_descs.size() == 2, "WhereBackward expects two inputs"); + auto comp_node = input_descs[0].comp_node; + TensorLayout diff = input_descs[0].layout, mask = input_descs[1].layout; + + mgb_assert( + diff.dtype == dtype::Float32() || diff.dtype == dtype::Int32(), + "diff dtype must be float32 or int32"); + mgb_assert(mask.dtype == dtype::Bool(), "mask dtype must be boolean"); + + if (!diff.ndim || !mask.ndim) { + return {{{diff, comp_node}}, false}; + } + + if (!diff.is_empty()) + mgb_assert(diff.is_contiguous(), "diff should be contiguous"); + if (!mask.is_empty()) + mgb_assert(mask.is_contiguous(), "mask should be contiguous"); + + mgb_assert(diff.eq_shape(mask), "diff shape doesn't match mask"); + + TensorLayout data1 = diff; + data1.init_contiguous_stride(); + TensorLayout data2 = diff; + data2.init_contiguous_stride(); + return {{{data1, comp_node}, {data2, comp_node}}, true}; +} + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{op.make_name()}; + return opr::WhereBackward::make(inputs[0], inputs[1], {}, config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, SmallVector inputs, + SmallVector& output_descs, const bool& validated) { + auto&& diff = inputs[0]; + auto&& mask = inputs[1]; + auto&& diff_layout = diff->layout(); + auto&& mask_layout = mask->layout(); + DnnOprCaller dnn_op(diff->comp_node()); + auto tlayouts = dnn_op.deduce_layouts<2>(diff_layout, mask_layout); + auto grad1 = Tensor::make(tlayouts.at(0), diff->comp_node()); + auto grad2 = Tensor::make(tlayouts.at(1), diff->comp_node()); + dnn_op.exec_with_ws(diff, mask, grad1, grad2); + return {grad1, grad2}; +} + +OP_TRAIT_REG(WhereBackward, WhereBackward) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); +} // namespace where_backward + +} // anonymous namespace +} // namespace imperative +} // namespace mgb \ No newline at end of file diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 26e3b379a2af726711b84a1c96642b5c1848d984..a5ef4a08a19b77e07f038847c155d6702b0f3b91 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -404,6 +404,23 @@ ValueRefList setsubtensor_rule(const OpDef& op, Span inputs) { return imperative::apply(op, converted); } +ValueRefList where_rule(const OpDef& op, Span inputs) { + SmallVector dtypes = get_value_dtypes({inputs.begin() + 1, inputs.end()}); + mgb::DType target_dtype = get_promoted_dtype(dtypes); + + ValueRefList converted(inputs.size()); + converted[0] = inputs[0]; + for (int idx = 1; idx < inputs.size(); idx++) { + if (*(inputs[idx].dtype()) != target_dtype) { + converted[idx] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[idx])[0]; + } else { + converted[idx] = inputs[idx]; + } + } + + return imperative::apply(op, converted); +} struct DTypePromoteRuleRegistry { DTypePromoteRuleRegistry() { @@ -424,6 +441,7 @@ struct DTypePromoteRuleRegistry { register_dtype_promote_rule(norm_rule); register_dtype_promote_rule(setsubtensor_rule); register_dtype_promote_rule(setsubtensor_rule); + register_dtype_promote_rule(where_rule); } } register_helper; diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index e0fd5cf799818e47c64e0f20dffddc896c4fbb12..1a52e5d93caec7754fbb07c6d8bf906413c89d2b 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -304,6 +304,7 @@ struct ScalarRuleRegistry { register_scalar_rule>(); register_scalar_rule>(); register_scalar_rule>(); + register_scalar_rule>(); } } _; } // namespace diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index f301093af061d48f23b2ee84a1db2876b1f89fb0..60d5508c16c4b8fb4c63267b254b77c580d05e86 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ e4035bfefce3a2cc0e8cc6ec7fcac227 ../../dnn/scripts/opr_param_defs.py -13ab898fce3749ebbcabf7c145876147 ../../src/core/include/megbrain/ir/ops.td -9dda6e2db75279373ec6809b297a2370 generated/opdef.h.inl -aabc2d8146742faacabf56e376177e7b generated/opdef.cpp.inl -8a5dffac1df3286178b3fd304c39b5da generated/opdef.py.inl -04322b642bba8f684034fcc5dc27efcf generated/opdef.cpy.inl +5ed571605e6f376c8801611c573707e3 ../../src/core/include/megbrain/ir/ops.td +28acc4b2c91ecbe3e0e27c1c45e5bc0c generated/opdef.h.inl +d554e79f054cb152251e88a41621b524 generated/opdef.cpp.inl +132afb96c40ae64cbab73b89aa00a844 generated/opdef.py.inl +87da68c8a60f965088e1c32c38939195 generated/opdef.cpy.inl 911001ef0dd771024919f7a1a3a009db generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index d20aaa57d80f5ab5cfcd98c82b97e187c9d7b55d..259bb508a4dc974710ad422400190feffc1f44d6 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -5426,6 +5426,40 @@ OP_TRAIT_REG(NMSKeep, NMSKeep) .props(NMSKeep_props_impl) .make_name(NMSKeep_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(NonZero); + +namespace { +size_t NonZero_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + return val; +} +bool NonZero_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + return true; +} +std::vector> NonZero_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + return props_; +} +std::string NonZero_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "NonZero"; +} +} // anonymous namespace +OP_TRAIT_REG(NonZero, NonZero) + .hash(NonZero_hash_impl) + .is_same_st(NonZero_is_same_st_impl) + .props(NonZero_props_impl) + .make_name(NonZero_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf); namespace { @@ -8333,4 +8367,72 @@ OP_TRAIT_REG(WarpPerspectiveBackwardMat, WarpPerspectiveBackwardMat) .props(WarpPerspectiveBackwardMat_props_impl) .make_name(WarpPerspectiveBackwardMat_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Where); + +namespace { +size_t Where_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + return val; +} +bool Where_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + return true; +} +std::vector> Where_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + return props_; +} +std::string Where_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "Where"; +} +} // anonymous namespace +OP_TRAIT_REG(Where, Where) + .hash(Where_hash_impl) + .is_same_st(Where_is_same_st_impl) + .props(Where_props_impl) + .make_name(Where_make_name_impl); + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(WhereBackward); + +namespace { +size_t WhereBackward_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + return val; +} +bool WhereBackward_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + return true; +} +std::vector> WhereBackward_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + return props_; +} +std::string WhereBackward_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "WhereBackward"; +} +} // anonymous namespace +OP_TRAIT_REG(WhereBackward, WhereBackward) + .hash(WhereBackward_hash_impl) + .is_same_st(WhereBackward_is_same_st_impl) + .props(WhereBackward_props_impl) + .make_name(WhereBackward_make_name_impl); + // clang-format on diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 3889dcb8b2bf21080c977edacab8ca30a57de47d..698a9adb07c06cdd51f13bfd7d09ae50a45090c9 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -16010,6 +16010,88 @@ void _init_py_NMSKeep(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(NMSKeep::typeinfo(), &py_type).second); } +PyOpDefBegin(NonZero) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(NonZero) + +int PyOp(NonZero)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + + return 0; +} + +PyGetSetDef PyOp(NonZero)::py_getsetters[] = { + + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(NonZero)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(NonZero)::getstate, METH_NOARGS, "NonZero getstate"}, + {const_cast("__setstate__"), PyOp(NonZero)::setstate, METH_VARARGS, "NonZero setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(NonZero)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(NonZero)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(NonZero)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(NonZero)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self) -> None\n" +}; + +void _init_py_NonZero(py::module m) { + using py_op = PyOp(NonZero); + auto& py_type = PyOpType(NonZero); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.NonZero"; + py_type.tp_basicsize = sizeof(PyOp(NonZero)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "NonZero"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(NonZero), &PyOp(NonZero)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("NonZero", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(NonZero::typeinfo(), &py_type).second); +} + PyOpDefBegin(NvOf) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -23534,6 +23616,170 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { m.add_object("WarpPerspectiveBackwardMat", reinterpret_cast(&py_type)); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspectiveBackwardMat::typeinfo(), &py_type).second); } + +PyOpDefBegin(Where) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(Where) + +int PyOp(Where)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + + return 0; +} + +PyGetSetDef PyOp(Where)::py_getsetters[] = { + + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(Where)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(Where)::getstate, METH_NOARGS, "Where getstate"}, + {const_cast("__setstate__"), PyOp(Where)::setstate, METH_VARARGS, "Where setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(Where)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(Where)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(Where)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(Where)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self) -> None\n" +}; + +void _init_py_Where(py::module m) { + using py_op = PyOp(Where); + auto& py_type = PyOpType(Where); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.Where"; + py_type.tp_basicsize = sizeof(PyOp(Where)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "Where"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(Where), &PyOp(Where)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("Where", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Where::typeinfo(), &py_type).second); +} + +PyOpDefBegin(WhereBackward) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(WhereBackward) + +int PyOp(WhereBackward)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + + return 0; +} + +PyGetSetDef PyOp(WhereBackward)::py_getsetters[] = { + + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(WhereBackward)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(WhereBackward)::getstate, METH_NOARGS, "WhereBackward getstate"}, + {const_cast("__setstate__"), PyOp(WhereBackward)::setstate, METH_VARARGS, "WhereBackward setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(WhereBackward)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(WhereBackward)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(WhereBackward)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(WhereBackward)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self) -> None\n" +}; + +void _init_py_WhereBackward(py::module m) { + using py_op = PyOp(WhereBackward); + auto& py_type = PyOpType(WhereBackward); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.WhereBackward"; + py_type.tp_basicsize = sizeof(PyOp(WhereBackward)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "WhereBackward"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(WhereBackward), &PyOp(WhereBackward)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("WhereBackward", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WhereBackward::typeinfo(), &py_type).second); +} #define INIT_ALL_OP(m) \ _init_py_AdaptivePooling(m); \ _init_py_AddAxis(m); \ @@ -23614,6 +23860,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_MeshIndexing(m); \ _init_py_MultiHeadAttn(m); \ _init_py_NMSKeep(m); \ + _init_py_NonZero(m); \ _init_py_NvOf(m); \ _init_py_Padding(m); \ _init_py_ParamPackConcat(m); \ @@ -23654,5 +23901,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_WarpAffine(m); \ _init_py_WarpPerspective(m); \ _init_py_WarpPerspectiveBackwardData(m); \ - _init_py_WarpPerspectiveBackwardMat(m); + _init_py_WarpPerspectiveBackwardMat(m); \ + _init_py_Where(m); \ + _init_py_WhereBackward(m); // clang-format on diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index 53885921f66d1b5ead6e5a88c371aee0b421419f..1501f2bcc70979cae26df643e8c38de560dcda7d 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1457,6 +1457,17 @@ public: NMSKeep(float iou_thresh_, uint32_t max_output_, std::string scope_ = {}): iou_thresh(iou_thresh_), max_output(max_output_) { set_scope(scope_); } }; +class NonZero : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + NonZero() = default; + NonZero(::megdnn::param::Empty) {} + ::megdnn::param::Empty param() const { + return {}; + } +}; + class NvOf : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; @@ -2093,4 +2104,26 @@ public: } }; +class Where : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + Where() = default; + Where(::megdnn::param::Empty) {} + ::megdnn::param::Empty param() const { + return {}; + } +}; + +class WhereBackward : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + WhereBackward() = default; + WhereBackward(::megdnn::param::Empty) {} + ::megdnn::param::Empty param() const { + return {}; + } +}; + // clang-format on diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 39fc0e8dd3a1ba23cd9d93fceef0fada2bd6c19e..69d29ef42323fa084e30309ce6a7ba92d0b3f568 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1553,6 +1553,11 @@ NMSKeepInst .def_readwrite("iou_thresh", &NMSKeep::iou_thresh) .def_readwrite("max_output", &NMSKeep::max_output); +py::class_, OpDef> NonZeroInst(m, "NonZero"); + +NonZeroInst + .def(py::init<>()); + py::class_, OpDef> NvOfInst(m, "NvOf"); NvOfInst @@ -2128,4 +2133,14 @@ WarpPerspectiveBackwardMatInst .def_readwrite("format", &WarpPerspectiveBackwardMat::format) .def_readwrite("border_val", &WarpPerspectiveBackwardMat::border_val); +py::class_, OpDef> WhereInst(m, "Where"); + +WhereInst + .def(py::init<>()); + +py::class_, OpDef> WhereBackwardInst(m, "WhereBackward"); + +WhereBackwardInst + .def(py::init<>()); + // clang-format on diff --git a/src/core/impl/comp_node_api.cpp b/src/core/impl/comp_node_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f383a5cb460dc31f983959d4d404a492a96b2562 --- /dev/null +++ b/src/core/impl/comp_node_api.cpp @@ -0,0 +1,95 @@ +#include "megbrain/comp_node_api.h" +#include +#include "megbrain/comp_node.h" +#include "megbrain/comp_node_env.h" +#include "megbrain_build_config.h" + +namespace mgb { +namespace pubapi { + +std::unordered_map& cn_cache() { + static std::unordered_map cn_map; + return cn_map; +} + +class CompNodeDepedentObjectInst final : public CompNodeDepedentObject { + std::shared_ptr on_comp_node_finalize() override { return {}; } + +public: + bool is_finalized() const { return CompNodeDepedentObject::is_finalized(); } +}; + +bool is_finalize() { + static CompNodeDepedentObjectInst* obj = new CompNodeDepedentObjectInst; + return obj->is_finalized(); +} + +void sync(mgbComputeNode_t cn) { + if (!is_finalize()) { + auto* s = reinterpret_cast(cn); + if (s->valid()) + s->sync(); + } +} + +mgbComputeNode_t load_cuda_cn(int device_id, int stream) { + std::string loc = ssprintf("gpu%i:%i", device_id, stream); + mgb_assert(!is_finalize()); + auto& cache = cn_cache(); + if (cache.find(loc) == cache.end()) { + auto* cn = new mgb::CompNode; + (*cn) = mgb::CompNode::load(loc); + mgb_assert(cn->to_string_physical() == loc); + cache[loc] = cn; + cn->activate(); + } + return reinterpret_cast(cache[loc]); +} + +void unload_cuda_cn(mgbComputeNode_t cn) { + auto* device = reinterpret_cast(cn); + auto& cache = cn_cache(); + mgb_assert( + cache.find(device->to_string_physical()) != cache.end() && + device == cache[device->to_string_physical()]); + cache.erase(device->to_string_physical()); + delete device; +} + +void* alloc(mgbComputeNode_t device, size_t s) { + if (s == 0) + return nullptr; + auto* cn = reinterpret_cast(device); + mgb_assert(!is_finalize()); + return cn->alloc_device(s); +} + +void dealloc(mgbComputeNode_t device, void* addr) { + if (addr != nullptr) { + auto* cn = reinterpret_cast(device); + if (!is_finalize()) { + cn->free_device(addr); + } + } +} + +void* get_cuda_stream(mgbComputeNode_t device) { + void* rst = nullptr; +#if MGB_CUDA + auto* cn = reinterpret_cast(device); + MGB_TRY { rst = CompNodeEnv::from_comp_node(*cn).cuda_env().stream; } + MGB_CATCH(MegBrainError & exc, { + mgb_log_error("failed to get stream: %s", exc.what()); + }) +#else + mgb_log_error("megbrain compiled without cuda support!"); +#endif + return rst; +} + +MGB_API DeviceLocator get_physical_location(mgbComputeNode_t device) { + auto location = reinterpret_cast(device)->locator().to_physical(); + return {location.device, location.stream}; +} +} // namespace pubapi +} // namespace mgb diff --git a/src/core/include/megbrain/comp_node_api.h b/src/core/include/megbrain/comp_node_api.h new file mode 100644 index 0000000000000000000000000000000000000000..c88ff6155d3b521be750f0ecd3a8cba407e30b04 --- /dev/null +++ b/src/core/include/megbrain/comp_node_api.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#if defined(_WIN32) +#define MGB_API __declspec(dllexport) +#else +#define MGB_API __attribute__((visibility("default"))) +#endif +namespace mgb { +namespace pubapi { + +typedef struct _MgbComputeNode* mgbComputeNode_t; +struct DeviceLocator { + int device = -1; + int stream = -1; +}; + +MGB_API mgbComputeNode_t load_cuda_cn(int device_id, int stream_id); +MGB_API void unload_cuda_cn(mgbComputeNode_t); +MGB_API void* alloc(mgbComputeNode_t cn, size_t); +MGB_API void dealloc(mgbComputeNode_t cn, void* addr); +MGB_API void* get_cuda_stream(mgbComputeNode_t cn); +MGB_API DeviceLocator get_physical_location(mgbComputeNode_t); +MGB_API void sync(mgbComputeNode_t cn); +MGB_API bool is_finalize(); +} // namespace pubapi +} // namespace mgb diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index d32e3537f2ad45929f2893aac0d2c3aa3d03b70b..8acbb10b34caa3ee9d2a4a1cb725557bf12f31b5 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -153,6 +153,8 @@ def Argmin : MgbHashableOp<"Argmin", [AxisParam]>; def CondTake : MgbHashableOp<"CondTake">; +def NonZero: MgbHashableOp<"NonZero",[EmptyParam]>; + def TopK: MgbHashableOp<"TopK", [TopKParam]>; def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; @@ -640,4 +642,9 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { def Cross: MgbHashableOp<"Cross", [CrossParam]>; +def Where: MgbHashableOp<"Where", [EmptyParam]>; + +def WhereBackward: MgbHashableOp<"WhereBackward", [EmptyParam]>; + + #endif // MGB_OPS diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index c95854144c40bbb8b9fd6527e2bc7ab2be1e6fcb..ac4a8f54728a9e6b9312064adf2fe34bdc67ba2e 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -385,6 +385,71 @@ void CondTake::scn_do_execute() { } } +/* ================= NonZero ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(NonZero); + +NonZero::NonZero(VarNode* data, const Param& param, const OperatorNodeConfig& config) + : Super(data->owner_graph(), config, "NonZero", {data}) { + init_megdnn_opr(*this, param); + add_input({data}); + auto dtype = megdnn_opr()->infer_type(data->dtype()); + output(0) + ->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .dtype(dtype); +} + +NonZero::NodeProp* NonZero::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + +SymbolVar NonZero::make( + SymbolVar data, const Param& param, const OperatorNodeConfig& config) { + auto ret = data.insert_single_output_opr(data.node(), param, config); + return ret; +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(NonZero) { + MGB_MARK_USED_VAR(out_grad); + MGB_MARK_USED_VAR(opr); + mgb_assert(!wrt_idx); + return nullptr; +} +#endif + +void NonZero::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto infer_workspace_shape = [this](TensorShape& dest, const InpVal& iv) { + auto dtype = this->input(0)->dtype(); + TensorLayout ily(iv.val[0].shape(), dtype); + dest.ndim = 1; + dest.shape[0] = this->megdnn_opr()->get_workspace_in_bytes(ily); + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(1), + {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace_shape}); +} + +void NonZero::add_input_layout_constraint() { + mixin::megdnn_utils::add_input_layout_constraint_contig(*this); +} + +void NonZero::scn_do_execute() { + auto&& data = input(0)->dev_tensor(); + intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()}; + if (data.layout().is_empty()) { + dyn_malloc.alloc_output(0, dtype::Int32(), {data.layout().ndim, 0}, nullptr); + } else { + megdnn_opr()->exec( + data.as_megdnn(), intl::get_megdnn_workspace_from_var(output().back()), + &dyn_malloc); + } +} + /* ================= TopK ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(TopK); @@ -605,4 +670,158 @@ void CheckNonFinite::add_input_layout_constraint() { i->add_layout_constraint_contiguous(); } } + +/* ================= Where ================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(WhereForward); + +WhereForward::WhereForward( + VarNode* mask, VarNode* data1, VarNode* data2, const Param& param, + const OperatorNodeConfig& config) + : Super(mask->owner_graph(), config, "where", {mask, data1, data2}) { + init_megdnn_opr(*this, param); + mgb_assert(mask->shape().eq_shape(data1->shape())); + mgb_assert(mask->shape().eq_shape(data2->shape())); + mgb_assert(data1->shape().eq_shape(data2->shape())); + add_input({mask, data1, data2}); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE).dtype(data1->dtype()); +} + +SymbolVar WhereForward::make( + SymbolVar mask, SymbolVar data1, SymbolVar data2, const Param& param, + const OperatorNodeConfig& config) { + return mask.insert_single_output_opr( + mask.node(), data1.node(), data2.node(), param, config); +} + +void WhereForward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto infer_shape = [](TensorShape& dest, const InpVal& iv) { + auto ishp = iv.val[0].shape(); + dest = ishp; + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape}); + + auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) { + auto mask_dtype = input(0)->dtype(); + auto mask_shape = iv.val[0].shape(); + TensorLayout mask_layout(mask_shape, mask_dtype); + auto data_dtype = input(1)->dtype(); + auto data_shape = iv.val[1].shape(); + TensorLayout data_layout(data_shape, data_dtype); + dest.ndim = 1; + dest[0] = megdnn_opr()->get_workspace_in_bytes( + mask_layout, data_layout, data_layout, data_layout); + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(1), {SourceType::DEP, + {{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}, + infer_workspace}); +} + +void WhereForward::add_input_layout_constraint() { + mixin::megdnn_utils::add_input_layout_constraint_contig(*this); +} + +void WhereForward::scn_do_execute() { + if (input(0)->dev_tensor().empty()) { + mgb_assert(output(0)->dev_tensor().empty()); + return; + } + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); +} +MAKE_NODE_PROP_WITH_ZERO_SHAPE_3(WhereForward, 0, 1, 2); + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(WhereForward) { + mgb_assert(out_grad.size() == 2 && !out_grad[1]); + VarNodeArray ret; + SymbolVarArray grad; + grad = WhereBackward::make(out_grad[0], opr.input(0), opr.param()); + if (wrt_idx == 0) + return nullptr; + else if (wrt_idx == 1) + return grad[0].node(); + else if (wrt_idx == 2) + return grad[1].node(); + return nullptr; +} +#endif + +/* ================= WhereBackward ================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(WhereBackward); + +WhereBackward::WhereBackward( + VarNode* diff, VarNode* mask, const Param& param, + const OperatorNodeConfig& config) + : Super(diff->owner_graph(), config, "WhereBackward", {diff, mask}) { + init_megdnn_opr(*this, param); + mgb_assert(diff->shape().eq_shape(mask->shape())); + add_input({diff, mask}); + output(0)->dtype(diff->dtype()); + output(1)->dtype(diff->dtype()); +} + +SymbolVarArray WhereBackward::make( + SymbolVar diff, SymbolVar mask, const Param& param, + const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), mask.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) + ret.emplace_back(out); + return ret; +} + +void WhereBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto infer_shape = [](TensorShape& dest, const InpVal& iv) { + auto ishp = iv.val[0].shape(); + dest = ishp; + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape}); + owner_graph()->static_infer_manager().register_shape_infer( + output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape}); + + auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) { + auto diff_dtype = input(0)->dtype(); + auto diff_shape = iv.val[0].shape(); + TensorLayout diff_layout(diff_shape, diff_dtype); + auto mask_dtype = input(1)->dtype(); + auto mask_shape = iv.val[1].shape(); + TensorLayout mask_layout(mask_shape, mask_dtype); + dest.ndim = 1; + dest[0] = megdnn_opr()->get_workspace_in_bytes( + diff_layout, mask_layout, diff_layout, diff_layout); + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(2), {SourceType::DEP, + {{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}, + infer_workspace}); +} + +void WhereBackward::add_input_layout_constraint() { + mixin::megdnn_utils::add_input_layout_constraint_contig(*this); +} + +void WhereBackward::scn_do_execute() { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index b06beecb260c73e7629bb0460eea8a9f6b4c5673..819c1fc8444bd3219296728415f7b193f09fa867 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -1,3 +1,4 @@ +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/misc.h" #include "megbrain/serialization/sereg.h" @@ -31,6 +32,19 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Opr = opr::NonZero; + using Param = Opr::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + auto out = Opr::make(inputs[0], param, config); + return out.node()->owner_opr(); + } +}; + template <> struct OprMaker { using Opr = opr::TopK; @@ -57,8 +71,66 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Opr = opr::Where; + using Param = Opr::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + auto out = Opr::make(inputs[0], inputs[1], inputs[2], param, config); + return out.node()->owner_opr(); + } +}; + +template <> +struct OprMaker { + using Opr = opr::WhereBackward; + using Param = opr::WhereBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + auto out = Opr::make(inputs[0], inputs[1], param, config); + return out[0].node()->owner_opr(); + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::Where; + using Mode = opr::Elemwise::Mode; + using PersisParam = opr::Where::Param; + using PersisWhereParam = opr::Where::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto mask = SymbolVar(inputs[0]); + auto x = SymbolVar(inputs[1]); + mask = opr::TypeCvt::make(mask, x.dtype()); + auto y = SymbolVar(inputs[2]); + auto oup = opr::Elemwise::make({mask, x}, Mode::SWITCH_GT0); + auto ksam = 1.0f - mask; + oup = oup + opr::Elemwise::make({ksam, y}, Mode::SWITCH_GT0); + return oup.node()->owner_opr(); + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; + } // namespace serialization +// OprMaker in MGB_SEREG_OPR only support unique output opr + namespace opr { MGB_SEREG_OPR(Argmax, 1); @@ -66,7 +138,14 @@ MGB_SEREG_OPR(Argmin, 1); MGB_SEREG_OPR(Argsort, 1); MGB_SEREG_OPR(ArgsortBackward, 3); MGB_SEREG_OPR(CondTake, 2); +MGB_SEREG_OPR(NonZero, 1); MGB_SEREG_OPR(TopK, 2); +MGB_SEREG_OPR_CONDITION(Where, 3, false); +MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( + Where, 3, (mgb::serialization::OprLoadDumpImplV2::replace_opr), + VERSION_1, VERSION_1); +MGB_SEREG_OPR(WhereBackward, 4) + //! current cumsum version using CumsumV1 = opr::Cumsum; MGB_SEREG_OPR(CumsumV1, 1); diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index 66a76824250d1e30dcfcbc84b3b89683014ebdb8..231c89bfca039545da56f9a2cddfbafcc666ec4d 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -142,6 +142,9 @@ using TopKBase = cg::SingleCNOperatorNode< cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl>; using CheckNonFiniteBase = cg::SingleCNOperatorNode< cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl>; + +using NonZeroBase = cg::SingleCNOperatorNode< + cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl>; } // namespace intl /*! @@ -162,6 +165,18 @@ public: SymbolVar data, SymbolVar mask, const Param& param, const OperatorNodeConfig& config = {}); }; +MGB_DEFINE_OPR_CLASS_WITH_EXPORT(NonZero, intl::NonZeroBase) // { + void init_output_static_infer_desc() override; + void scn_do_execute() override; + void add_input_layout_constraint() override; + NodeProp* do_make_node_prop() const override; + +public: + MGE_WIN_DECLSPEC_FUC NonZero( + VarNode* data, const Param& param, const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar data, const Param& param, const OperatorNodeConfig& config = {}); +}; MGB_DEFINE_OPR_CLASS_WITH_EXPORT(TopK, intl::TopKBase) // { void init_output_dtype() override; @@ -196,6 +211,40 @@ public: const OperatorNodeConfig& config = {}); }; +MGB_DEFINE_OPR_CLASS( + WhereForward, intl::MegDNNOprWrapperFwd) // { + void scn_do_execute() override; + void init_output_static_infer_desc() override; + void add_input_layout_constraint() override; + NodeProp* do_make_node_prop() const override; + +public: + MGE_WIN_DECLSPEC_FUC WhereForward( + VarNode* mask, VarNode* data1, VarNode* data2, const Param& param, + const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar mask, SymbolVar data1, SymbolVar data2, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; +using Where = WhereForward; + +MGB_DEFINE_OPR_CLASS( + WhereBackward, intl::MegDNNOprWrapperFwd) // { + void scn_do_execute() override; + void init_output_static_infer_desc() override; + void add_input_layout_constraint() override; + +public: + MGE_WIN_DECLSPEC_FUC WhereBackward( + VarNode* diff, VarNode* mask, const Param& param, + const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar mask, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; + } // namespace opr } // namespace mgb diff --git a/src/opr/test/misc.cpp b/src/opr/test/misc.cpp index e5964108b4ec529a303a0f69c50d56470efc9091..38c428ac4927843f1a669b2e5ee89c3fb532eb49 100644 --- a/src/opr/test/misc.cpp +++ b/src/opr/test/misc.cpp @@ -262,6 +262,25 @@ TEST(TestOprMisc, CondTakeEmptyIO) { check({1, 0}); } +TEST(TestOprMisc, NonZero) { + using Param = opr::NonZero::Param; + Param param; + HostTensorGenerator<> gen; + auto check = [&](const TensorShape& shp) { + auto host_data = gen(shp); + auto graph = ComputingGraph::make(); + auto data = opr::Host2DeviceCopy::make(*graph, host_data); + auto out = opr::NonZero::make(data, param); + HostTensorND host_out; + auto func = graph->compile({make_callback_copy(out, host_out)}); + func->execute(); + }; + check({1}); + check({0}); + check({1, 0}); + check({1, 2, 3, 4, 5, 6, 7}); +} + TEST(TestOprMisc, TopKValueOnly) { auto run = [](bool dyn_k, bool non_contig) { using Checker = AutoOprChecker<1, 1>; @@ -414,4 +433,57 @@ TEST(TestOprMisc, TopKGrad) { EXPECT_TRUE(gk == nullptr); } +TEST(TestOprMisc, Where) { + auto graph = ComputingGraph::make(); + HostTensorGenerator gen_mask; + auto host_mask = gen_mask({2, 2, 2}); + HostTensorGenerator<> gen_data{0, 1000}; + auto host_data1 = gen_data({2, 2, 2}), host_data2 = gen_data({2, 2, 2}); + auto mask = opr::Host2DeviceCopy::make(*graph, host_mask), + data1 = opr::Host2DeviceCopy::make(*graph, host_data1), + data2 = opr::Host2DeviceCopy::make(*graph, host_data2), + dst = opr::Where::make(mask, data1, data2, {}); + + HostTensorND host_dst; + auto func = graph->compile({make_callback_copy(dst, host_dst)}); + func->execute(); + + auto pmask = host_mask->ptr(); + auto pdata1 = host_data1->ptr(); + auto pdata2 = host_data2->ptr(); + auto pdst = host_dst.ptr(); + for (size_t i = 0; i < host_mask->layout().total_nr_elems(); ++i) { + ASSERT_EQ(pmask[i] ? pdata1[i] : pdata2[i], pdst[i]); + } +} + +TEST(TestOprMisc, WhereBackward) { + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen_diff{0, 1000}; + auto host_diff = gen_diff({2, 2, 2}); + HostTensorGenerator gen_mask; + auto host_mask = gen_mask({2, 2, 2}); + auto diff = opr::Host2DeviceCopy::make(*graph, host_diff), + mask = opr::Host2DeviceCopy::make(*graph, host_mask); + auto grads = opr::WhereBackward::make(diff, mask, {}); + auto grad1 = grads[0]; + auto grad2 = grads[1]; + + HostTensorND host_grad1; + HostTensorND host_grad2; + auto func = graph->compile( + {make_callback_copy(grad1, host_grad1), + make_callback_copy(grad2, host_grad2)}); + func->execute(); + + auto pdiff = host_diff->ptr(); + auto pmask = host_mask->ptr(); + auto pgrad1 = host_grad1.ptr(); + auto pgrad2 = host_grad2.ptr(); + for (size_t i = 0; i < host_diff->layout().total_nr_elems(); ++i) { + ASSERT_EQ(pmask[i] ? pdiff[i] : 0, pgrad1[i]); + ASSERT_EQ(pmask[i] ? 0 : pdiff[i], pgrad2[i]); + } +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}