From 13b15fb08cc89b65ec92b795a9274eb9c3855012 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 13 May 2021 14:02:16 +0800 Subject: [PATCH] feat(megbrain): add correlation opr GitOrigin-RevId: 6d44598891d2b88c3da21f53b1e205a54e02f11f --- dnn/src/cuda/correlation/correlation_cuda.cu | 2 +- dnn/src/naive/handle.cpp | 7 +- dnn/test/common/correlation.h | 13 -- .../python/megengine/functional/vision.py | 37 +++++ .../test/unit/functional/test_functional.py | 100 +++++++++++++ imperative/src/impl/ops/specializations.cpp | 16 +++ src/core/include/megbrain/ir/ops.td | 1 + src/opr/impl/dnn/correlation.cpp | 109 ++++++++++++++ src/opr/impl/dnn/dnn.sereg.h | 5 + .../include/megbrain/opr/dnn/correlation.h | 69 +++++++++ src/opr/test/dnn/correlation.cpp | 134 ++++++++++++++++++ src/serialization/impl/schema.fbs | 1 + 12 files changed, 476 insertions(+), 18 deletions(-) create mode 100644 src/opr/impl/dnn/correlation.cpp create mode 100644 src/opr/include/megbrain/opr/dnn/correlation.h create mode 100644 src/opr/test/dnn/correlation.cpp diff --git a/dnn/src/cuda/correlation/correlation_cuda.cu b/dnn/src/cuda/correlation/correlation_cuda.cu index beb6db45a..22f6aa347 100644 --- a/dnn/src/cuda/correlation/correlation_cuda.cu +++ b/dnn/src/cuda/correlation/correlation_cuda.cu @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/roi_align/roi_align.cu + * \file dnn/src/cuda/correlation/correlation_cuda.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index e641737b3..e6138c4ae 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -28,9 +28,9 @@ #include "src/naive/convolution/opr_impl.h" #include "src/naive/convolution3d/opr_impl.h" #include "src/naive/convpooling/opr_impl.h" +#include "src/naive/correlation/opr_impl.h" #include "src/naive/cumsum/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h" -#include "src/naive/correlation/opr_impl.h" #include "src/naive/dct/opr_impl.h" #include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" @@ -38,6 +38,7 @@ #include "src/naive/elemwise/opr_impl.h" #include "src/naive/elemwise_multi_type/opr_impl.h" #include "src/naive/eye/opr_impl.h" +#include "src/naive/fake_quant/opr_impl.h" #include "src/naive/flip/opr_impl.h" #include "src/naive/gaussian_blur/opr_impl.h" #include "src/naive/group_local/opr_impl.h" @@ -75,13 +76,11 @@ #include "src/naive/tensor_remap/opr_impl.h" #include "src/naive/tile/opr_impl.h" #include "src/naive/topk/opr_impl.h" +#include "src/naive/tqt/opr_impl.h" #include "src/naive/transpose/opr_impl.h" #include "src/naive/type_cvt/opr_impl.h" #include "src/naive/warp_affine/opr_impl.h" #include "src/naive/warp_perspective/opr_impl.h" -#include "src/naive/remap/opr_impl.h" -#include "src/naive/fake_quant/opr_impl.h" -#include "src/naive/tqt/opr_impl.h" static size_t g_image2d_pitch_alignment = 1; diff --git a/dnn/test/common/correlation.h b/dnn/test/common/correlation.h index 373751244..bf5ac2683 100644 --- a/dnn/test/common/correlation.h +++ b/dnn/test/common/correlation.h @@ -45,19 +45,6 @@ inline static std::vector get_args() { TensorShape{batch_size, channel, height, width}, TensorShape{batch_size, channel, height, width}); - // cur_param.is_multiply = false; - // cur_param.kernel_size = 1; - // cur_param.max_displacement = 2; - // cur_param.pad_size = 1; - // cur_param.stride1 = 1; - // cur_param.stride2 = 1; - // cur_param.format = - // megdnn::param::Correlation::Format::NCHW; - - // args.emplace_back( - // cur_param, - // TensorShape{batch_size, channel, height, width}, - // TensorShape{batch_size, channel, height, width}); } } } diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index db6a1869d..8ed39a732 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -106,6 +106,43 @@ def roi_pooling( return result +def correlation( + data1: Tensor, + data2: Tensor, + kernel_size: int = 1, + max_displacement: int = 1, + stride1: int = 1, + stride2: int = 1, + pad_size: int = 0, + is_multiply: bool = True, +) -> Tensor: + """ Applies correlation to inputs. + + :param data1: Input data1 to the correlation. format must be nchw + :param data2: Input data2 to the correlation. format must be nchw + :param kernel_size: (int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number + :param max_displacement: (int (non-negative), optional, default=1) – Max displacement of Correlation + :param stride1: (int (non-negative), optional, default=1) – stride1 quantize data1 globally + :param stride2: (int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1 + :param pad_size: (int (non-negative), optional, default=0) – pad for Correlation + :param is_multiply: (boolean, optional, default=True) – operation type is either multiplication or absolute difference + + """ + + op = builtin.Correlation( + format="NCHW", + kernel_size=kernel_size, + max_displacement=max_displacement, + stride1=stride1, + stride2=stride2, + pad_size=pad_size, + is_multiply=is_multiply, + ) + + result, *_ = apply(op, data1, data2) + return result + + def roi_align( inp: Tensor, rois: Tensor, diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index c2a178a8e..c96865358 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -228,6 +228,106 @@ def test_roi_align(): assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) +def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): + if random: + inp_feat1 = np.random.randn( + image_shape[0], image_shape[1], image_shape[2], image_shape[3] + ) + inp_feat2 = np.random.randn( + image_shape[0], image_shape[1], image_shape[2], image_shape[3] + ) + else: + inp_feat1 = np.ones(image_shape) * constant + inp_feat2 = np.ones(image_shape) * constant + + return tensor(inp_feat1), tensor(inp_feat2) + + +def test_correlation(): + ##test case 0 check the grad shape + data1, data2 = _gen_correlation() + grad = Grad().wrt(data1, callback=_save_to(data1)) + + out_feat = F.vision.correlation( + data1, + data2, + kernel_size=5, + max_displacement=4, + stride1=2, + stride2=2, + pad_size=2, + is_multiply=True, + ) + + grad(out_feat, tensor(F.ones_like(out_feat))) + assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) + + ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 + data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) + + out_feat = F.vision.correlation( + data1, + data2, + kernel_size=3, + max_displacement=0, + stride1=1, + stride2=1, + pad_size=0, + is_multiply=True, + ) + assert abs(out_feat.sum() - 1) < 1e-9 + + ##test case 2 check same image subduction + data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) + + out_feat = F.vision.correlation( + data1, + data2, + kernel_size=3, + max_displacement=0, + stride1=1, + stride2=1, + pad_size=0, + is_multiply=False, + ) + assert out_feat.sum() < 1e-9 + + ##test case 3 check same image subduction + data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) + + out_feat = F.vision.correlation( + data1, + data2, + kernel_size=3, + max_displacement=0, + stride1=1, + stride2=1, + pad_size=0, + is_multiply=False, + ) + assert out_feat.sum() < 1e-9 + + ##test case 4 check correlation + data1, _ = _gen_correlation( + random=False, image_shape=(1, 1, 220, 220), constant=2.0 + ) + _, data2 = _gen_correlation( + random=False, image_shape=(1, 1, 220, 220), constant=1.0 + ) + + out_feat = F.vision.correlation( + data1, + data2, + kernel_size=3, + max_displacement=2, + stride1=1, + stride2=2, + pad_size=0, + is_multiply=False, + ) + assert abs(out_feat.mean() - 1) < 1e-9 + + def test_roi_pooling(): inp_feat, rois = _gen_roi_inp() grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 71cbebb5c..bde440cce 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -19,6 +19,7 @@ #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/roi_align.h" +#include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/blas.h" @@ -445,6 +446,21 @@ OP_TRAIT_REG(ROIAlign, ROIAlign) .fallback(); }} // roi_align +namespace { namespace correlation { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{op.make_name()}; + return opr::Correlation::make( + inputs[0], inputs[1], op.param(), config); +} +OP_TRAIT_REG(Correlation, Correlation) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // correlation + #if MGB_CUDA namespace { namespace nvof { auto apply_on_var_node( diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index a7e2c6555..164510d24 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -82,6 +82,7 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; +def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; diff --git a/src/opr/impl/dnn/correlation.cpp b/src/opr/impl/dnn/correlation.cpp new file mode 100644 index 000000000..494225b1f --- /dev/null +++ b/src/opr/impl/dnn/correlation.cpp @@ -0,0 +1,109 @@ +/** + * \file src/opr/impl/dnn/correlation.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/correlation.h" + +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/utility.h" + +#include "../internal/megdnn_opr_wrapper.inl" + +using namespace mgb; +using namespace opr; + +/* ==================== CorrelationForward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationForward); +CorrelationForward::CorrelationForward(VarNode* data1, VarNode* data2, + const Param& param, + const OperatorNodeConfig& config) + : Super{data1->owner_graph(), config, "correlation", {data1, data2}} { + init_megdnn_opr(*this, param); + mgb_assert(data1->dtype() == data2->dtype()); + mgb_assert(data1->dtype().category() == DTypeCategory::FLOAT); + + add_input({data1, data2}); + output(0)->dtype(data1->dtype()); +} + +SymbolVar CorrelationForward::make(SymbolVar data1, SymbolVar data2, + const Param& param, + const OperatorNodeConfig& config) { + return data1.insert_single_output_opr( + data1.node(), data2.node(), param, config); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(CorrelationForward) { + if (wrt_idx == 0) { + // wrt src + SymbolVar grad = CorrelationBackwardData1::make( + out_grad[0], opr.input(0), opr.input(1), opr.param(), + opr.config()); + return grad.node(); + } else { + mgb_assert(wrt_idx == 1); + SymbolVar grad = CorrelationBackwardData2::make( + out_grad[0], opr.input(0), opr.input(1), opr.param(), + opr.config()); + return grad.node(); + } +} +#endif + +/* ==================== CorrelationBackwardData1 ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData1); +MEGDNN_OPR_INIT3(CorrelationBackwardData1, "correlation_backward_data1", 1, + true); + +void CorrelationBackwardData1::scn_do_execute() { + megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), + input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output(1))); +} + +size_t CorrelationBackwardData1::get_workspace_size_bytes( + const TensorShapeArray& inp_shapes, + const TensorShapeArray& out_shapes) const { + TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()}, + data1{inp_shapes[1], input(1)->dtype(), input(1)->format()}, + data2{inp_shapes[2], input(2)->dtype(), input(2)->format()}, + grad1{out_shapes[0], output(0)->dtype(), output(0)->format()}; + return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad1); +} + +/* ==================== CorrelationBackwardData2 ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData2); +MEGDNN_OPR_INIT3(CorrelationBackwardData2, "correlation_backward_data2", 1, + true); + +void CorrelationBackwardData2::scn_do_execute() { + megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), + input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output(1))); +} + +size_t CorrelationBackwardData2::get_workspace_size_bytes( + const TensorShapeArray& inp_shapes, + const TensorShapeArray& out_shapes) const { + TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()}, + data1{inp_shapes[1], input(1)->dtype(), input(1)->format()}, + data2{inp_shapes[2], input(2)->dtype(), input(2)->format()}, + grad2{out_shapes[0], output(0)->dtype(), output(0)->format()}; + return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad2); +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 2082f42fd..01cedfa50 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -11,6 +11,7 @@ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/adaptive_pooling.h" @@ -573,6 +574,10 @@ MGB_SEREG_OPR(DeformableConvForwardV1, 0); MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); +MGB_SEREG_OPR(CorrelationForward, 2); +MGB_SEREG_OPR(CorrelationBackwardData1, 3); +MGB_SEREG_OPR(CorrelationBackwardData2, 3); + MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); diff --git a/src/opr/include/megbrain/opr/dnn/correlation.h b/src/opr/include/megbrain/opr/dnn/correlation.h new file mode 100644 index 000000000..bc58d887f --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/correlation.h @@ -0,0 +1,69 @@ +/** + * \file src/opr/include/megbrain/opr/dnn/correlation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" + +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS(CorrelationForward, + intl::MegDNNOprWrapperFwd) // { +public: + CorrelationForward(VarNode* data1, VarNode* data2, const Param& param, + const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar data1, SymbolVar data2, + const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; +using Correlation = CorrelationForward; + +MGB_DEFINE_OPR_CLASS( + CorrelationBackwardData1, intl::MegDNNOprWrapperBwd) // { +public: + CorrelationBackwardData1(VarNode* diff, VarNode* data1, VarNode* data2, + const Param& param, const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2, + const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void scn_do_execute() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; +}; + +MGB_DEFINE_OPR_CLASS( + CorrelationBackwardData2, intl::MegDNNOprWrapperBwd) // { +public: + CorrelationBackwardData2(VarNode* diff, VarNode* data1, VarNode* data2, + const Param& param, const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2, + const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void scn_do_execute() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; +}; + +} // namespace opr +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/dnn/correlation.cpp b/src/opr/test/dnn/correlation.cpp new file mode 100644 index 000000000..ddc53744d --- /dev/null +++ b/src/opr/test/dnn/correlation.cpp @@ -0,0 +1,134 @@ +/** + * \file src/opr/test/dnn/correlation.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/correlation.h" +#include "megbrain/test/autocheck.h" +#include "megbrain/test/helper.h" +#include "megbrain/test/megdnn_helper.h" + +#include "megdnn/oprs.h" + +#include +#include +#include +#include + +using namespace mgb; + +namespace { +using Param = opr::CorrelationForward::Param; + +void run_forward(bool is_multiply) { + RNGxorshf rng{next_rand_seed()}; + using Checker = AutoOprChecker<2, 1>; + + Param param; + param.format = Param::Format::NCHW; + param.is_multiply = is_multiply; + param.kernel_size = 3; + param.max_displacement = 2; + param.pad_size = 1; + param.stride1 = 2; + param.stride2 = 2; + + auto make_graph = + [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto o0 = opr::CorrelationForward::make(inputs[0], inputs[1], param); + return {o0}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = megdnn_naive_handle() + ->create_operator(); + opr->param() = param; + auto inp_shape = inp[0]->shape(); + auto num = inp_shape[0]; + auto height = inp_shape[2]; + auto width = inp_shape[3]; + + uint32_t pad_size = param.pad_size; + uint32_t kernel_size = param.kernel_size; + uint32_t stride1 = param.stride1; + uint32_t stride2 = param.stride2; + uint32_t max_displacement = param.max_displacement; + + int paddedbottomheight = height + 2 * pad_size; + int paddedbottomwidth = width + 2 * pad_size; + uint32_t kernel_radius = (kernel_size - 1) / 2; + uint32_t border_size = max_displacement + kernel_radius; + uint32_t top_width = + ceil(static_cast(paddedbottomwidth - border_size * 2) / + static_cast(stride1)); + uint32_t top_height = + ceil(static_cast(paddedbottomheight - border_size * 2) / + static_cast(stride1)); + uint32_t neighborhood_grid_radius = max_displacement / stride2; + uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + uint32_t top_channels = + neighborhood_grid_width * neighborhood_grid_width; + megdnn::TensorShape target_shape{num, top_channels, top_height, + top_width}; + + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize(target_shape); + opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), + {}); + }; + + auto rand_real = [&](float lo, float hi) { + std::uniform_real_distribution dist(lo, hi); + return dist(rng); + }; + auto gen_inp1 = [&](HostTensorND &inp) { + auto ptr = inp.ptr(); + for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) { + ptr[i] = rand_real(0.06f, 0.1f); + }; + }; + + auto gen_inp2 = [&](HostTensorND &inp) { + auto ptr = inp.ptr(); + for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) { + ptr[i] = rand_real(0.01f, 0.04f); + }; + }; + + Checker::RunOptions option; + option.numdiff_eps = 1e-3; + option.numdiff_max_err = 1e-2; + Checker checker{make_graph, fwd}; + + checker.set_input_generator(0, gen_inp1); + checker.set_input_generator(1, gen_inp2); + + checker.run({TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 10, 10}}, option) + .run({TensorShape{1, 3, 50, 50}, TensorShape{1, 3, 50, 50}}, option) + .run({TensorShape{1, 1, 100, 100}, TensorShape{1, 1, 100, 100}}, + option); +} + +TEST(TestOprDNN, CorrelationForwardMultiply) { + // TODO: fix me, add correct backward of cpu + REQUIRE_GPU(1); + run_forward(true); +} + +TEST(TestOprDNN, CorrelationForwardSubstract) { + // TODO: fix me, add correct backward of cpu + REQUIRE_GPU(1); + run_forward(false); +} + +} // anonymous namespace + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index e12ef1672..7b3f98474 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -106,6 +106,7 @@ union OperatorParam { param.DctChannelSelect = 72, param.FakeQuant = 73, param.TQT = 74, + param.Correlation = 75, } table Operator { -- GitLab