未验证 提交 86d4a672 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

Add upsample module backward (#5025)

* add argmax test

* add upsample module backward

* update upsample unittest

* fix unittest bug

* refine upsample backward

* code format

* fix comment

* fix comment
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 121b1423
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
namespace oneflow {
namespace one {
struct UpsampleInterpState : public OpExprInterpState {
bool requires_grad;
float height_scale;
float width_scale;
float align_corners;
std::string data_format;
std::string interpolation;
};
class Upsample : public OpExprGradFunction<UpsampleInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(UpsampleInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const UpsampleInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> grad_op_;
};
Maybe<void> Upsample::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
const std::string& op_name = fw_op_expr->op_name();
const float height_scale = 1.0;
const float width_scale = 1.0;
const bool align_corners = false;
const std::string data_format = "NCHW";
const std::string interpolation = "nearest";
grad_op_ =
JUST(op_expr_helper::UpsampleGradOp(height_scale, width_scale, align_corners, data_format,
interpolation, GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Upsample::Capture(UpsampleInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->height_scale = JUST(composed_attrs.GetAttr<float>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<float>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->interpolation = JUST(composed_attrs.GetAttr<std::string>("interpolation"));
return Maybe<void>::Ok();
}
Maybe<void> Upsample::Apply(const UpsampleInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
MutableAttrMap attrs;
JUST(attrs.SetAttr<float>("height_scale", ctx->height_scale));
JUST(attrs.SetAttr<float>("width_scale", ctx->width_scale));
JUST(attrs.SetAttr<bool>("align_corners", ctx->align_corners));
JUST(attrs.SetAttr<std::string>("data_format", ctx->data_format));
JUST(attrs.SetAttr<std::string>("interpolation", ctx->interpolation));
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grads.at(0)}, attrs));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample);
} // namespace one
} // namespace oneflow
......@@ -547,6 +547,26 @@ Maybe<one::UserOpExpr> PReLUGradOp(const std::string& name) {
.Build();
}
Maybe<one::UserOpExpr> UpsampleGradOp(const float& height_scale, const float& width_scale,
const bool& align_corners, const std::string& data_format,
const std::string& interpolation) {
return UpsampleGradOp(height_scale, width_scale, align_corners, data_format, interpolation,
UniqueOpName("upsample_grad"));
}
Maybe<one::UserOpExpr> UpsampleGradOp(const float& height_scale, const float& width_scale,
const bool& align_corners, const std::string& data_format,
const std::string& interpolation, const std::string& name) {
return one::OpBuilder("upsample_grad", name)
.Input("dy")
.Output("dx")
.Attr<float>("height_scale", height_scale)
.Attr<float>("width_scale", width_scale)
.Attr<bool>("align_corners", align_corners)
.Attr<std::string>("data_format", data_format)
.Attr<std::string>("interpolation", interpolation)
.Build();
}
Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim) {
return DimScatterAddLikeOp(dim, UniqueOpName("dim_scatter_add_like"));
}
......
......@@ -180,6 +180,13 @@ Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth,
Maybe<one::UserOpExpr> PReLUGradOp();
Maybe<one::UserOpExpr> PReLUGradOp(const std::string& name);
Maybe<one::UserOpExpr> UpsampleGradOp(const float& height_scale, const float& width_scale,
const bool& align_corners, const std::string& data_format,
const std::string& interpolation);
Maybe<one::UserOpExpr> UpsampleGradOp(const float& height_scale, const float& width_scale,
const bool& align_corners, const std::string& data_format,
const std::string& interpolation, const std::string& name);
Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim);
Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim, const std::string& name);
Maybe<one::UserOpExpr> TransposeOp(const std::vector<int32_t>& perm);
......
......@@ -14,204 +14,282 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestUpsample2d(flow.unittest.TestCase):
def test_upsample2d(test_case):
input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.Upsample(scale_factor=2.0, mode="nearest")
of_out = m(input)
np_out = np.array(
def _test_upsample2d(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="nearest")
of_out = m(input)
np_out = np.array(
[
[
[
[
[1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0],
[3.0, 3.0, 4.0, 4.0],
]
[1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0],
[3.0, 3.0, 4.0, 4.0],
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_upsample2d_bilinear(test_case):
input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear")
of_out = m(input)
np_out = np.array(
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def _test_upsample2d_bilinear(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear")
of_out = m(input)
np_out = np.array(
[
[
[
[
[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000],
]
[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000],
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_upsample2d_bilinear_aligncorner(test_case):
input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True)
of_out = m(input)
np_out = np.array(
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def _test_upsample2d_bilinear_aligncorner(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True)
of_out = m(input)
np_out = np.array(
[
[
[
[
[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000],
]
[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000],
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-3, 1e-3))
def test_UpsamplingNearest2d(test_case):
input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.UpsamplingNearest2d(scale_factor=2.0)
of_out = m(input)
np_out = np.array(
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
def _test_UpsamplingNearest2d(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.UpsamplingNearest2d(scale_factor=2.0)
of_out = m(input)
np_out = np.array(
[
[
[
[
[1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0],
[3.0, 3.0, 4.0, 4.0],
]
[1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0],
[3.0, 3.0, 4.0, 4.0],
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_UpsamplingBilinear2d(test_case):
input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0)
of_out = m(input)
np_out = np.array(
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def _test_UpsamplingBilinear2d(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0)
of_out = m(input)
np_out = np.array(
[
[
[
[
[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000],
]
[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000],
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-3, 1e-3))
def test_upsample2d_4dim(test_case):
input = flow.Tensor(np.arange(1, 37).reshape((2, 2, 3, 3)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.Upsample(scale_factor=2.0, mode="nearest")
of_out = m(input)
np_out = np.array(
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
def _test_upsample2d_4dim(test_case, device):
input = flow.Tensor(
np.arange(1, 37).reshape((2, 2, 3, 3)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="nearest")
of_out = m(input)
np_out = np.array(
[
[
[
[
[1.0, 1.0, 2.0, 2.0, 3.0, 3.0,],
[1.0, 1.0, 2.0, 2.0, 3.0, 3.0,],
[4.0, 4.0, 5.0, 5.0, 6.0, 6.0,],
[4.0, 4.0, 5.0, 5.0, 6.0, 6.0,],
[7.0, 7.0, 8.0, 8.0, 9.0, 9.0,],
[7.0, 7.0, 8.0, 8.0, 9.0, 9.0,],
],
[
[10.0, 10.0, 11.0, 11.0, 12.0, 12.0,],
[10.0, 10.0, 11.0, 11.0, 12.0, 12.0,],
[13.0, 13.0, 14.0, 14.0, 15.0, 15.0,],
[13.0, 13.0, 14.0, 14.0, 15.0, 15.0,],
[16.0, 16.0, 17.0, 17.0, 18.0, 18.0,],
[16.0, 16.0, 17.0, 17.0, 18.0, 18.0,],
],
[1.0, 1.0, 2.0, 2.0, 3.0, 3.0,],
[1.0, 1.0, 2.0, 2.0, 3.0, 3.0,],
[4.0, 4.0, 5.0, 5.0, 6.0, 6.0,],
[4.0, 4.0, 5.0, 5.0, 6.0, 6.0,],
[7.0, 7.0, 8.0, 8.0, 9.0, 9.0,],
[7.0, 7.0, 8.0, 8.0, 9.0, 9.0,],
],
[
[
[19.0, 19.0, 20.0, 20.0, 21.0, 21.0,],
[19.0, 19.0, 20.0, 20.0, 21.0, 21.0,],
[22.0, 22.0, 23.0, 23.0, 24.0, 24.0,],
[22.0, 22.0, 23.0, 23.0, 24.0, 24.0,],
[25.0, 25.0, 26.0, 26.0, 27.0, 27.0,],
[25.0, 25.0, 26.0, 26.0, 27.0, 27.0,],
],
[
[28.0, 28.0, 29.0, 29.0, 30.0, 30.0,],
[28.0, 28.0, 29.0, 29.0, 30.0, 30.0,],
[31.0, 31.0, 32.0, 32.0, 33.0, 33.0,],
[31.0, 31.0, 32.0, 32.0, 33.0, 33.0,],
[34.0, 34.0, 35.0, 35.0, 36.0, 36.0,],
[34.0, 34.0, 35.0, 35.0, 36.0, 36.0,],
],
[10.0, 10.0, 11.0, 11.0, 12.0, 12.0,],
[10.0, 10.0, 11.0, 11.0, 12.0, 12.0,],
[13.0, 13.0, 14.0, 14.0, 15.0, 15.0,],
[13.0, 13.0, 14.0, 14.0, 15.0, 15.0,],
[16.0, 16.0, 17.0, 17.0, 18.0, 18.0,],
[16.0, 16.0, 17.0, 17.0, 18.0, 18.0,],
],
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_upsample2d_bilinear_4dim(test_case):
input = flow.Tensor(np.arange(1, 37).reshape((2, 2, 3, 3)), dtype=flow.float32)
input = input.to("cuda")
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear")
of_out = m(input)
np_out = np.array(
],
[
[
[
[1.0, 1.25, 1.75, 2.25, 2.75, 3.0],
[1.75, 2.0, 2.5, 3.0, 3.5, 3.75],
[3.25, 3.5, 4.0, 4.5, 5.0, 5.25],
[4.75, 5.0, 5.5, 6.0, 6.5, 6.75],
[6.25, 6.5, 7.0, 7.5, 8.0, 8.25],
[7.0, 7.25, 7.75, 8.25, 8.75, 9.0],
],
[
[10.0, 10.25, 10.75, 11.25, 11.75, 12.0],
[10.75, 11.0, 11.5, 12.0, 12.5, 12.75],
[12.25, 12.5, 13.0, 13.5, 14.0, 14.25],
[13.75, 14.0, 14.5, 15.0, 15.5, 15.75],
[15.25, 15.5, 16.0, 16.5, 17.0, 17.25],
[16.0, 16.25, 16.75, 17.25, 17.75, 18.0],
],
[19.0, 19.0, 20.0, 20.0, 21.0, 21.0,],
[19.0, 19.0, 20.0, 20.0, 21.0, 21.0,],
[22.0, 22.0, 23.0, 23.0, 24.0, 24.0,],
[22.0, 22.0, 23.0, 23.0, 24.0, 24.0,],
[25.0, 25.0, 26.0, 26.0, 27.0, 27.0,],
[25.0, 25.0, 26.0, 26.0, 27.0, 27.0,],
],
[
[
[19.0, 19.25, 19.75, 20.25, 20.75, 21.0],
[19.75, 20.0, 20.5, 21.0, 21.5, 21.75],
[21.25, 21.5, 22.0, 22.5, 23.0, 23.25],
[22.75, 23.0, 23.5, 24.0, 24.5, 24.75],
[24.25, 24.5, 25.0, 25.5, 26.0, 26.25],
[25.0, 25.25, 25.75, 26.25, 26.75, 27.0],
],
[
[28.0, 28.25, 28.75, 29.25, 29.75, 30.0],
[28.75, 29.0, 29.5, 30.0, 30.5, 30.75],
[30.25, 30.5, 31.0, 31.5, 32.0, 32.25],
[31.75, 32.0, 32.5, 33.0, 33.5, 33.75],
[33.25, 33.5, 34.0, 34.5, 35.0, 35.25],
[34.0, 34.25, 34.75, 35.25, 35.75, 36.0],
],
[28.0, 28.0, 29.0, 29.0, 30.0, 30.0,],
[28.0, 28.0, 29.0, 29.0, 30.0, 30.0,],
[31.0, 31.0, 32.0, 32.0, 33.0, 33.0,],
[31.0, 31.0, 32.0, 32.0, 33.0, 33.0,],
[34.0, 34.0, 35.0, 35.0, 36.0, 36.0,],
[34.0, 34.0, 35.0, 35.0, 36.0, 36.0,],
],
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
],
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def _test_upsample2d_bilinear_4dim(test_case, device):
input = flow.Tensor(
np.arange(1, 37).reshape((2, 2, 3, 3)),
device=flow.device(device),
dtype=flow.float32,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear")
of_out = m(input)
np_out = np.array(
[
[
[
[1.0, 1.25, 1.75, 2.25, 2.75, 3.0],
[1.75, 2.0, 2.5, 3.0, 3.5, 3.75],
[3.25, 3.5, 4.0, 4.5, 5.0, 5.25],
[4.75, 5.0, 5.5, 6.0, 6.5, 6.75],
[6.25, 6.5, 7.0, 7.5, 8.0, 8.25],
[7.0, 7.25, 7.75, 8.25, 8.75, 9.0],
],
[
[10.0, 10.25, 10.75, 11.25, 11.75, 12.0],
[10.75, 11.0, 11.5, 12.0, 12.5, 12.75],
[12.25, 12.5, 13.0, 13.5, 14.0, 14.25],
[13.75, 14.0, 14.5, 15.0, 15.5, 15.75],
[15.25, 15.5, 16.0, 16.5, 17.0, 17.25],
[16.0, 16.25, 16.75, 17.25, 17.75, 18.0],
],
],
[
[
[19.0, 19.25, 19.75, 20.25, 20.75, 21.0],
[19.75, 20.0, 20.5, 21.0, 21.5, 21.75],
[21.25, 21.5, 22.0, 22.5, 23.0, 23.25],
[22.75, 23.0, 23.5, 24.0, 24.5, 24.75],
[24.25, 24.5, 25.0, 25.5, 26.0, 26.25],
[25.0, 25.25, 25.75, 26.25, 26.75, 27.0],
],
[
[28.0, 28.25, 28.75, 29.25, 29.75, 30.0],
[28.75, 29.0, 29.5, 30.0, 30.5, 30.75],
[30.25, 30.5, 31.0, 31.5, 32.0, 32.25],
[31.75, 32.0, 32.5, 33.0, 33.5, 33.75],
[33.25, 33.5, 34.0, 34.5, 35.0, 35.25],
[34.0, 34.25, 34.75, 35.25, 35.75, 36.0],
],
],
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def _test_upsample2d_backward(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="nearest")
of_out = m(input)
of_out = of_out.sum()
of_out.backward()
np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
def _test_upsample2d_bilinear_aligncorner_backward(test_case, device):
input = flow.Tensor(
np.arange(1, 5).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True)
of_out = m(input)
of_out = of_out.sum()
of_out.backward()
np_grad = [[[[3.999999523162842, 4.000000476837158], [3.999999761581421, 4.0]]]]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestUpsample2d(flow.unittest.TestCase):
def test_upsample2d(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_upsample2d,
_test_upsample2d_bilinear,
_test_upsample2d_bilinear_aligncorner,
_test_UpsamplingNearest2d,
_test_UpsamplingBilinear2d,
_test_upsample2d_4dim,
_test_upsample2d_bilinear_4dim,
_test_upsample2d_backward,
_test_upsample2d_bilinear_aligncorner_backward,
]
arg_dict["device"] = ["cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册