提交 2d0c9690 编写于 作者: M Megvii Engine Team

fix(opr): add mask check for dct

GitOrigin-RevId: a452092cc7bc0aded75be64aa3ef9ff202094bdb
上级 9ffde212
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "./internal/megdnn_opr_wrapper.inl" #include "./internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h" #include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
using namespace mgb; using namespace mgb;
...@@ -486,6 +487,7 @@ struct MegDNNOprInitPostCtor<DctChannelSelectForward> { ...@@ -486,6 +487,7 @@ struct MegDNNOprInitPostCtor<DctChannelSelectForward> {
} // namespace intl } // namespace intl
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
void DctChannelSelectForward::get_output_var_shape( void DctChannelSelectForward::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
auto mo = megdnn_opr(); auto mo = megdnn_opr();
...@@ -504,6 +506,7 @@ void DctChannelSelectForward::get_output_var_shape( ...@@ -504,6 +506,7 @@ void DctChannelSelectForward::get_output_var_shape(
} }
out_shape[0] = dst; out_shape[0] = dst;
} }
size_t DctChannelSelectForward::get_workspace_size_bytes( size_t DctChannelSelectForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const { const TensorShapeArray& output_shapes) const {
...@@ -513,6 +516,7 @@ size_t DctChannelSelectForward::get_workspace_size_bytes( ...@@ -513,6 +516,7 @@ size_t DctChannelSelectForward::get_workspace_size_bytes(
{input_shapes[0], input(0)->dtype(), input(0)->format()}, {}, {}, {input_shapes[0], input(0)->dtype(), input(0)->format()}, {}, {},
{output_shapes[0], output(0)->dtype(), output(0)->format()}); {output_shapes[0], output(0)->dtype(), output(0)->format()});
} }
void DctChannelSelectForward::scn_do_execute() { void DctChannelSelectForward::scn_do_execute() {
auto&& inp = input(); auto&& inp = input();
auto mo = megdnn_opr(); auto mo = megdnn_opr();
...@@ -524,7 +528,6 @@ void DctChannelSelectForward::scn_do_execute() { ...@@ -524,7 +528,6 @@ void DctChannelSelectForward::scn_do_execute() {
} else { } else {
mgb_assert(inp.size() == 3, "no support input tensor num %zu", mgb_assert(inp.size() == 3, "no support input tensor num %zu",
inp.size()); inp.size());
mo->exec(inp[0]->dev_tensor().as_megdnn(), mo->exec(inp[0]->dev_tensor().as_megdnn(),
inp[1]->dev_tensor().as_megdnn(), inp[1]->dev_tensor().as_megdnn(),
inp[2]->dev_tensor().as_megdnn(), inp[2]->dev_tensor().as_megdnn(),
...@@ -533,7 +536,70 @@ void DctChannelSelectForward::scn_do_execute() { ...@@ -533,7 +536,70 @@ void DctChannelSelectForward::scn_do_execute() {
} }
} }
MEGDNN_OPR_INIT3(DctChannelSelectForward, "dct_channel_select") void DctChannelSelectForward::valid_mask(const int* mask_offset, int mask_len,
const int* mask_val, int mask_val_len,
const Param& param) {
if (mask_len <= 0)
return;
mgb_assert(mask_offset[0] == 0,
"The first element of mask_offset must be zero, but got %d. For "
"example mask offset [0, 15, 20] indicate there are 2 ic, and "
"ic_0 will have (15 - 0) oc, ic_1 have (20 - 15) oc",
mask_offset[0]);
for (int i = 1; i < mask_len; ++i) {
if (param.format == Param::Format::NCHW4) {
mgb_assert(mask_offset[i] % 4 == 0,
"Invalid mask offset %d at %d, it should be times of "
"4 when using nchw4 format",
mask_offset[i], i);
}
mgb_assert(mask_offset[i] >= mask_offset[i - 1],
"The offset of mask must be increasing, but %d(%d) is less "
"than %d(%d)",
mask_offset[i], i, mask_offset[i - 1], i - 1);
}
const int max_mask = param.dct_block_size * param.dct_block_size;
for (int i = 0; i < mask_val_len; ++i) {
mgb_assert(0 <= mask_val[i] && mask_val[i] < max_mask,
"Invalid mask_val, assert 0 <= mask_val[%d] < %d, aka 0 <= "
"%d < %d",
i, max_mask, mask_val[i], max_mask);
}
}
DctChannelSelectForward::DctChannelSelectForward(
VarNode* src, VarNode* mask_offset, VarNode* mask_val,
const Param& param, const OperatorNodeConfig& config)
: Super(OperatorNodeBaseCtorParam{
src->owner_graph(), config, "dct_channel_select", {src}}) {
init_megdnn_opr(*this, param);
add_input({src, mask_offset, mask_val});
if (mask_offset != nullptr) {
mgb_assert(mask_val,
"mask_val should not be null when mask_offset is not null");
auto host_offset = mask_offset->owner_opr()
->cast_final_safe<opr::ImmutableTensor>()
.host_value();
auto host_val = mask_val->owner_opr()
->cast_final_safe<opr::ImmutableTensor>()
.host_value();
valid_mask(host_offset.ptr<int>(),
host_offset.layout().total_nr_elems(), host_val.ptr<int>(),
host_val.layout().total_nr_elems(), param);
}
intl::MegDNNOprInitPostCtor<DctChannelSelectForward>::apply(*this);
}
SymbolVar DctChannelSelectForward::make(SymbolVar src, SymbolVar mask_offset,
SymbolVar mask_val, const Param& param,
const OperatorNodeConfig& config) {
intl::MegDNNOprInitInputsModifier<DctChannelSelectForward>::apply(
param, {&src, &mask_offset, &mask_val});
return src.insert_single_output_opr<DctChannelSelectForward>(
src.node(), mask_offset.node(), mask_val.node(), param, config);
}
MEGDNN_OPR_INIT1(DctChannelSelectForward, "dct_channel_select") MEGDNN_OPR_INIT1(DctChannelSelectForward, "dct_channel_select")
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -639,6 +639,9 @@ SymbolVar ImmutableTensor::make(ComputingGraph &graph, const DTypeScalar &val, ...@@ -639,6 +639,9 @@ SymbolVar ImmutableTensor::make(ComputingGraph &graph, const DTypeScalar &val,
const DeviceTensorND& ImmutableTensor::value() const { const DeviceTensorND& ImmutableTensor::value() const {
return m_value.dev(); return m_value.dev();
} }
const DeviceTensorND& ImmutableTensor::host_value() {
return const_cast<Value*>(&m_value)->static_infer();
}
SymbolVar ImmutableTensor::make_from_value( SymbolVar ImmutableTensor::make_from_value(
ComputingGraph &graph, ComputingGraph &graph,
......
...@@ -286,6 +286,9 @@ size_t get_workspace_size_bytes( ...@@ -286,6 +286,9 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override; const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override; void scn_do_execute() override;
void valid_mask(const int* mask_offset, int mask_len, const int* mask_val,
int mask_val_len, const Param& param);
}; };
using DctChannelSelect = DctChannelSelectForward; using DctChannelSelect = DctChannelSelectForward;
......
...@@ -378,6 +378,8 @@ MGB_DEFINE_OPR_CLASS(ImmutableTensor, intl::DeviceTensorHolder) // { ...@@ -378,6 +378,8 @@ MGB_DEFINE_OPR_CLASS(ImmutableTensor, intl::DeviceTensorHolder) // {
//! get underlying value on device //! get underlying value on device
const DeviceTensorND& value() const; const DeviceTensorND& value() const;
const DeviceTensorND& host_value();
SymbolVar shallow_copy( SymbolVar shallow_copy(
ComputingGraph &graph, const OperatorNodeConfig &config) const { ComputingGraph &graph, const OperatorNodeConfig &config) const {
return make_from_value(graph, m_value, m_value_refkeep, config); return make_from_value(graph, m_value, m_value_refkeep, config);
......
...@@ -803,4 +803,54 @@ TEST(TestOprImgproc, DCT) { ...@@ -803,4 +803,54 @@ TEST(TestOprImgproc, DCT) {
MGB_MARK_USED_VAR(fwd3); MGB_MARK_USED_VAR(fwd3);
MGB_MARK_USED_VAR(gen_mask); MGB_MARK_USED_VAR(gen_mask);
} }
TEST(TestOprImgproc, DCT_BAD_MASK) {
HostTensorGenerator<dtype::Uint8> gen_u8;
HostTensorGenerator<dtype::Int32> gen_s32;
TensorShape src_shape({1, 2, 256, 256}), mask_offset_shape({3}),
mask_val_shape({8});
opr::DctChannelSelectForward::Param param;
auto graph = ComputingGraph::make();
auto src_tensor = gen_u8(src_shape);
auto mask_offset_tensor = gen_s32(mask_offset_shape);
auto mask_val_tensor = gen_s32(mask_val_shape);
auto mask_offset_ptr = mask_offset_tensor->ptr<int32_t>();
auto mask_val_ptr = mask_val_tensor->ptr<int32_t>();
mask_offset_ptr[0] = 1;
mask_val_ptr[0] = 64;
auto src_sym = opr::ImmutableTensor::make(*graph, *src_tensor);
auto mask_offset_sym =
opr::ImmutableTensor::make(*graph, *mask_offset_tensor);
auto mask_val_sym = opr::ImmutableTensor::make(*graph, *mask_val_tensor);
ASSERT_THROW(opr::DctChannelSelect::make(src_sym, mask_offset_sym,
mask_val_sym, param),
MegBrainError);
mask_offset_ptr[0] = 0;
mask_offset_ptr[1] = 2;
mask_offset_ptr[2] = 8;
mask_offset_sym = opr::ImmutableTensor::make(*graph, *mask_offset_tensor);
ASSERT_THROW(opr::DctChannelSelect::make(src_sym, mask_offset_sym,
mask_val_sym, param),
MegBrainError);
mask_val_ptr[0] = 0;
mask_val_ptr[1] = 1;
mask_val_ptr[2] = 2;
mask_val_ptr[3] = 3;
mask_val_ptr[4] = 4;
mask_val_ptr[5] = 5;
mask_val_ptr[6] = 6;
mask_val_ptr[7] = 7;
mask_val_sym = opr::ImmutableTensor::make(*graph, *mask_val_tensor);
opr::DctChannelSelect::make(src_sym, mask_offset_sym, mask_val_sym, param);
param.format = opr::DctChannelSelect::Param::Format::NCHW4;
ASSERT_THROW(opr::DctChannelSelect::make(src_sym, mask_offset_sym,
mask_val_sym, param),
MegBrainError);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -150,6 +150,36 @@ TEST(TestOprIO, ImmutableTensor) { ...@@ -150,6 +150,36 @@ TEST(TestOprIO, ImmutableTensor) {
} }
TEST(TestOprIO, ImmutableTensorHostvalue) {
HostTensorGenerator<> gen;
TensorShape shape({2, 3});
auto host_x = gen(shape);
auto graph = ComputingGraph::make();
auto x = opr::ImmutableTensor::make(*graph, *host_x);
auto y = x.node()->owner_opr()
->cast_final_safe<opr::ImmutableTensor>()
.host_value();
for (size_t i = 0; i < shape.total_nr_elems(); ++i) {
ASSERT_EQ(host_x->ptr<float>()[i], y.ptr<float>()[i]);
}
}
TEST(TestOprIO, ImmutableTensorHostvalueGPU) {
REQUIRE_GPU(1);
auto gpu_cn = CompNode::load("gpu0");
HostTensorGenerator<> gen;
TensorShape shape({2, 3});
auto host_x = gen(shape);
auto graph = ComputingGraph::make();
auto x = opr::ImmutableTensor::make(*graph, *host_x, {gpu_cn});
auto y = x.node()->owner_opr()
->cast_final_safe<opr::ImmutableTensor>()
.host_value();
for (size_t i = 0; i < shape.total_nr_elems(); ++i) {
ASSERT_EQ(host_x->ptr<float>()[i], y.ptr<float>()[i]);
}
}
TEST(TestOprIO, ImmutableTensorLarge) { TEST(TestOprIO, ImmutableTensorLarge) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto host_x = gen({1025}); auto host_x = gen({1025});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册