未验证 提交 21d94dd3 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support test_diff_op switch to eager mode (#42360)

上级 05d6be7e
...@@ -553,7 +553,13 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -553,7 +553,13 @@ std::vector<paddle::experimental::Tensor> RunBackward(
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
const paddle::experimental::Tensor& tensor = tensors[i]; const paddle::experimental::Tensor& tensor = tensors[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(tensor); AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
if (auto_grad_meta == nullptr) {
VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
"stop_gradient=True: "
<< tensor.name();
continue;
}
// Get grad input info from target tensors // Get grad input info from target tensors
auto input_info = auto_grad_meta->OutRankInfo(); auto input_info = auto_grad_meta->OutRankInfo();
......
...@@ -19,8 +19,7 @@ import paddle ...@@ -19,8 +19,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _enable_legacy_dygraph from paddle.fluid.framework import _test_eager_guard
_enable_legacy_dygraph()
class TestDiffOp(unittest.TestCase): class TestDiffOp(unittest.TestCase):
...@@ -55,7 +54,7 @@ class TestDiffOp(unittest.TestCase): ...@@ -55,7 +54,7 @@ class TestDiffOp(unittest.TestCase):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0)) self.places.append(paddle.CUDAPlace(0))
def test_dygraph(self): def func_dygraph(self):
for place in self.places: for place in self.places:
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(self.input, place=place) x = paddle.to_tensor(self.input, place=place)
...@@ -71,6 +70,13 @@ class TestDiffOp(unittest.TestCase): ...@@ -71,6 +70,13 @@ class TestDiffOp(unittest.TestCase):
append=self.append) append=self.append)
self.assertTrue((out.numpy() == self.output).all(), True) self.assertTrue((out.numpy() == self.output).all(), True)
def test_dygraph(self):
with _test_eager_guard():
self.setUp()
self.func_dygraph()
self.setUp()
self.func_dygraph()
def test_static(self): def test_static(self):
paddle.enable_static() paddle.enable_static()
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
...@@ -110,7 +116,7 @@ class TestDiffOp(unittest.TestCase): ...@@ -110,7 +116,7 @@ class TestDiffOp(unittest.TestCase):
fetch_list=[out]) fetch_list=[out])
self.assertTrue((fetches[0] == self.output).all(), True) self.assertTrue((fetches[0] == self.output).all(), True)
def test_grad(self): def func_grad(self):
for place in self.places: for place in self.places:
x = paddle.to_tensor(self.input, place=place, stop_gradient=False) x = paddle.to_tensor(self.input, place=place, stop_gradient=False)
if self.prepend is not None: if self.prepend is not None:
...@@ -129,6 +135,13 @@ class TestDiffOp(unittest.TestCase): ...@@ -129,6 +135,13 @@ class TestDiffOp(unittest.TestCase):
except: except:
raise RuntimeError("Check Diff Gradient Failed") raise RuntimeError("Check Diff Gradient Failed")
def test_grad(self):
with _test_eager_guard():
self.setUp()
self.func_grad()
self.setUp()
self.func_grad()
class TestDiffOpAxis(TestDiffOp): class TestDiffOpAxis(TestDiffOp):
def set_args(self): def set_args(self):
......
...@@ -4260,18 +4260,19 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): ...@@ -4260,18 +4260,19 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
ends_2 = [dim_len] ends_2 = [dim_len]
attrs_2 += ('ends', ends_2) attrs_2 += ('ends', ends_2)
if in_dygraph_mode(): if in_dygraph_mode():
input_back = input_front = _C_ops.final_state_slice(new_input, axes, starts_2, ends_2, infer_flags, input_back = _C_ops.final_state_slice(new_input, axes, starts_2, ends_2, infer_flags,
[]) [])
else: else:
input_back = _C_ops.slice(new_input, None, None, None, None, 'axes', axes, \ input_back = _C_ops.slice(new_input, None, None, None, None, 'axes', axes, \
'infer_flags', infer_flags, *attrs_2) 'infer_flags', infer_flags, *attrs_2)
if x.dtype == paddle.bool: if x.dtype == paddle.bool:
op = getattr(_C_ops, "logical_xor") if in_dygraph_mode():
out = op(input_back, input_front) return _C_ops.final_state_logical_xor(input_back, input_front)
else:
return _C_ops.logical_xor(input_back, input_front)
else: else:
out = elementwise_sub(input_back, input_front, axis=axis) return elementwise_sub(input_back, input_front, axis=axis)
return out
else: else:
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff') check_variable_and_dtype(x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册