未验证 提交 4c82e455 编写于 作者: W WangZhen 提交者: GitHub

[Cherry-Pick][Dy2St]Support call backward() without params in dy2st (#49812) (#50144)

* [Dy2St]Support call backward() without params in dy2st (#49812)

* Support call backward() without params in dy2st

* format code

* format code
上级 8c5e432b
...@@ -93,7 +93,8 @@ class SelectOutputInferShape : public framework::InferShapeBase { ...@@ -93,7 +93,8 @@ class SelectOutputInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput"); OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput"); OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
OP_INOUT_CHECK(context->HasOutputs("Out"), "Output", "Out", "SelectOutput"); OP_INOUT_CHECK(
context->HasOutputs("Out", true), "Output", "Out", "SelectOutput");
} }
}; };
......
...@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self, ...@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self,
PyObject *kwargs) { PyObject *kwargs) {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try { try {
auto X = GetTensorListFromArgs("run_program", "X", args, 0, false); auto X = GetTensorListFromArgs("run_program", "X", args, 0, true);
auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true); auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true);
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, false); auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true);
auto OutScope = auto OutScope =
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
@paddle.jit.to_static
def forward(self, x):
out = x + 1
return out
class TestBackwardWithoutParams(unittest.TestCase):
def test_run(self):
net = Net()
x = paddle.ones([2, 2])
x.stop_gradient = False
out = net(x)
loss = paddle.mean(out)
loss.backward()
np.testing.assert_equal(x.grad.numpy(), np.full(x.shape, 0.25))
if __name__ == '__main__':
unittest.main()
...@@ -292,7 +292,6 @@ def for_tuple_as_enumerate_value(x_array): ...@@ -292,7 +292,6 @@ def for_tuple_as_enumerate_value(x_array):
# 20. test for function in a class # 20. test for function in a class
class ForwardContainsForLayer(paddle.nn.Layer): class ForwardContainsForLayer(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(ForwardContainsForLayer, self).__init__() super(ForwardContainsForLayer, self).__init__()
self.high = 5 self.high = 5
...@@ -328,8 +327,8 @@ def for_original_tuple(): ...@@ -328,8 +327,8 @@ def for_original_tuple():
# 23. for zip error # 23. for zip error
@paddle.jit.to_static( @paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 10]), input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])]
InputSpec(shape=[None, 10])]) )
def for_zip_error(x, y): def for_zip_error(x, y):
for i, j in zip(x, y): for i, j in zip(x, y):
a = i + j a = i + j
...@@ -338,8 +337,8 @@ def for_zip_error(x, y): ...@@ -338,8 +337,8 @@ def for_zip_error(x, y):
# 24. for zip # 24. for zip
@paddle.jit.to_static( @paddle.jit.to_static(
input_spec=[InputSpec(shape=[2, 10]), input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])]
InputSpec(shape=[2, 10])]) )
def for_zip(x, y): def for_zip(x, y):
for i, j in zip(x, y): for i, j in zip(x, y):
a = i + j a = i + j
...@@ -347,10 +346,12 @@ def for_zip(x, y): ...@@ -347,10 +346,12 @@ def for_zip(x, y):
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace( self.place = (
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() fluid.CUDAPlace(0)
if fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
)
self.set_input() self.set_input()
self.set_test_func() self.set_test_func()
...@@ -359,7 +360,8 @@ class TestTransformBase(unittest.TestCase): ...@@ -359,7 +360,8 @@ class TestTransformBase(unittest.TestCase):
def set_test_func(self): def set_test_func(self):
raise NotImplementedError( raise NotImplementedError(
"For Enumerate test should implement set_test_func") "For Enumerate test should implement set_test_func"
)
def _run(self, to_static): def _run(self, to_static):
program_translator.enable(to_static) program_translator.enable(to_static)
...@@ -374,22 +376,21 @@ class TestTransformBase(unittest.TestCase): ...@@ -374,22 +376,21 @@ class TestTransformBase(unittest.TestCase):
class TestTransform(TestTransformBase): class TestTransform(TestTransformBase):
def transformed_result_compare(self): def transformed_result_compare(self):
dy_outs = self.get_dygraph_output() dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, (tuple, list)): if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs, ) dy_outs = (dy_outs,)
self.dygraph_func.eval()
st_outs = self.get_static_output() st_outs = self.get_static_output()
if not isinstance(st_outs, (tuple, list)): if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs, ) st_outs = (st_outs,)
for x, y in zip(dy_outs, st_outs): for x, y in zip(dy_outs, st_outs):
np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05) np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05)
class TestTransformForOriginalList(TestTransform): class TestTransformForOriginalList(TestTransform):
def _run(self, to_static): def _run(self, to_static):
program_translator.enable(to_static) program_translator.enable(to_static)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -397,7 +398,6 @@ class TestTransformForOriginalList(TestTransform): ...@@ -397,7 +398,6 @@ class TestTransformForOriginalList(TestTransform):
class TestTransformError(TestTransformBase): class TestTransformError(TestTransformBase):
def transformed_error(self, etype): def transformed_error(self, etype):
with self.assertRaises(etype): with self.assertRaises(etype):
dy_out = self.get_dygraph_output() dy_out = self.get_dygraph_output()
...@@ -405,7 +405,6 @@ class TestTransformError(TestTransformBase): ...@@ -405,7 +405,6 @@ class TestTransformError(TestTransformBase):
class TestForInRange(TestTransform): class TestForInRange(TestTransform):
def set_input(self): def set_input(self):
self.input = np.array([5]) self.input = np.array([5])
...@@ -417,7 +416,6 @@ class TestForInRange(TestTransform): ...@@ -417,7 +416,6 @@ class TestForInRange(TestTransform):
class TestForIterList(TestTransform): class TestForIterList(TestTransform):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_list self.dygraph_func = for_iter_list
...@@ -426,19 +424,16 @@ class TestForIterList(TestTransform): ...@@ -426,19 +424,16 @@ class TestForIterList(TestTransform):
class TestForEnumerateSimple(TestForIterList): class TestForEnumerateSimple(TestForIterList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_list self.dygraph_func = for_enumerate_list
class TestForInRangeWithBreak(TestForInRange): class TestForInRangeWithBreak(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_in_range_with_break self.dygraph_func = for_in_range_with_break
class TestForIterVarNumpy(TestTransform): class TestForIterVarNumpy(TestTransform):
def set_input(self): def set_input(self):
self.input = np.array([1, 2, 3, 4, 5]) self.input = np.array([1, 2, 3, 4, 5])
...@@ -450,103 +445,86 @@ class TestForIterVarNumpy(TestTransform): ...@@ -450,103 +445,86 @@ class TestForIterVarNumpy(TestTransform):
class TestForEnumerateVarNumpy(TestForIterVarNumpy): class TestForEnumerateVarNumpy(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy self.dygraph_func = for_enumerate_var_numpy
class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start self.dygraph_func = for_enumerate_var_numpy_with_start
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_break self.dygraph_func = for_enumerate_var_numpy_with_break
class TestForEnumerateVarNumpyWithContinue(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithContinue(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_continue self.dygraph_func = for_enumerate_var_numpy_with_continue
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start_break self.dygraph_func = for_enumerate_var_numpy_with_start_break
class TestForEnumerateVarNumpyWithStartAndContinue(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStartAndContinue(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start_continue self.dygraph_func = for_enumerate_var_numpy_with_start_continue
class TestForIterVar(TestForIterVarNumpy): class TestForIterVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var self.dygraph_func = for_iter_var
class TestForIterVarIdx(TestForIterVarNumpy): class TestForIterVarIdx(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var_idx self.dygraph_func = for_iter_var_idx
class TestForEnumerateVar(TestForIterVarNumpy): class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var self.dygraph_func = for_enumerate_var
class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy): class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_with_nested_range self.dygraph_func = for_enumerate_var_with_nested_range
class TestForIterVarList(TestForInRange): class TestForIterVarList(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var_list self.dygraph_func = for_iter_var_list
class TestForEnumerateVarList(TestForInRange): class TestForEnumerateVarList(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_list self.dygraph_func = for_enumerate_var_list
class TestForTupleAsIterVar(TestForIterVarNumpy): class TestForTupleAsIterVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_tuple_as_iter_var self.dygraph_func = for_tuple_as_iter_var
class TestForTupleAsEnumerateIter(TestForIterVarNumpy): class TestForTupleAsEnumerateIter(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_iter self.dygraph_func = for_tuple_as_enumerate_iter
class TestForTupleAsEnumerateValue(TestForIterVarNumpy): class TestForTupleAsEnumerateValue(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_value self.dygraph_func = for_tuple_as_enumerate_value
class TestForwardContainsForLayer(TestForIterVarNumpy): class TestForwardContainsForLayer(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = ForwardContainsForLayer() self.dygraph_func = ForwardContainsForLayer()
class TestForOriginalList(TestTransformForOriginalList): class TestForOriginalList(TestTransformForOriginalList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_original_list self.dygraph_func = for_original_list
...@@ -555,7 +533,6 @@ class TestForOriginalList(TestTransformForOriginalList): ...@@ -555,7 +533,6 @@ class TestForOriginalList(TestTransformForOriginalList):
class TestForOriginalTuple(TestTransformForOriginalList): class TestForOriginalTuple(TestTransformForOriginalList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_original_tuple self.dygraph_func = for_original_tuple
...@@ -564,7 +541,6 @@ class TestForOriginalTuple(TestTransformForOriginalList): ...@@ -564,7 +541,6 @@ class TestForOriginalTuple(TestTransformForOriginalList):
class TestForZip(unittest.TestCase): class TestForZip(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册