From 47dcdf3e17caac1623765a99bab80ec1f5632279 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Jun 2021 15:09:17 +0800 Subject: [PATCH] fix(mgb/core): fix dtype and resize modifiers for tensor GitOrigin-RevId: a9d95a4cd80bdd3b2fea5f89fcf9712347930bcc --- dnn/src/naive/convolution/convolution.cpp | 4 +- src/core/impl/tensor.cpp | 4 +- src/gopt/impl/tensor_reformat.cpp | 10 +-- src/gopt/test/inference.cpp | 79 +++++++++++++++++++++++ 4 files changed, 89 insertions(+), 8 deletions(-) diff --git a/dnn/src/naive/convolution/convolution.cpp b/dnn/src/naive/convolution/convolution.cpp index 8676c41ef..834a7ca5f 100644 --- a/dnn/src/naive/convolution/convolution.cpp +++ b/dnn/src/naive/convolution/convolution.cpp @@ -270,8 +270,8 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, } ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( - const TensorLayout& /* src */, const TensorLayout& /* diff */, - const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, + const TensorLayout& /* src */, const TensorLayout& /* filter */, + const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { auto algo = diff --git a/src/core/impl/tensor.cpp b/src/core/impl/tensor.cpp index ab5eafa45..27ae463a6 100644 --- a/src/core/impl/tensor.cpp +++ b/src/core/impl/tensor.cpp @@ -443,7 +443,7 @@ TensorND::name DEF(resize, &)(const TensorShape& shape) { mgb_assert(m_layout.dtype.valid()); - m_layout = TensorLayout(shape, m_layout.dtype); + m_layout.init_contiguous_stride(shape); m_storage.ensure_size(m_layout.span().dist_byte()); return static_cast(*this); } @@ -479,7 +479,7 @@ DEF(storage, &)(const TensorStorage &storage) { DEF(dtype, &)(DType dtype) { if (m_layout.dtype != dtype) { - m_layout.dtype = dtype; + m_layout.modify_dtype_inplace(dtype); m_layout.ndim = 0; } return static_cast(*this); diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 9eaa97333..4cb259d1c 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -3833,8 +3833,9 @@ void PaddingChannelPass::apply(OptState& opt) const { inp->dtype().enumv() == DTypeEnum::QuantizedS32); TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]}; - std::shared_ptr host_val = std::make_shared( - inp->comp_node(), shape, inp->dtype()); + std::shared_ptr host_val = + std::make_shared(inp->comp_node(), inp->dtype()); + host_val->resize(shape); auto ptr = host_val->raw_ptr(); size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); @@ -3853,8 +3854,9 @@ void PaddingChannelPass::apply(OptState& opt) const { inp->dtype().enumv() == DTypeEnum::QuantizedS32); TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]}; - std::shared_ptr host_val = std::make_shared( - inp->comp_node(), shape, inp->dtype()); + std::shared_ptr host_val = + std::make_shared(inp->comp_node(), inp->dtype()); + host_val->resize(shape); auto ptr = host_val->raw_ptr(); size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index d6cb2af8e..9be80442a 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1208,6 +1208,85 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +#if MGB_OPENCL +#include "megcore_opencl.h" + +#define REQUIRE_OPENCL() \ + do { \ + if (!CompNode::get_device_count(CompNode::DeviceType::OPENCL)) { \ + return; \ + } \ + } while (0) + +TEST(TestGoptInference, ConvertFormatNHWCD4OpenCL) { + REQUIRE_OPENCL(); + + HostTensorGenerator<> gen; + auto cn = CompNode::load("openclx"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto host_x = gen({8, 8, 8, 8}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param; + param.pad_h = param.pad_w = 0; + auto w1 = mkcvar("w1", {4, 8, 3, 3}), + conv = opr::Convolution::make(x, w1, param); + auto shape_of = opr::GetVarShape::make(conv); + auto subtensor = opr::Subtensor::make( + shape_of, {opr::Subtensor::AxisIndexer::make_interval( + 0, x.make_scalar(2), None, x.make_scalar(1))}); + + opr::Resize::Param param_resize; + param_resize.format = opr::Resize::Param::Format::NCHW; + auto resize = opr::ResizeForward::make(conv, subtensor * 2, param_resize); + auto mat = mkcvar("mat", {8, 3, 3}), + warp = opr::WarpPerspectiveForward::make( + resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4})); + + auto b = mkvar("b", {1, 4, 1, 1}), + elem = opr::Elemwise::make({warp + b}, + opr::Elemwise::Param::Mode::RELU); + param.pad_h = param.pad_w = 1; + auto w2 = mkcvar("w2", {4, 4, 3, 3}), + y = opr::Convolution::make(elem, w2, param), + z = opr::AxisAddRemove::make( + y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); + + SymbolVar y_opt, z_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + unpack_vector(gopt::optimize_for_inference({z}, options), z_opt); + + ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, + find_opr(y_opt).param().format); + + ASSERT_EQ(TensorFormat::Type::DEFAULT, + find_opr(z_opt).input(0)->format().type()); + ASSERT_EQ(4, find_opr(z_opt).input(0)->shape().ndim); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); + + *host_x = *gen({8, 8, 16, 16}, cn); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} +#undef REQUIRE_OPENCL +#endif + TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle; -- GitLab