提交 13b15fb0 编写于 作者: M Megvii Engine Team

feat(megbrain): add correlation opr

GitOrigin-RevId: 6d44598891d2b88c3da21f53b1e205a54e02f11f
上级 1997b1a2
/** /**
* \file dnn/src/cuda/roi_align/roi_align.cu * \file dnn/src/cuda/correlation/correlation_cuda.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......
...@@ -28,9 +28,9 @@ ...@@ -28,9 +28,9 @@
#include "src/naive/convolution/opr_impl.h" #include "src/naive/convolution/opr_impl.h"
#include "src/naive/convolution3d/opr_impl.h" #include "src/naive/convolution3d/opr_impl.h"
#include "src/naive/convpooling/opr_impl.h" #include "src/naive/convpooling/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/cumsum/opr_impl.h" #include "src/naive/cumsum/opr_impl.h"
#include "src/naive/cvt_color/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/dct/opr_impl.h" #include "src/naive/dct/opr_impl.h"
#include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_conv/opr_impl.h"
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h"
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "src/naive/elemwise/opr_impl.h" #include "src/naive/elemwise/opr_impl.h"
#include "src/naive/elemwise_multi_type/opr_impl.h" #include "src/naive/elemwise_multi_type/opr_impl.h"
#include "src/naive/eye/opr_impl.h" #include "src/naive/eye/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h"
#include "src/naive/flip/opr_impl.h" #include "src/naive/flip/opr_impl.h"
#include "src/naive/gaussian_blur/opr_impl.h" #include "src/naive/gaussian_blur/opr_impl.h"
#include "src/naive/group_local/opr_impl.h" #include "src/naive/group_local/opr_impl.h"
...@@ -75,13 +76,11 @@ ...@@ -75,13 +76,11 @@
#include "src/naive/tensor_remap/opr_impl.h" #include "src/naive/tensor_remap/opr_impl.h"
#include "src/naive/tile/opr_impl.h" #include "src/naive/tile/opr_impl.h"
#include "src/naive/topk/opr_impl.h" #include "src/naive/topk/opr_impl.h"
#include "src/naive/tqt/opr_impl.h"
#include "src/naive/transpose/opr_impl.h" #include "src/naive/transpose/opr_impl.h"
#include "src/naive/type_cvt/opr_impl.h" #include "src/naive/type_cvt/opr_impl.h"
#include "src/naive/warp_affine/opr_impl.h" #include "src/naive/warp_affine/opr_impl.h"
#include "src/naive/warp_perspective/opr_impl.h" #include "src/naive/warp_perspective/opr_impl.h"
#include "src/naive/remap/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h"
#include "src/naive/tqt/opr_impl.h"
static size_t g_image2d_pitch_alignment = 1; static size_t g_image2d_pitch_alignment = 1;
......
...@@ -45,19 +45,6 @@ inline static std::vector<TestArg> get_args() { ...@@ -45,19 +45,6 @@ inline static std::vector<TestArg> get_args() {
TensorShape{batch_size, channel, height, width}, TensorShape{batch_size, channel, height, width},
TensorShape{batch_size, channel, height, width}); TensorShape{batch_size, channel, height, width});
// cur_param.is_multiply = false;
// cur_param.kernel_size = 1;
// cur_param.max_displacement = 2;
// cur_param.pad_size = 1;
// cur_param.stride1 = 1;
// cur_param.stride2 = 1;
// cur_param.format =
// megdnn::param::Correlation::Format::NCHW;
// args.emplace_back(
// cur_param,
// TensorShape{batch_size, channel, height, width},
// TensorShape{batch_size, channel, height, width});
} }
} }
} }
......
...@@ -106,6 +106,43 @@ def roi_pooling( ...@@ -106,6 +106,43 @@ def roi_pooling(
return result return result
def correlation(
data1: Tensor,
data2: Tensor,
kernel_size: int = 1,
max_displacement: int = 1,
stride1: int = 1,
stride2: int = 1,
pad_size: int = 0,
is_multiply: bool = True,
) -> Tensor:
""" Applies correlation to inputs.
:param data1: Input data1 to the correlation. format must be nchw
:param data2: Input data2 to the correlation. format must be nchw
:param kernel_size: (int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number
:param max_displacement: (int (non-negative), optional, default=1) – Max displacement of Correlation
:param stride1: (int (non-negative), optional, default=1) – stride1 quantize data1 globally
:param stride2: (int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1
:param pad_size: (int (non-negative), optional, default=0) – pad for Correlation
:param is_multiply: (boolean, optional, default=True) – operation type is either multiplication or absolute difference
"""
op = builtin.Correlation(
format="NCHW",
kernel_size=kernel_size,
max_displacement=max_displacement,
stride1=stride1,
stride2=stride2,
pad_size=pad_size,
is_multiply=is_multiply,
)
result, *_ = apply(op, data1, data2)
return result
def roi_align( def roi_align(
inp: Tensor, inp: Tensor,
rois: Tensor, rois: Tensor,
......
...@@ -228,6 +228,106 @@ def test_roi_align(): ...@@ -228,6 +228,106 @@ def test_roi_align():
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
if random:
inp_feat1 = np.random.randn(
image_shape[0], image_shape[1], image_shape[2], image_shape[3]
)
inp_feat2 = np.random.randn(
image_shape[0], image_shape[1], image_shape[2], image_shape[3]
)
else:
inp_feat1 = np.ones(image_shape) * constant
inp_feat2 = np.ones(image_shape) * constant
return tensor(inp_feat1), tensor(inp_feat2)
def test_correlation():
##test case 0 check the grad shape
data1, data2 = _gen_correlation()
grad = Grad().wrt(data1, callback=_save_to(data1))
out_feat = F.vision.correlation(
data1,
data2,
kernel_size=5,
max_displacement=4,
stride1=2,
stride2=2,
pad_size=2,
is_multiply=True,
)
grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)
##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=0,
stride1=1,
stride2=1,
pad_size=0,
is_multiply=True,
)
assert abs(out_feat.sum() - 1) < 1e-9
##test case 2 check same image subduction
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=0,
stride1=1,
stride2=1,
pad_size=0,
is_multiply=False,
)
assert out_feat.sum() < 1e-9
##test case 3 check same image subduction
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=0,
stride1=1,
stride2=1,
pad_size=0,
is_multiply=False,
)
assert out_feat.sum() < 1e-9
##test case 4 check correlation
data1, _ = _gen_correlation(
random=False, image_shape=(1, 1, 220, 220), constant=2.0
)
_, data2 = _gen_correlation(
random=False, image_shape=(1, 1, 220, 220), constant=1.0
)
out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=2,
stride1=1,
stride2=2,
pad_size=0,
is_multiply=False,
)
assert abs(out_feat.mean() - 1) < 1e-9
def test_roi_pooling(): def test_roi_pooling():
inp_feat, rois = _gen_roi_inp() inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h" #include "megbrain/opr/blas.h"
...@@ -445,6 +446,21 @@ OP_TRAIT_REG(ROIAlign, ROIAlign) ...@@ -445,6 +446,21 @@ OP_TRAIT_REG(ROIAlign, ROIAlign)
.fallback(); .fallback();
}} // roi_align }} // roi_align
namespace { namespace correlation {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Correlation&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Correlation::make(
inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Correlation, Correlation)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // correlation
#if MGB_CUDA #if MGB_CUDA
namespace { namespace nvof { namespace { namespace nvof {
auto apply_on_var_node( auto apply_on_var_node(
......
...@@ -82,6 +82,7 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio ...@@ -82,6 +82,7 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;
def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;
......
/**
* \file src/opr/impl/dnn/correlation.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"
#include "../internal/megdnn_opr_wrapper.inl"
using namespace mgb;
using namespace opr;
/* ==================== CorrelationForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationForward);
CorrelationForward::CorrelationForward(VarNode* data1, VarNode* data2,
const Param& param,
const OperatorNodeConfig& config)
: Super{data1->owner_graph(), config, "correlation", {data1, data2}} {
init_megdnn_opr(*this, param);
mgb_assert(data1->dtype() == data2->dtype());
mgb_assert(data1->dtype().category() == DTypeCategory::FLOAT);
add_input({data1, data2});
output(0)->dtype(data1->dtype());
}
SymbolVar CorrelationForward::make(SymbolVar data1, SymbolVar data2,
const Param& param,
const OperatorNodeConfig& config) {
return data1.insert_single_output_opr<CorrelationForward>(
data1.node(), data2.node(), param, config);
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CorrelationForward) {
if (wrt_idx == 0) {
// wrt src
SymbolVar grad = CorrelationBackwardData1::make(
out_grad[0], opr.input(0), opr.input(1), opr.param(),
opr.config());
return grad.node();
} else {
mgb_assert(wrt_idx == 1);
SymbolVar grad = CorrelationBackwardData2::make(
out_grad[0], opr.input(0), opr.input(1), opr.param(),
opr.config());
return grad.node();
}
}
#endif
/* ==================== CorrelationBackwardData1 ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData1);
MEGDNN_OPR_INIT3(CorrelationBackwardData1, "correlation_backward_data1", 1,
true);
void CorrelationBackwardData1::scn_do_execute() {
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
size_t CorrelationBackwardData1::get_workspace_size_bytes(
const TensorShapeArray& inp_shapes,
const TensorShapeArray& out_shapes) const {
TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()},
data1{inp_shapes[1], input(1)->dtype(), input(1)->format()},
data2{inp_shapes[2], input(2)->dtype(), input(2)->format()},
grad1{out_shapes[0], output(0)->dtype(), output(0)->format()};
return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad1);
}
/* ==================== CorrelationBackwardData2 ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData2);
MEGDNN_OPR_INIT3(CorrelationBackwardData2, "correlation_backward_data2", 1,
true);
void CorrelationBackwardData2::scn_do_execute() {
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
size_t CorrelationBackwardData2::get_workspace_size_bytes(
const TensorShapeArray& inp_shapes,
const TensorShapeArray& out_shapes) const {
TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()},
data1{inp_shapes[1], input(1)->dtype(), input(1)->format()},
data2{inp_shapes[2], input(2)->dtype(), input(2)->format()},
grad2{out_shapes[0], output(0)->dtype(), output(0)->format()};
return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad2);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/adaptive_pooling.h"
...@@ -573,6 +574,10 @@ MGB_SEREG_OPR(DeformableConvForwardV1, 0); ...@@ -573,6 +574,10 @@ MGB_SEREG_OPR(DeformableConvForwardV1, 0);
MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0);
MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0);
MGB_SEREG_OPR(CorrelationForward, 2);
MGB_SEREG_OPR(CorrelationBackwardData1, 3);
MGB_SEREG_OPR(CorrelationBackwardData2, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
......
/**
* \file src/opr/include/megbrain/opr/dnn/correlation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(CorrelationForward,
intl::MegDNNOprWrapperFwd<megdnn::CorrelationForward>) // {
public:
CorrelationForward(VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar data1, SymbolVar data2,
const Param& param = {},
const OperatorNodeConfig& config = {});
};
using Correlation = CorrelationForward;
MGB_DEFINE_OPR_CLASS(
CorrelationBackwardData1, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData1>) // {
public:
CorrelationBackwardData1(VarNode* diff, VarNode* data1, VarNode* data2,
const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2,
const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void scn_do_execute() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
};
MGB_DEFINE_OPR_CLASS(
CorrelationBackwardData2, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData2>) // {
public:
CorrelationBackwardData2(VarNode* diff, VarNode* data1, VarNode* data2,
const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2,
const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void scn_do_execute() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
};
} // namespace opr
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/opr/test/dnn/correlation.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megdnn/oprs.h"
#include <cmath>
#include <iomanip>
#include <random>
#include <sstream>
using namespace mgb;
namespace {
using Param = opr::CorrelationForward::Param;
void run_forward(bool is_multiply) {
RNGxorshf rng{next_rand_seed()};
using Checker = AutoOprChecker<2, 1>;
Param param;
param.format = Param::Format::NCHW;
param.is_multiply = is_multiply;
param.kernel_size = 3;
param.max_displacement = 2;
param.pad_size = 1;
param.stride1 = 2;
param.stride2 = 2;
auto make_graph =
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto o0 = opr::CorrelationForward::make(inputs[0], inputs[1], param);
return {o0};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr = megdnn_naive_handle()
->create_operator<megdnn::CorrelationForward>();
opr->param() = param;
auto inp_shape = inp[0]->shape();
auto num = inp_shape[0];
auto height = inp_shape[2];
auto width = inp_shape[3];
uint32_t pad_size = param.pad_size;
uint32_t kernel_size = param.kernel_size;
uint32_t stride1 = param.stride1;
uint32_t stride2 = param.stride2;
uint32_t max_displacement = param.max_displacement;
int paddedbottomheight = height + 2 * pad_size;
int paddedbottomwidth = width + 2 * pad_size;
uint32_t kernel_radius = (kernel_size - 1) / 2;
uint32_t border_size = max_displacement + kernel_radius;
uint32_t top_width =
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) /
static_cast<float>(stride1));
uint32_t top_height =
ceil(static_cast<float>(paddedbottomheight - border_size * 2) /
static_cast<float>(stride1));
uint32_t neighborhood_grid_radius = max_displacement / stride2;
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
uint32_t top_channels =
neighborhood_grid_width * neighborhood_grid_width;
megdnn::TensorShape target_shape{num, top_channels, top_height,
top_width};
dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize(target_shape);
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(),
{});
};
auto rand_real = [&](float lo, float hi) {
std::uniform_real_distribution<float> dist(lo, hi);
return dist(rng);
};
auto gen_inp1 = [&](HostTensorND &inp) {
auto ptr = inp.ptr<float>();
for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) {
ptr[i] = rand_real(0.06f, 0.1f);
};
};
auto gen_inp2 = [&](HostTensorND &inp) {
auto ptr = inp.ptr<float>();
for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) {
ptr[i] = rand_real(0.01f, 0.04f);
};
};
Checker::RunOptions option;
option.numdiff_eps = 1e-3;
option.numdiff_max_err = 1e-2;
Checker checker{make_graph, fwd};
checker.set_input_generator(0, gen_inp1);
checker.set_input_generator(1, gen_inp2);
checker.run({TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 10, 10}}, option)
.run({TensorShape{1, 3, 50, 50}, TensorShape{1, 3, 50, 50}}, option)
.run({TensorShape{1, 1, 100, 100}, TensorShape{1, 1, 100, 100}},
option);
}
TEST(TestOprDNN, CorrelationForwardMultiply) {
// TODO: fix me, add correct backward of cpu
REQUIRE_GPU(1);
run_forward(true);
}
TEST(TestOprDNN, CorrelationForwardSubstract) {
// TODO: fix me, add correct backward of cpu
REQUIRE_GPU(1);
run_forward(false);
}
} // anonymous namespace
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -106,6 +106,7 @@ union OperatorParam { ...@@ -106,6 +106,7 @@ union OperatorParam {
param.DctChannelSelect = 72, param.DctChannelSelect = 72,
param.FakeQuant = 73, param.FakeQuant = 73,
param.TQT = 74, param.TQT = 74,
param.Correlation = 75,
} }
table Operator { table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册