From 538b0863083bc041543d024418ad9a9c8b7affb3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 8 Nov 2022 15:35:03 +0800 Subject: [PATCH] feat(opr): add masked_fill op GitOrigin-RevId: 47cd068b9e2220448c2c2735001ce6416f06252f --- ci/compatibility/fbs/V2-backup/schema_v2.fbs | 1 + dnn/include/megdnn/oprs/general.h | 25 +++++ dnn/src/common/handle_impl.h | 3 +- dnn/src/common/masked_fill.cpp | 32 ++++++ dnn/src/common/opr_trait.h | 1 + dnn/src/cuda/handle_create.cpp | 2 + dnn/src/cuda/masked_fill/kern.cu | 14 +++ dnn/src/cuda/masked_fill/kern.cuh | 47 ++++++++ dnn/src/cuda/masked_fill/opr_impl.cpp | 31 ++++++ dnn/src/cuda/masked_fill/opr_impl.h | 17 +++ dnn/src/naive/handle.cpp | 1 + dnn/src/naive/masked_fill/opr_impl.cpp | 54 +++++++++ dnn/src/naive/masked_fill/opr_impl.h | 17 +++ dnn/test/cuda/masked_fill.cpp | 41 +++++++ dnn/test/naive/masked_fill.cpp | 50 +++++++++ imperative/python/src/tensor_utils.cpp | 23 ++++ imperative/src/impl/ops/tensor_manip.cpp | 58 ++++++++++ imperative/tablegen/generated/hash.txt | 10 +- imperative/tablegen/generated/opdef.cpp.inl | 37 +++++++ imperative/tablegen/generated/opdef.cpy.inl | 110 +++++++++++++++++++ imperative/tablegen/generated/opdef.h.inl | 13 +++ imperative/tablegen/generated/opdef.py.inl | 6 + src/core/include/megbrain/ir/ops.td | 1 + src/opr/impl/tensor_manip.cpp | 22 ++++ src/opr/impl/tensor_manip.sereg.h | 1 + src/opr/include/megbrain/opr/tensor_manip.h | 12 ++ src/serialization/impl/schema.fbs | 1 + src/serialization/impl/schema_v2.fbs | 1 + 28 files changed, 625 insertions(+), 6 deletions(-) create mode 100644 dnn/src/common/masked_fill.cpp create mode 100644 dnn/src/cuda/masked_fill/kern.cu create mode 100644 dnn/src/cuda/masked_fill/kern.cuh create mode 100644 dnn/src/cuda/masked_fill/opr_impl.cpp create mode 100644 dnn/src/cuda/masked_fill/opr_impl.h create mode 100644 dnn/src/naive/masked_fill/opr_impl.cpp create mode 100644 dnn/src/naive/masked_fill/opr_impl.h create mode 100644 dnn/test/cuda/masked_fill.cpp create mode 100644 dnn/test/naive/masked_fill.cpp diff --git a/ci/compatibility/fbs/V2-backup/schema_v2.fbs b/ci/compatibility/fbs/V2-backup/schema_v2.fbs index f6f0732c6..5e458598e 100644 --- a/ci/compatibility/fbs/V2-backup/schema_v2.fbs +++ b/ci/compatibility/fbs/V2-backup/schema_v2.fbs @@ -141,6 +141,7 @@ union OperatorParam { param.Softmax = 90, param.Diag = 91, param.GroupNorm = 92, + param.Fill = 93, } table Operator { diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index b0301aa66..9e25664c0 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1392,6 +1392,31 @@ protected: void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; +class MaskedFill : public OperatorBase { + DEF_OPR_PARAM(Fill); + DEF_OPR_IMPL(MaskedFill, OperatorBase, 2, 1); + +public: + virtual void exec( + _megdnn_tensor_in origin, _megdnn_tensor_in index, + _megdnn_tensor_out dst) = 0; + void exec( + _megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dst, + _megdnn_workspace /*workspace*/) { + exec(origin, index, dst); + } + virtual size_t get_workspace_in_bytes( + const TensorLayout& origin, const TensorLayout& index, + const TensorLayout& dest) = 0; + void deduce_layout( + const TensorLayout& origin, const TensorLayout& index, TensorLayout& dest); + +protected: + void check_exec( + const TensorLayout& origin, const TensorLayout& index, + const TensorLayout& dest); +}; + /*! * \brief standard padding operator * Inputs must have the same dtype, and the output tensor shape must greater or equal diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 6c5a04da3..14396c36e 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -218,7 +218,8 @@ private: cb(RegionRestrictedConvolutionBackwardData) \ cb(RegionRestrictedConvolutionBackwardFilter) \ cb(GroupNormForward) \ - cb(GroupNormBackward) + cb(GroupNormBackward) \ + cb(MaskedFill) // clang-format on /*! diff --git a/dnn/src/common/masked_fill.cpp b/dnn/src/common/masked_fill.cpp new file mode 100644 index 000000000..9b1c8ce37 --- /dev/null +++ b/dnn/src/common/masked_fill.cpp @@ -0,0 +1,32 @@ +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +void MaskedFill::deduce_layout( + const TensorLayout& origin, const TensorLayout& /*index*/, TensorLayout& dest) { + dest = TensorLayout(origin, origin.dtype); +} + +void MaskedFill::check_exec( + const TensorLayout& origin, const TensorLayout& index, + const TensorLayout& dest) { + megdnn_assert_contiguous(index); + megdnn_assert_contiguous(dest); + megdnn_assert(index.dtype == dtype::Bool()); + megdnn_assert(origin.ndim >= index.ndim); + bool correct_index_shape = true; + for (size_t i = 0; i < index.ndim; i++) { + correct_index_shape = correct_index_shape && origin.shape[i] == index.shape[i]; + } + megdnn_assert(correct_index_shape, "unsupported index shape"); + bool supported_dtype = false; + +#define cb(Dtype) supported_dtype = supported_dtype || (origin.dtype == Dtype()); + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(megdnn::dtype::Bool) +#undef cb + + megdnn_assert(supported_dtype, "unsupported dtype"); +} + +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 39a28089f..7917507d8 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -144,6 +144,7 @@ DEF(RegionRestrictedConvolutionBackwardData, 5, true, false); DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); DEF(GroupNormForward, 6, true, true); DEF(GroupNormBackward, 8, true, true); +DEF(MaskedFill, 3, false, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index d49a3b9b7..e673467c4 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -44,6 +44,7 @@ #include "src/cuda/lrn/opr_impl.h" #include "src/cuda/lsq/opr_impl.h" #include "src/cuda/mask_conv/opr_impl.h" +#include "src/cuda/masked_fill/opr_impl.h" #include "src/cuda/matrix_inverse/opr_impl.h" #include "src/cuda/matrix_mul/opr_impl.h" #include "src/cuda/max_tensor_diff/opr_impl.h" @@ -178,6 +179,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaxTensorDiff); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskPropagate); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskedFill); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DBackwardData); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DBackwardFilter); diff --git a/dnn/src/cuda/masked_fill/kern.cu b/dnn/src/cuda/masked_fill/kern.cu new file mode 100644 index 000000000..cadec91e4 --- /dev/null +++ b/dnn/src/cuda/masked_fill/kern.cu @@ -0,0 +1,14 @@ +#include "./kern.cuh" + +namespace megdnn { +namespace cuda { +#define cb(_dtype) \ + INST_RUN_ELEMWISE( \ + MaskedFillScalarKernOp::ctype>, \ + DTypeTrait<_dtype>::ctype, 1); +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +cb(::megdnn::dtype::Bool) + +#undef cb +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/masked_fill/kern.cuh b/dnn/src/cuda/masked_fill/kern.cuh new file mode 100644 index 000000000..269ba0b14 --- /dev/null +++ b/dnn/src/cuda/masked_fill/kern.cuh @@ -0,0 +1,47 @@ +#pragma once + +#include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.cuh" + +#if MEGDNN_CC_HOST +#include "megdnn/oprs.h" +#endif + +namespace megdnn { +namespace cuda { + +template +struct MaskedFillScalarKernOp { + using VectTypeTrait = elemwise_intl::VectTypeTrait; + typedef typename VectTypeTrait::vect_type vect_type; + ctype* output; + bool* mask; + ctype value; + uint32_t mask_stride; + + __device__ __forceinline__ void operator()(uint32_t idx, ctype orig) { + output[idx] = mask[idx / mask_stride] + ? value + : orig; //! mask[idx] * orig + mask[idx]* *value; + } + __device__ __forceinline__ void operator()(uint32_t idx, vect_type orig) { + ctype a = mask[(idx) / mask_stride] ? value : orig.x; + ctype b = mask[(idx + 1) / mask_stride] ? value : orig.y; + ctype g = mask[(idx + 2) / mask_stride] ? value : orig.z; + ctype r = mask[(idx + 3) / mask_stride] ? value : orig.w; + *(vect_type*)(&output[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } + +#if MEGDNN_CC_HOST + MaskedFillScalarKernOp( + const TensorND& output, const TensorND& mask, ctype value, + uint32_t mask_stride) + : output{output.ptr()}, + mask{mask.ptr()}, + value{value}, + mask_stride{mask_stride} {} +#endif +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/masked_fill/opr_impl.cpp b/dnn/src/cuda/masked_fill/opr_impl.cpp new file mode 100644 index 000000000..1906eba48 --- /dev/null +++ b/dnn/src/cuda/masked_fill/opr_impl.cpp @@ -0,0 +1,31 @@ +#include "./opr_impl.h" +#include "./kern.cuh" +#include "src/common/utils.h" +namespace megdnn { +namespace cuda { +void MaskedFillImpl::exec( + _megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dest) { + check_exec(origin.layout, index.layout, dest.layout); + + megdnn_assert(index.layout.is_contiguous()); + uint32_t mask_stride = TensorLayout(origin.layout, origin.layout.dtype) + .stride[index.layout.ndim - 1]; + ElemwiseOpParamN<1> ele_param; + ele_param[0] = origin; + ele_param.init_from_given_tensor(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (origin.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + auto value = static_cast(param().value); \ + run_elemwise, T, 1>( \ + ele_param, stream, {dest, index, value, mask_stride}); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb +} +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/masked_fill/opr_impl.h b/dnn/src/cuda/masked_fill/opr_impl.h new file mode 100644 index 000000000..289c8d3a8 --- /dev/null +++ b/dnn/src/cuda/masked_fill/opr_impl.h @@ -0,0 +1,17 @@ +#pragma once +#include "megdnn/oprs.h" +#include "src/cuda/utils.h" +namespace megdnn { +namespace cuda { +class MaskedFillImpl : public MaskedFill { +public: + using MaskedFill::MaskedFill; + void exec(_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dst) + override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 6c5e50c77..412521153 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -48,6 +48,7 @@ #include "src/naive/lstm/opr_impl.h" #include "src/naive/lstm_cell/opr_impl.h" #include "src/naive/mask_conv/opr_impl.h" +#include "src/naive/masked_fill/opr_impl.h" #include "src/naive/matrix_inverse/opr_impl.h" #include "src/naive/matrix_mul/opr_impl.h" #include "src/naive/max_tensor_diff/opr_impl.h" diff --git a/dnn/src/naive/masked_fill/opr_impl.cpp b/dnn/src/naive/masked_fill/opr_impl.cpp new file mode 100644 index 000000000..883a1b232 --- /dev/null +++ b/dnn/src/naive/masked_fill/opr_impl.cpp @@ -0,0 +1,54 @@ +#include "src/naive/masked_fill/opr_impl.h" +#include +#include "megdnn/tensor_iter.h" +#include "src/common/elemwise_helper.cuh" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace { +using namespace megdnn; +template +void forward_impl(const ElemwiseOpParamN<3> src, const T value) { + auto inp = tensor_iter_valonly(src[0]).begin(); + auto out = tensor_iter_valonly(src[1]).begin(); + auto mask = tensor_iter_valonly(src[2]).begin(); + size_t total = src[0].layout.total_nr_elems(); + for (size_t i = 0; i < total; ++i) { + *out = *mask ? value : *inp; + ++inp; + ++out; + ++mask; + } +} +} // namespace + +namespace megdnn { +namespace naive { +void MaskedFillImpl::exec( + _megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dest) { + check_exec(origin.layout, index.layout, dest.layout); + + megdnn_assert(origin.layout.is_contiguous() && index.layout.is_contiguous()); + ElemwiseOpParamN<3> src; + src[0] = origin; + src[1] = dest; + src[2] = index; + if (src[2].layout.ndim < src[0].layout.ndim) { + for (size_t n = src[2].layout.ndim; n < src[0].layout.ndim; n++) + src[2].layout.add_axis_cont_inplace(n); + } + src[2].layout = src[2].layout.broadcast(origin.layout); + +#define cb(DType) \ + if (origin.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + auto value = static_cast(param().value); \ + forward_impl(src, value); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb +} +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/masked_fill/opr_impl.h b/dnn/src/naive/masked_fill/opr_impl.h new file mode 100644 index 000000000..93cb167ed --- /dev/null +++ b/dnn/src/naive/masked_fill/opr_impl.h @@ -0,0 +1,17 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { +class MaskedFillImpl : public MaskedFill { +public: + using MaskedFill::MaskedFill; + void exec(_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dst) + override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; +} // namespace naive +} // namespace megdnn diff --git a/dnn/test/cuda/masked_fill.cpp b/dnn/test/cuda/masked_fill.cpp new file mode 100644 index 000000000..9c482f1ba --- /dev/null +++ b/dnn/test/cuda/masked_fill.cpp @@ -0,0 +1,41 @@ +#include "test/cuda/fixture.h" + +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, MASKEDFILL) { + using Param = MaskedFill::Param; + Param param; + param.value = 1.0; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-2); + + auto run = [&](DType d) { + for (size_t A : {2, 3}) + for (size_t B : {6, 9}) { + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, dtype::Bool()) + .set_dtype(2, d) + .execs({{A, B, 2, 1}, {A, B}, {A, B, 2, 1}}); + } + for (size_t A : {2, 3}) + for (size_t B : {6, 9}) { + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, dtype::Bool()) + .set_dtype(2, d) + .execs({{A, B, 2, 1}, {A, B, 2, 1}, {A, B, 2, 1}}); + } + }; + + run(dtype::Float32()); + run(dtype::Float16()); + run(dtype::BFloat16()); + run(dtype::Uint8()); +} + +} // namespace test +} // namespace megdnn diff --git a/dnn/test/naive/masked_fill.cpp b/dnn/test/naive/masked_fill.cpp new file mode 100644 index 000000000..43e7d23c2 --- /dev/null +++ b/dnn/test/naive/masked_fill.cpp @@ -0,0 +1,50 @@ +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, MASKEDFILL) { + Checker checker(handle(), true); + + MaskedFill::Param param; + param.value = 0.2; + + checker.set_param(param).exect( + Testcase{ + TensorValue( + {2, 3, 2, 1}, dtype::Float32(), + {3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587, + 0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input + TensorValue({2}, dtype::Bool(), {false, true}), // hx + {}}, + Testcase{ + {}, + {}, + TensorValue( + {2, 3, 2, 1}, dtype::Float32(), + {3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.2}), // output + }); + checker.set_param(param).exect( + Testcase{ + TensorValue( + {1, 3, 1, 2}, dtype::Float32(), + {-2.4348, -1.7948, 0.5223, 0.0932, -0.2955, + -0.0492}), // input + TensorValue({1, 3}, dtype::Bool(), {false, true, false}), // hx + {}, + }, + Testcase{ + {}, + {}, + TensorValue( + {1, 3, 1, 2}, dtype::Float32(), + {-2.4348, -1.7948, 0.2, 0.2, -0.2955, -0.0492}), + }); +} + +} // namespace test +} // namespace megdnn diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 1319348e1..6258dd23f 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1290,6 +1290,29 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { py::object org_shape = getattr(inp_hdl, "shape"); py::object val = py::reinterpret_borrow(val_hdl); + bool is_val_scalar = false; + float value; + if (PyLong_Check(val.ptr())) { + is_val_scalar = true; + value = static_cast(PyLong_AsDouble(val.ptr())); + } + if (PyFloat_Check(val.ptr())) { + is_val_scalar = true; + value = static_cast(PyFloat_AsDouble(val.ptr())); + } + if (TensorWrapper::try_cast(idx_hdl.ptr()) && is_bool_dtype(idx_hdl.ptr()) && + is_val_scalar && enable_fastpath(inp_hdl)) { + std::vector q(3); + std::shared_ptr Op = MaskedFill::make(value); + py::object maskedfill = py::cast(Op); + q[0] = maskedfill.ptr(); + q[1] = inp_hdl.ptr(); + q[2] = idx_hdl.ptr(); + py::tuple result = + py::reinterpret_steal(py_apply(NULL, q.data(), q.size())); + py::object res = result[0]; + return res; + } if (!TensorWrapper::try_cast(val.ptr())) { val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device")); } diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 7afc86df3..67adb0c6e 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -292,4 +292,62 @@ OP_TRAIT_REG(Split, Split, opr::Split) } // namespace split +namespace masked_fill { + +cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op_def = def.cast_final_safe(); + OperatorNodeConfig config{op_def.make_name()}; + mgb_assert(inputs.size() == 2); + return opr::MaskedFill::make(inputs[0], inputs[1], op_def.param(), config) + .node() + ->owner_opr(); +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + return layout_checker; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& input_descs) { + return {{{{input_descs[0].layout, input_descs[0].layout.dtype}, + input_descs[0].comp_node}}, + input_descs[0].layout.ndim != 0}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto&& inp = inputs[0]; + auto&& mask = inputs[1]; + + TensorLayout outlayout(inp->layout(), inp->layout().dtype); + + auto output = Tensor::make(outlayout, inp->comp_node()); + + DnnOprCaller dnn_opr{inp->comp_node(), op.param()}; + dnn_opr.exec_with_ws(inp, mask, output); + return {output}; +} + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return MaskedFill::make(node->param()); +} + +OP_TRAIT_REG(MaskedFill, MaskedFill, mgb::opr::MaskedFill) + .get_input_layout_constraint(get_input_layout_constraint) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .apply_on_var_node(apply_on_var_node) + .make_from_op_node(make_from_op_node) + .fallback(); + +} // namespace masked_fill + } // namespace mgb::imperative diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index f5f9ee6e8..ce0ee0f37 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ 8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py -7d6df1c8e50a22ef2c36b7ea89daa9c5 ../../src/core/include/megbrain/ir/ops.td -f30ae9494b4bf3363cd74d9396acaf49 generated/opdef.h.inl -cb27f486b28a099221f38c6fcaa06a44 generated/opdef.cpp.inl -adb758acd1147f213db7f0cb1b708773 generated/opdef.py.inl -30ad8e75a5994edf9ec46387c6285312 generated/opdef.cpy.inl +4bd0317fd84b5065c8d88a7ca6241908 ../../src/core/include/megbrain/ir/ops.td +cb32cb1ef6b2ef4a7defaeb02ecd36e3 generated/opdef.h.inl +1c0230f60ddf3459de2aa4e16c1e2957 generated/opdef.cpp.inl +f6cbfd25f0d61e7b94c687733f5ae9b9 generated/opdef.py.inl +3a023199c39ea5611975b902a882bbba generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index 994a6d6b4..ab292471d 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -4788,6 +4788,43 @@ OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime) .props(MagicMindRuntime_props_impl) .make_name(MagicMindRuntime_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskedFill); + +namespace { +size_t MaskedFill_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::hash(op_.value)); + return val; +} +bool MaskedFill_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.value != b_.value) return false; + return true; +} +std::vector> MaskedFill_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + props_.emplace_back("value", std::to_string(op_.value)); + return props_; +} +std::string MaskedFill_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "MaskedFill"; +} +} // anonymous namespace +OP_TRAIT_REG(MaskedFill, MaskedFill) + .hash(MaskedFill_hash_impl) + .is_same_st(MaskedFill_is_same_st_impl) + .props(MaskedFill_props_impl) + .make_name(MaskedFill_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); namespace { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 29e641bd0..363a78959 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -14037,6 +14037,115 @@ void _init_py_MagicMindRuntime(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MagicMindRuntime::typeinfo(), &py_type).second); } +PyOpDefBegin(MaskedFill) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"value", serialization::dump(opdef.value)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("value"); + if (iter != state.end()) { + opdef.value = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(MaskedFill) + +int PyOp(MaskedFill)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"value", "scope", NULL}; + PyObject *value = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &value, &scope)) + return -1; + + if (value) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().value = + py::cast(py::handle(value)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(MaskedFill)::py_getsetters[] = { + {const_cast("value"), py_get_generic(MaskedFill, value), py_set_generic(MaskedFill, value), const_cast("value"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(MaskedFill)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(MaskedFill)::getstate, METH_NOARGS, "MaskedFill getstate"}, + {const_cast("__setstate__"), PyOp(MaskedFill)::setstate, METH_VARARGS, "MaskedFill setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(MaskedFill)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(MaskedFill)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(MaskedFill)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(MaskedFill)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self, value: float = ...) -> None\n" +}; + +void _init_py_MaskedFill(py::module m) { + using py_op = PyOp(MaskedFill); + auto& py_type = PyOpType(MaskedFill); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.MaskedFill"; + py_type.tp_basicsize = sizeof(PyOp(MaskedFill)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "MaskedFill"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(MaskedFill), &PyOp(MaskedFill)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("MaskedFill", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MaskedFill::typeinfo(), &py_type).second); +} + PyOpDefBegin(MatrixInverse) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -22157,6 +22266,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_LayerNorm(m); \ _init_py_Linspace(m); \ _init_py_MagicMindRuntime(m); \ + _init_py_MaskedFill(m); \ _init_py_MatrixInverse(m); \ _init_py_MatrixMul(m); \ _init_py_MeshGrid(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index b486f9987..a130b9e27 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1288,6 +1288,19 @@ public: MagicMindRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); } }; +class MaskedFill : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + float value = 0; + MaskedFill() = default; + MaskedFill(float value_, std::string scope_ = {}): value(value_) { set_scope(scope_); } + MaskedFill(::megdnn::param::Fill packed_param_0): value(packed_param_0.value) {} + ::megdnn::param::Fill param() const { + return {value}; + } +}; + class MatrixInverse : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index b3ac9df30..b3af8baf4 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1412,6 +1412,12 @@ MagicMindRuntimeInst .def_readwrite("buf", &MagicMindRuntime::buf) .def_readwrite("buf_size", &MagicMindRuntime::buf_size); +py::class_, OpDef> MaskedFillInst(m, "MaskedFill"); + +MaskedFillInst + .def(py::init(), py::arg("value") = 0, py::arg("scope") = {}) + .def_readwrite("value", &MaskedFill::value); + py::class_, OpDef> MatrixInverseInst(m, "MatrixInverse"); MatrixInverseInst diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index ed16963f6..eb56061a4 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -553,5 +553,6 @@ def MeshGrid: MgbHashableOp<"MeshGrid"> { def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [ConvolutionParam]>; def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>; +def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>; #endif // MGB_OPS diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 8932104fd..d25da45a7 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -1631,4 +1631,26 @@ MEGDNN_OPR_INIT2(PaddingBackward, "padding_backward", 1, false); // f}}} +/* f{{{ ======================= MaskedFill ======================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskedFill); +MEGDNN_OPR_INIT2(MaskedFill, "masked_fill"); + +void MaskedFill::init_output_dtype() { + output(0)->dtype(input(0)->dtype()); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(MaskedFill) { + mgb_assert(opr.input().size() == 2); + if (wrt_idx == 0) { + SymbolVar grad = MaskedFill::make(out_grad[0], opr.input(1), {.0}); + return grad.node(); + } else + return InvalidGrad::make(opr, wrt_idx); +} +#endif + +// f}}} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/tensor_manip.sereg.h b/src/opr/impl/tensor_manip.sereg.h index b333ad4d5..a88f0854d 100644 --- a/src/opr/impl/tensor_manip.sereg.h +++ b/src/opr/impl/tensor_manip.sereg.h @@ -214,6 +214,7 @@ MGB_SEREG_OPR(RelayoutFormatV1, 1); MGB_SEREG_OPR(Padding, 1); MGB_SEREG_OPR(PaddingBackward, 2); +MGB_SEREG_OPR(MaskedFill, 2); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index c8817f449..62290f9f9 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -630,6 +630,18 @@ public: const OperatorNodeConfig& config = {}); }; +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + MaskedFill, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC MaskedFill( + VarNode* src, VarNode* index, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar src, SymbolVar index, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC void init_output_dtype() override; +}; + } // namespace opr } // namespace mgb diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 9a94aea73..fa5af9edd 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -124,6 +124,7 @@ union OperatorParam { param.Softmax = 90, param.Diag = 91, param.GroupNorm = 92, + param.Fill = 93, } table Operator { diff --git a/src/serialization/impl/schema_v2.fbs b/src/serialization/impl/schema_v2.fbs index 87c2ef36a..49b341ffe 100644 --- a/src/serialization/impl/schema_v2.fbs +++ b/src/serialization/impl/schema_v2.fbs @@ -141,6 +141,7 @@ union OperatorParam { param.Softmax = 90, param.Diag = 91, param.GroupNorm = 92, + param.Fill = 93, } table Operator { -- GitLab