diff --git a/oneflow/core/autograd/gradient_funcs/upsample.cpp b/oneflow/core/autograd/gradient_funcs/upsample.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c44f0fa5918c98e289e2458dcd4e705d42ba4e25 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/upsample.cpp @@ -0,0 +1,95 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" + +namespace oneflow { +namespace one { + +struct UpsampleInterpState : public OpExprInterpState { + bool requires_grad; + float height_scale; + float width_scale; + float align_corners; + std::string data_format; + std::string interpolation; +}; + +class Upsample : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override; + Maybe Capture(UpsampleInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const UpsampleInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; + std::shared_ptr grad_op_; +}; + +Maybe Upsample::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + const std::string& op_name = fw_op_expr->op_name(); + const float height_scale = 1.0; + const float width_scale = 1.0; + const bool align_corners = false; + const std::string data_format = "NCHW"; + const std::string interpolation = "nearest"; + grad_op_ = + JUST(op_expr_helper::UpsampleGradOp(height_scale, width_scale, align_corners, data_format, + interpolation, GradientOpName(op_name))); + return Maybe::Ok(); +} + +Maybe Upsample::Capture(UpsampleInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); + ctx->interpolation = JUST(composed_attrs.GetAttr("interpolation")); + return Maybe::Ok(); +} + +Maybe Upsample::Apply(const UpsampleInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if (!ctx->requires_grad) { return Maybe::Ok(); } + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + + MutableAttrMap attrs; + JUST(attrs.SetAttr("height_scale", ctx->height_scale)); + JUST(attrs.SetAttr("width_scale", ctx->width_scale)); + JUST(attrs.SetAttr("align_corners", ctx->align_corners)); + JUST(attrs.SetAttr("data_format", ctx->data_format)); + JUST(attrs.SetAttr("interpolation", ctx->interpolation)); + in_grads->resize(1); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0)}, attrs)); + return Maybe::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index 2c469154fd84250ab4b18c92a73a1ed1fe5ce666..7a9f8cb12a7d4ae6e397819005bb9800a2b8f209 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -547,6 +547,26 @@ Maybe PReLUGradOp(const std::string& name) { .Build(); } +Maybe UpsampleGradOp(const float& height_scale, const float& width_scale, + const bool& align_corners, const std::string& data_format, + const std::string& interpolation) { + return UpsampleGradOp(height_scale, width_scale, align_corners, data_format, interpolation, + UniqueOpName("upsample_grad")); +} +Maybe UpsampleGradOp(const float& height_scale, const float& width_scale, + const bool& align_corners, const std::string& data_format, + const std::string& interpolation, const std::string& name) { + return one::OpBuilder("upsample_grad", name) + .Input("dy") + .Output("dx") + .Attr("height_scale", height_scale) + .Attr("width_scale", width_scale) + .Attr("align_corners", align_corners) + .Attr("data_format", data_format) + .Attr("interpolation", interpolation) + .Build(); +} + Maybe DimScatterAddLikeOp(const int32_t dim) { return DimScatterAddLikeOp(dim, UniqueOpName("dim_scatter_add_like")); } diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h index 13cb4782c5b5aad93382628b8b2731cf99704076..49a564f2b65170cef6eda00b2bab9d5e12dd291c 100644 --- a/oneflow/core/framework/op_expr_helper.h +++ b/oneflow/core/framework/op_expr_helper.h @@ -180,6 +180,13 @@ Maybe SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth, Maybe PReLUGradOp(); Maybe PReLUGradOp(const std::string& name); +Maybe UpsampleGradOp(const float& height_scale, const float& width_scale, + const bool& align_corners, const std::string& data_format, + const std::string& interpolation); +Maybe UpsampleGradOp(const float& height_scale, const float& width_scale, + const bool& align_corners, const std::string& data_format, + const std::string& interpolation, const std::string& name); + Maybe DimScatterAddLikeOp(const int32_t dim); Maybe DimScatterAddLikeOp(const int32_t dim, const std::string& name); Maybe TransposeOp(const std::vector& perm); diff --git a/oneflow/python/test/modules/test_upsample2d.py b/oneflow/python/test/modules/test_upsample2d.py index b9bf791e111b4fec3d698d0c0ac9f6ad1dc5532b..051390b36ceb64f5828d82ad9f410bedb783af8d 100644 --- a/oneflow/python/test/modules/test_upsample2d.py +++ b/oneflow/python/test/modules/test_upsample2d.py @@ -14,204 +14,282 @@ See the License for the specific language governing permissions and limitations under the License. """ import unittest +from collections import OrderedDict import numpy as np + import oneflow.experimental as flow +from test_util import GenArgList -@unittest.skipIf( - not flow.unittest.env.eager_execution_enabled(), - ".numpy() doesn't work in lazy mode", -) -class TestUpsample2d(flow.unittest.TestCase): - def test_upsample2d(test_case): - input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") - of_out = m(input) - np_out = np.array( +def _test_upsample2d(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") + of_out = m(input) + np_out = np.array( + [ [ [ - [ - [1.0, 1.0, 2.0, 2.0], - [1.0, 1.0, 2.0, 2.0], - [3.0, 3.0, 4.0, 4.0], - [3.0, 3.0, 4.0, 4.0], - ] + [1.0, 1.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0], + [3.0, 3.0, 4.0, 4.0], + [3.0, 3.0, 4.0, 4.0], ] ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) - - def test_upsample2d_bilinear(test_case): - input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear") - of_out = m(input) - np_out = np.array( + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_upsample2d_bilinear(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear") + of_out = m(input) + np_out = np.array( + [ [ [ - [ - [1.0000, 1.2500, 1.7500, 2.0000], - [1.5000, 1.7500, 2.2500, 2.5000], - [2.5000, 2.7500, 3.2500, 3.5000], - [3.0000, 3.2500, 3.7500, 4.0000], - ] + [1.0000, 1.2500, 1.7500, 2.0000], + [1.5000, 1.7500, 2.2500, 2.5000], + [2.5000, 2.7500, 3.2500, 3.5000], + [3.0000, 3.2500, 3.7500, 4.0000], ] ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) - - def test_upsample2d_bilinear_aligncorner(test_case): - input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) - of_out = m(input) - np_out = np.array( + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_upsample2d_bilinear_aligncorner(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) + of_out = m(input) + np_out = np.array( + [ [ [ - [ - [1.0000, 1.3333, 1.6667, 2.0000], - [1.6667, 2.0000, 2.3333, 2.6667], - [2.3333, 2.6667, 3.0000, 3.3333], - [3.0000, 3.3333, 3.6667, 4.0000], - ] + [1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000], ] ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-3, 1e-3)) - - def test_UpsamplingNearest2d(test_case): - input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.UpsamplingNearest2d(scale_factor=2.0) - of_out = m(input) - np_out = np.array( + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +def _test_UpsamplingNearest2d(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.UpsamplingNearest2d(scale_factor=2.0) + of_out = m(input) + np_out = np.array( + [ [ [ - [ - [1.0, 1.0, 2.0, 2.0], - [1.0, 1.0, 2.0, 2.0], - [3.0, 3.0, 4.0, 4.0], - [3.0, 3.0, 4.0, 4.0], - ] + [1.0, 1.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0], + [3.0, 3.0, 4.0, 4.0], + [3.0, 3.0, 4.0, 4.0], ] ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) - - def test_UpsamplingBilinear2d(test_case): - input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) - of_out = m(input) - np_out = np.array( + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_UpsamplingBilinear2d(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) + of_out = m(input) + np_out = np.array( + [ [ [ - [ - [1.0000, 1.3333, 1.6667, 2.0000], - [1.6667, 2.0000, 2.3333, 2.6667], - [2.3333, 2.6667, 3.0000, 3.3333], - [3.0000, 3.3333, 3.6667, 4.0000], - ] + [1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000], ] ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-3, 1e-3)) - - def test_upsample2d_4dim(test_case): - input = flow.Tensor(np.arange(1, 37).reshape((2, 2, 3, 3)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") - of_out = m(input) - np_out = np.array( + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +def _test_upsample2d_4dim(test_case, device): + input = flow.Tensor( + np.arange(1, 37).reshape((2, 2, 3, 3)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") + of_out = m(input) + np_out = np.array( + [ [ [ - [ - [1.0, 1.0, 2.0, 2.0, 3.0, 3.0,], - [1.0, 1.0, 2.0, 2.0, 3.0, 3.0,], - [4.0, 4.0, 5.0, 5.0, 6.0, 6.0,], - [4.0, 4.0, 5.0, 5.0, 6.0, 6.0,], - [7.0, 7.0, 8.0, 8.0, 9.0, 9.0,], - [7.0, 7.0, 8.0, 8.0, 9.0, 9.0,], - ], - [ - [10.0, 10.0, 11.0, 11.0, 12.0, 12.0,], - [10.0, 10.0, 11.0, 11.0, 12.0, 12.0,], - [13.0, 13.0, 14.0, 14.0, 15.0, 15.0,], - [13.0, 13.0, 14.0, 14.0, 15.0, 15.0,], - [16.0, 16.0, 17.0, 17.0, 18.0, 18.0,], - [16.0, 16.0, 17.0, 17.0, 18.0, 18.0,], - ], + [1.0, 1.0, 2.0, 2.0, 3.0, 3.0,], + [1.0, 1.0, 2.0, 2.0, 3.0, 3.0,], + [4.0, 4.0, 5.0, 5.0, 6.0, 6.0,], + [4.0, 4.0, 5.0, 5.0, 6.0, 6.0,], + [7.0, 7.0, 8.0, 8.0, 9.0, 9.0,], + [7.0, 7.0, 8.0, 8.0, 9.0, 9.0,], ], [ - [ - [19.0, 19.0, 20.0, 20.0, 21.0, 21.0,], - [19.0, 19.0, 20.0, 20.0, 21.0, 21.0,], - [22.0, 22.0, 23.0, 23.0, 24.0, 24.0,], - [22.0, 22.0, 23.0, 23.0, 24.0, 24.0,], - [25.0, 25.0, 26.0, 26.0, 27.0, 27.0,], - [25.0, 25.0, 26.0, 26.0, 27.0, 27.0,], - ], - [ - [28.0, 28.0, 29.0, 29.0, 30.0, 30.0,], - [28.0, 28.0, 29.0, 29.0, 30.0, 30.0,], - [31.0, 31.0, 32.0, 32.0, 33.0, 33.0,], - [31.0, 31.0, 32.0, 32.0, 33.0, 33.0,], - [34.0, 34.0, 35.0, 35.0, 36.0, 36.0,], - [34.0, 34.0, 35.0, 35.0, 36.0, 36.0,], - ], + [10.0, 10.0, 11.0, 11.0, 12.0, 12.0,], + [10.0, 10.0, 11.0, 11.0, 12.0, 12.0,], + [13.0, 13.0, 14.0, 14.0, 15.0, 15.0,], + [13.0, 13.0, 14.0, 14.0, 15.0, 15.0,], + [16.0, 16.0, 17.0, 17.0, 18.0, 18.0,], + [16.0, 16.0, 17.0, 17.0, 18.0, 18.0,], ], - ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) - - def test_upsample2d_bilinear_4dim(test_case): - input = flow.Tensor(np.arange(1, 37).reshape((2, 2, 3, 3)), dtype=flow.float32) - input = input.to("cuda") - m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear") - of_out = m(input) - np_out = np.array( + ], [ [ - [ - [1.0, 1.25, 1.75, 2.25, 2.75, 3.0], - [1.75, 2.0, 2.5, 3.0, 3.5, 3.75], - [3.25, 3.5, 4.0, 4.5, 5.0, 5.25], - [4.75, 5.0, 5.5, 6.0, 6.5, 6.75], - [6.25, 6.5, 7.0, 7.5, 8.0, 8.25], - [7.0, 7.25, 7.75, 8.25, 8.75, 9.0], - ], - [ - [10.0, 10.25, 10.75, 11.25, 11.75, 12.0], - [10.75, 11.0, 11.5, 12.0, 12.5, 12.75], - [12.25, 12.5, 13.0, 13.5, 14.0, 14.25], - [13.75, 14.0, 14.5, 15.0, 15.5, 15.75], - [15.25, 15.5, 16.0, 16.5, 17.0, 17.25], - [16.0, 16.25, 16.75, 17.25, 17.75, 18.0], - ], + [19.0, 19.0, 20.0, 20.0, 21.0, 21.0,], + [19.0, 19.0, 20.0, 20.0, 21.0, 21.0,], + [22.0, 22.0, 23.0, 23.0, 24.0, 24.0,], + [22.0, 22.0, 23.0, 23.0, 24.0, 24.0,], + [25.0, 25.0, 26.0, 26.0, 27.0, 27.0,], + [25.0, 25.0, 26.0, 26.0, 27.0, 27.0,], ], [ - [ - [19.0, 19.25, 19.75, 20.25, 20.75, 21.0], - [19.75, 20.0, 20.5, 21.0, 21.5, 21.75], - [21.25, 21.5, 22.0, 22.5, 23.0, 23.25], - [22.75, 23.0, 23.5, 24.0, 24.5, 24.75], - [24.25, 24.5, 25.0, 25.5, 26.0, 26.25], - [25.0, 25.25, 25.75, 26.25, 26.75, 27.0], - ], - [ - [28.0, 28.25, 28.75, 29.25, 29.75, 30.0], - [28.75, 29.0, 29.5, 30.0, 30.5, 30.75], - [30.25, 30.5, 31.0, 31.5, 32.0, 32.25], - [31.75, 32.0, 32.5, 33.0, 33.5, 33.75], - [33.25, 33.5, 34.0, 34.5, 35.0, 35.25], - [34.0, 34.25, 34.75, 35.25, 35.75, 36.0], - ], + [28.0, 28.0, 29.0, 29.0, 30.0, 30.0,], + [28.0, 28.0, 29.0, 29.0, 30.0, 30.0,], + [31.0, 31.0, 32.0, 32.0, 33.0, 33.0,], + [31.0, 31.0, 32.0, 32.0, 33.0, 33.0,], + [34.0, 34.0, 35.0, 35.0, 36.0, 36.0,], + [34.0, 34.0, 35.0, 35.0, 36.0, 36.0,], ], - ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + ], + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_upsample2d_bilinear_4dim(test_case, device): + input = flow.Tensor( + np.arange(1, 37).reshape((2, 2, 3, 3)), + device=flow.device(device), + dtype=flow.float32, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear") + of_out = m(input) + np_out = np.array( + [ + [ + [ + [1.0, 1.25, 1.75, 2.25, 2.75, 3.0], + [1.75, 2.0, 2.5, 3.0, 3.5, 3.75], + [3.25, 3.5, 4.0, 4.5, 5.0, 5.25], + [4.75, 5.0, 5.5, 6.0, 6.5, 6.75], + [6.25, 6.5, 7.0, 7.5, 8.0, 8.25], + [7.0, 7.25, 7.75, 8.25, 8.75, 9.0], + ], + [ + [10.0, 10.25, 10.75, 11.25, 11.75, 12.0], + [10.75, 11.0, 11.5, 12.0, 12.5, 12.75], + [12.25, 12.5, 13.0, 13.5, 14.0, 14.25], + [13.75, 14.0, 14.5, 15.0, 15.5, 15.75], + [15.25, 15.5, 16.0, 16.5, 17.0, 17.25], + [16.0, 16.25, 16.75, 17.25, 17.75, 18.0], + ], + ], + [ + [ + [19.0, 19.25, 19.75, 20.25, 20.75, 21.0], + [19.75, 20.0, 20.5, 21.0, 21.5, 21.75], + [21.25, 21.5, 22.0, 22.5, 23.0, 23.25], + [22.75, 23.0, 23.5, 24.0, 24.5, 24.75], + [24.25, 24.5, 25.0, 25.5, 26.0, 26.25], + [25.0, 25.25, 25.75, 26.25, 26.75, 27.0], + ], + [ + [28.0, 28.25, 28.75, 29.25, 29.75, 30.0], + [28.75, 29.0, 29.5, 30.0, 30.5, 30.75], + [30.25, 30.5, 31.0, 31.5, 32.0, 32.25], + [31.75, 32.0, 32.5, 33.0, 33.5, 33.75], + [33.25, 33.5, 34.0, 34.5, 35.0, 35.25], + [34.0, 34.25, 34.75, 35.25, 35.75, 36.0], + ], + ], + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_upsample2d_backward(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="nearest") + of_out = m(input) + of_out = of_out.sum() + of_out.backward() + np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]] + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + +def _test_upsample2d_bilinear_aligncorner_backward(test_case, device): + input = flow.Tensor( + np.arange(1, 5).reshape((1, 1, 2, 2)), + device=flow.device(device), + dtype=flow.float32, + requires_grad=True, + ) + m = flow.nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) + of_out = m(input) + of_out = of_out.sum() + of_out.backward() + np_grad = [[[[3.999999523162842, 4.000000476837158], [3.999999761581421, 4.0]]]] + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestUpsample2d(flow.unittest.TestCase): + def test_upsample2d(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_upsample2d, + _test_upsample2d_bilinear, + _test_upsample2d_bilinear_aligncorner, + _test_UpsamplingNearest2d, + _test_UpsamplingBilinear2d, + _test_upsample2d_4dim, + _test_upsample2d_bilinear_4dim, + _test_upsample2d_backward, + _test_upsample2d_bilinear_aligncorner_backward, + ] + arg_dict["device"] = ["cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) if __name__ == "__main__":