未验证 提交 4c82e455 编写于 作者: W WangZhen 提交者: GitHub

[Cherry-Pick][Dy2St]Support call backward() without params in dy2st (#49812) (#50144)

* [Dy2St]Support call backward() without params in dy2st (#49812)

* Support call backward() without params in dy2st

* format code

* format code
上级 8c5e432b
......@@ -93,7 +93,8 @@ class SelectOutputInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
OP_INOUT_CHECK(context->HasOutputs("Out"), "Output", "Out", "SelectOutput");
OP_INOUT_CHECK(
context->HasOutputs("Out", true), "Output", "Out", "SelectOutput");
}
};
......
......@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self,
PyObject *kwargs) {
PyThreadState *tstate = nullptr;
try {
auto X = GetTensorListFromArgs("run_program", "X", args, 0, false);
auto X = GetTensorListFromArgs("run_program", "X", args, 0, true);
auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true);
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, false);
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true);
auto OutScope =
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
@paddle.jit.to_static
def forward(self, x):
out = x + 1
return out
class TestBackwardWithoutParams(unittest.TestCase):
def test_run(self):
net = Net()
x = paddle.ones([2, 2])
x.stop_gradient = False
out = net(x)
loss = paddle.mean(out)
loss.backward()
np.testing.assert_equal(x.grad.numpy(), np.full(x.shape, 0.25))
if __name__ == '__main__':
unittest.main()
......@@ -292,7 +292,6 @@ def for_tuple_as_enumerate_value(x_array):
# 20. test for function in a class
class ForwardContainsForLayer(paddle.nn.Layer):
def __init__(self):
super(ForwardContainsForLayer, self).__init__()
self.high = 5
......@@ -328,8 +327,8 @@ def for_original_tuple():
# 23. for zip error
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 10]),
InputSpec(shape=[None, 10])])
input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])]
)
def for_zip_error(x, y):
for i, j in zip(x, y):
a = i + j
......@@ -338,8 +337,8 @@ def for_zip_error(x, y):
# 24. for zip
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[2, 10]),
InputSpec(shape=[2, 10])])
input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])]
)
def for_zip(x, y):
for i, j in zip(x, y):
a = i + j
......@@ -347,10 +346,12 @@ def for_zip(x, y):
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
self.place = (
fluid.CUDAPlace(0)
if fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
)
self.set_input()
self.set_test_func()
......@@ -359,7 +360,8 @@ class TestTransformBase(unittest.TestCase):
def set_test_func(self):
raise NotImplementedError(
"For Enumerate test should implement set_test_func")
"For Enumerate test should implement set_test_func"
)
def _run(self, to_static):
program_translator.enable(to_static)
......@@ -374,22 +376,21 @@ class TestTransformBase(unittest.TestCase):
class TestTransform(TestTransformBase):
def transformed_result_compare(self):
dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs, )
dy_outs = (dy_outs,)
self.dygraph_func.eval()
st_outs = self.get_static_output()
if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs, )
st_outs = (st_outs,)
for x, y in zip(dy_outs, st_outs):
np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05)
class TestTransformForOriginalList(TestTransform):
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard():
......@@ -397,7 +398,6 @@ class TestTransformForOriginalList(TestTransform):
class TestTransformError(TestTransformBase):
def transformed_error(self, etype):
with self.assertRaises(etype):
dy_out = self.get_dygraph_output()
......@@ -405,7 +405,6 @@ class TestTransformError(TestTransformBase):
class TestForInRange(TestTransform):
def set_input(self):
self.input = np.array([5])
......@@ -417,7 +416,6 @@ class TestForInRange(TestTransform):
class TestForIterList(TestTransform):
def set_test_func(self):
self.dygraph_func = for_iter_list
......@@ -426,19 +424,16 @@ class TestForIterList(TestTransform):
class TestForEnumerateSimple(TestForIterList):
def set_test_func(self):
self.dygraph_func = for_enumerate_list
class TestForInRangeWithBreak(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_in_range_with_break
class TestForIterVarNumpy(TestTransform):
def set_input(self):
self.input = np.array([1, 2, 3, 4, 5])
......@@ -450,103 +445,86 @@ class TestForIterVarNumpy(TestTransform):
class TestForEnumerateVarNumpy(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy
class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_break
class TestForEnumerateVarNumpyWithContinue(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_continue
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start_break
class TestForEnumerateVarNumpyWithStartAndContinue(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start_continue
class TestForIterVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_iter_var
class TestForIterVarIdx(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_iter_var_idx
class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var
class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_with_nested_range
class TestForIterVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_iter_var_list
class TestForEnumerateVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_list
class TestForTupleAsIterVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_tuple_as_iter_var
class TestForTupleAsEnumerateIter(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_iter
class TestForTupleAsEnumerateValue(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_value
class TestForwardContainsForLayer(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = ForwardContainsForLayer()
class TestForOriginalList(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_list
......@@ -555,7 +533,6 @@ class TestForOriginalList(TestTransformForOriginalList):
class TestForOriginalTuple(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_tuple
......@@ -564,7 +541,6 @@ class TestForOriginalTuple(TestTransformForOriginalList):
class TestForZip(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册