From 2d0c96906a287c1539b470ef55af45a93a8a195e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 26 Oct 2020 20:00:00 +0800 Subject: [PATCH] fix(opr): add mask check for dct GitOrigin-RevId: a452092cc7bc0aded75be64aa3ef9ff202094bdb --- src/opr/impl/imgproc.cpp | 70 +++++++++++++++++++++++++- src/opr/impl/io.cpp | 3 ++ src/opr/include/megbrain/opr/imgproc.h | 3 ++ src/opr/include/megbrain/opr/io.h | 2 + src/opr/test/imgproc.cpp | 50 ++++++++++++++++++ src/opr/test/io.cpp | 30 +++++++++++ 6 files changed, 156 insertions(+), 2 deletions(-) diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp index b84c8397f..94887296c 100644 --- a/src/opr/impl/imgproc.cpp +++ b/src/opr/impl/imgproc.cpp @@ -13,6 +13,7 @@ #include "./internal/megdnn_opr_wrapper.inl" #include "megbrain/graph/grad_impl.h" #include "megbrain/opr/imgproc.h" +#include "megbrain/opr/io.h" #include "megbrain/opr/utility.h" using namespace mgb; @@ -486,6 +487,7 @@ struct MegDNNOprInitPostCtor { } // namespace intl } // namespace opr } // namespace mgb + void DctChannelSelectForward::get_output_var_shape( const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { auto mo = megdnn_opr(); @@ -504,6 +506,7 @@ void DctChannelSelectForward::get_output_var_shape( } out_shape[0] = dst; } + size_t DctChannelSelectForward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { @@ -513,6 +516,7 @@ size_t DctChannelSelectForward::get_workspace_size_bytes( {input_shapes[0], input(0)->dtype(), input(0)->format()}, {}, {}, {output_shapes[0], output(0)->dtype(), output(0)->format()}); } + void DctChannelSelectForward::scn_do_execute() { auto&& inp = input(); auto mo = megdnn_opr(); @@ -524,7 +528,6 @@ void DctChannelSelectForward::scn_do_execute() { } else { mgb_assert(inp.size() == 3, "no support input tensor num %zu", inp.size()); - mo->exec(inp[0]->dev_tensor().as_megdnn(), inp[1]->dev_tensor().as_megdnn(), inp[2]->dev_tensor().as_megdnn(), @@ -533,7 +536,70 @@ void DctChannelSelectForward::scn_do_execute() { } } -MEGDNN_OPR_INIT3(DctChannelSelectForward, "dct_channel_select") +void DctChannelSelectForward::valid_mask(const int* mask_offset, int mask_len, + const int* mask_val, int mask_val_len, + const Param& param) { + if (mask_len <= 0) + return; + mgb_assert(mask_offset[0] == 0, + "The first element of mask_offset must be zero, but got %d. For " + "example mask offset [0, 15, 20] indicate there are 2 ic, and " + "ic_0 will have (15 - 0) oc, ic_1 have (20 - 15) oc", + mask_offset[0]); + for (int i = 1; i < mask_len; ++i) { + if (param.format == Param::Format::NCHW4) { + mgb_assert(mask_offset[i] % 4 == 0, + "Invalid mask offset %d at %d, it should be times of " + "4 when using nchw4 format", + mask_offset[i], i); + } + mgb_assert(mask_offset[i] >= mask_offset[i - 1], + "The offset of mask must be increasing, but %d(%d) is less " + "than %d(%d)", + mask_offset[i], i, mask_offset[i - 1], i - 1); + } + const int max_mask = param.dct_block_size * param.dct_block_size; + for (int i = 0; i < mask_val_len; ++i) { + mgb_assert(0 <= mask_val[i] && mask_val[i] < max_mask, + "Invalid mask_val, assert 0 <= mask_val[%d] < %d, aka 0 <= " + "%d < %d", + i, max_mask, mask_val[i], max_mask); + } +} + +DctChannelSelectForward::DctChannelSelectForward( + VarNode* src, VarNode* mask_offset, VarNode* mask_val, + const Param& param, const OperatorNodeConfig& config) + : Super(OperatorNodeBaseCtorParam{ + src->owner_graph(), config, "dct_channel_select", {src}}) { + init_megdnn_opr(*this, param); + add_input({src, mask_offset, mask_val}); + if (mask_offset != nullptr) { + mgb_assert(mask_val, + "mask_val should not be null when mask_offset is not null"); + auto host_offset = mask_offset->owner_opr() + ->cast_final_safe() + .host_value(); + auto host_val = mask_val->owner_opr() + ->cast_final_safe() + .host_value(); + + valid_mask(host_offset.ptr(), + host_offset.layout().total_nr_elems(), host_val.ptr(), + host_val.layout().total_nr_elems(), param); + } + intl::MegDNNOprInitPostCtor::apply(*this); +} + +SymbolVar DctChannelSelectForward::make(SymbolVar src, SymbolVar mask_offset, + SymbolVar mask_val, const Param& param, + const OperatorNodeConfig& config) { + intl::MegDNNOprInitInputsModifier::apply( + param, {&src, &mask_offset, &mask_val}); + return src.insert_single_output_opr( + src.node(), mask_offset.node(), mask_val.node(), param, config); +} + MEGDNN_OPR_INIT1(DctChannelSelectForward, "dct_channel_select") // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index a31049394..6bf35a342 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -639,6 +639,9 @@ SymbolVar ImmutableTensor::make(ComputingGraph &graph, const DTypeScalar &val, const DeviceTensorND& ImmutableTensor::value() const { return m_value.dev(); } +const DeviceTensorND& ImmutableTensor::host_value() { + return const_cast(&m_value)->static_infer(); +} SymbolVar ImmutableTensor::make_from_value( ComputingGraph &graph, diff --git a/src/opr/include/megbrain/opr/imgproc.h b/src/opr/include/megbrain/opr/imgproc.h index 26016fac2..d4fabc9e9 100644 --- a/src/opr/include/megbrain/opr/imgproc.h +++ b/src/opr/include/megbrain/opr/imgproc.h @@ -286,6 +286,9 @@ size_t get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const override; void scn_do_execute() override; + +void valid_mask(const int* mask_offset, int mask_len, const int* mask_val, + int mask_val_len, const Param& param); }; using DctChannelSelect = DctChannelSelectForward; diff --git a/src/opr/include/megbrain/opr/io.h b/src/opr/include/megbrain/opr/io.h index 14fe1c275..8c842efab 100644 --- a/src/opr/include/megbrain/opr/io.h +++ b/src/opr/include/megbrain/opr/io.h @@ -378,6 +378,8 @@ MGB_DEFINE_OPR_CLASS(ImmutableTensor, intl::DeviceTensorHolder) // { //! get underlying value on device const DeviceTensorND& value() const; + const DeviceTensorND& host_value(); + SymbolVar shallow_copy( ComputingGraph &graph, const OperatorNodeConfig &config) const { return make_from_value(graph, m_value, m_value_refkeep, config); diff --git a/src/opr/test/imgproc.cpp b/src/opr/test/imgproc.cpp index e54085e31..521d8fba4 100644 --- a/src/opr/test/imgproc.cpp +++ b/src/opr/test/imgproc.cpp @@ -803,4 +803,54 @@ TEST(TestOprImgproc, DCT) { MGB_MARK_USED_VAR(fwd3); MGB_MARK_USED_VAR(gen_mask); } + +TEST(TestOprImgproc, DCT_BAD_MASK) { + HostTensorGenerator gen_u8; + HostTensorGenerator gen_s32; + TensorShape src_shape({1, 2, 256, 256}), mask_offset_shape({3}), + mask_val_shape({8}); + opr::DctChannelSelectForward::Param param; + + auto graph = ComputingGraph::make(); + + auto src_tensor = gen_u8(src_shape); + auto mask_offset_tensor = gen_s32(mask_offset_shape); + auto mask_val_tensor = gen_s32(mask_val_shape); + auto mask_offset_ptr = mask_offset_tensor->ptr(); + auto mask_val_ptr = mask_val_tensor->ptr(); + mask_offset_ptr[0] = 1; + mask_val_ptr[0] = 64; + auto src_sym = opr::ImmutableTensor::make(*graph, *src_tensor); + auto mask_offset_sym = + opr::ImmutableTensor::make(*graph, *mask_offset_tensor); + auto mask_val_sym = opr::ImmutableTensor::make(*graph, *mask_val_tensor); + + ASSERT_THROW(opr::DctChannelSelect::make(src_sym, mask_offset_sym, + mask_val_sym, param), + MegBrainError); + + mask_offset_ptr[0] = 0; + mask_offset_ptr[1] = 2; + mask_offset_ptr[2] = 8; + mask_offset_sym = opr::ImmutableTensor::make(*graph, *mask_offset_tensor); + ASSERT_THROW(opr::DctChannelSelect::make(src_sym, mask_offset_sym, + mask_val_sym, param), + MegBrainError); + + mask_val_ptr[0] = 0; + mask_val_ptr[1] = 1; + mask_val_ptr[2] = 2; + mask_val_ptr[3] = 3; + mask_val_ptr[4] = 4; + mask_val_ptr[5] = 5; + mask_val_ptr[6] = 6; + mask_val_ptr[7] = 7; + mask_val_sym = opr::ImmutableTensor::make(*graph, *mask_val_tensor); + opr::DctChannelSelect::make(src_sym, mask_offset_sym, mask_val_sym, param); + + param.format = opr::DctChannelSelect::Param::Format::NCHW4; + ASSERT_THROW(opr::DctChannelSelect::make(src_sym, mask_offset_sym, + mask_val_sym, param), + MegBrainError); +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/io.cpp b/src/opr/test/io.cpp index 7930bffb6..e2a9f6099 100644 --- a/src/opr/test/io.cpp +++ b/src/opr/test/io.cpp @@ -150,6 +150,36 @@ TEST(TestOprIO, ImmutableTensor) { } +TEST(TestOprIO, ImmutableTensorHostvalue) { + HostTensorGenerator<> gen; + TensorShape shape({2, 3}); + auto host_x = gen(shape); + auto graph = ComputingGraph::make(); + auto x = opr::ImmutableTensor::make(*graph, *host_x); + auto y = x.node()->owner_opr() + ->cast_final_safe() + .host_value(); + for (size_t i = 0; i < shape.total_nr_elems(); ++i) { + ASSERT_EQ(host_x->ptr()[i], y.ptr()[i]); + } +} + +TEST(TestOprIO, ImmutableTensorHostvalueGPU) { + REQUIRE_GPU(1); + auto gpu_cn = CompNode::load("gpu0"); + HostTensorGenerator<> gen; + TensorShape shape({2, 3}); + auto host_x = gen(shape); + auto graph = ComputingGraph::make(); + auto x = opr::ImmutableTensor::make(*graph, *host_x, {gpu_cn}); + auto y = x.node()->owner_opr() + ->cast_final_safe() + .host_value(); + for (size_t i = 0; i < shape.total_nr_elems(); ++i) { + ASSERT_EQ(host_x->ptr()[i], y.ptr()[i]); + } +} + TEST(TestOprIO, ImmutableTensorLarge) { HostTensorGenerator<> gen; auto host_x = gen({1025}); -- GitLab