未验证 提交 37385f63 编写于 作者: W WeiXin 提交者: GitHub

replace 'InnerSetOverridedStopGradient' with 'SetOverridedStopGradient'. (#33303)

* replace 'InnerSetOverridedStopGradient' with 'SetOverridedStopGradient'.

* improve coverage.

* polish error message.
上级 45d1ae21
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -32,7 +33,17 @@ bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { ...@@ -32,7 +33,17 @@ bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& var_base : name_pair.second) { for (const auto& var_base : name_pair.second) {
if (!var_base->OverridedStopGradient()) { if (!var_base->OverridedStopGradient()) {
PassStopGradient(outs, var_base->OverridedStopGradient()); for (const auto& pair : outs) {
for (const auto& var : pair.second) {
if (var) {
var->SetOverridedStopGradient(false);
SetForwardDataTypeOfGradVar(var);
VLOG(3) << "Set output: " << var->Name()
<< "'s OverridedStopGradient as "
<< var->OverridedStopGradient();
}
}
}
return true; return true;
} }
} }
...@@ -78,28 +89,36 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -78,28 +89,36 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
// process args,`input_vars` only collect `imperative::VarBase` // process args,`input_vars` only collect `imperative::VarBase`
if (!args.empty()) { if (!args.empty()) {
for (auto ptr = args.begin(); ptr != args.end(); ptr++) { for (auto ptr = args.begin(); ptr != args.end(); ptr++) {
// Only collect Tensor type in 'args' and pass them to backward. Ignore
// other types of input temporarily.
if (py::isinstance<imperative::VarBase>(*ptr)) {
try { try {
if (Py_None != ptr->ptr()) {
auto a = ptr->cast<std::shared_ptr<VarBase>>(); auto a = ptr->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a); input_vars.push_back(a);
}
} catch (py::cast_error& err) { } catch (py::cast_error& err) {
// Only collect Tensor type in 'args' and pass them to backward. Ignore PADDLE_THROW(platform::errors::InvalidArgument(
// other types of input temporarily. "The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
} }
} }
} }
// process kwargs, only collect `imperative::VarBase` // process kwargs, only collect `imperative::VarBase`
if (!kwargs.empty()) { if (!kwargs.empty()) {
for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) { for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) {
// Only collect Tensor type in 'kwargs' and pass them to backward.
// Ignore other types of input temporarily.
if (py::isinstance<imperative::VarBase>(*ptr->second)) {
try { try {
if (Py_None != ptr->second.ptr()) {
auto a = ptr->second.cast<std::shared_ptr<VarBase>>(); auto a = ptr->second.cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a); input_vars.push_back(a);
}
} catch (py::cast_error&) { } catch (py::cast_error&) {
// Only collect Tensor type in 'kwargs' and pass them to backward. PADDLE_THROW(platform::errors::InvalidArgument(
// Ignore other types of input temporarily. "The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
} }
} }
} }
...@@ -110,33 +129,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -110,33 +129,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
PyList_Check(result_forward.ptr())) { PyList_Check(result_forward.ptr())) {
auto tuple_result = result_forward.cast<py::tuple>(); auto tuple_result = result_forward.cast<py::tuple>();
for (size_t i = 0; i < tuple_result.size(); i++) { for (size_t i = 0; i < tuple_result.size(); i++) {
if (Py_None != tuple_result[i].ptr()) { // Only collect Tensor type of output and pass them to backward.
// Ignore other types of input temporarily.
if (py::isinstance<imperative::VarBase>(tuple_result[i])) {
try { try {
auto temp_out = auto temp_out =
tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>(); tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>();
output_vars.push_back(temp_out); output_vars.push_back(temp_out);
} catch (py::cast_error&) { } catch (py::cast_error&) {
// Only collect Tensor type in 'kwargs' and pass them to backward. PADDLE_THROW(platform::errors::InvalidArgument(
// Ignore other types of input temporarily. "The `PyLayer.forward` function returns invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
tuple_result[i].ptr()->ob_type->tp_name));
} }
} else {
// Only collect Tensor type in 'kwargs' and pass them to backward.
// Ignore other types of input temporarily.
} }
} }
} else { } else {
if (Py_None != result_forward.ptr()) { // Only collect Tensor type of output and pass them to backward.
// Ignore other types of input temporarily.
if (py::isinstance<imperative::VarBase>(result_forward)) {
try { try {
auto temp_out = auto temp_out =
result_forward.cast<std::shared_ptr<imperative::VarBase>>(); result_forward.cast<std::shared_ptr<imperative::VarBase>>();
output_vars.push_back(temp_out); output_vars.push_back(temp_out);
} catch (py::cast_error&) { } catch (py::cast_error&) {
// Only collect Tensor type in 'kwargs' and pass them to backward. PADDLE_THROW(platform::errors::InvalidArgument(
// Ignore other types of input temporarily. "The `PyLayer.forward` function returns invalid argument, the `%s` "
"type argument can not be cast into `Tensor`.",
result_forward.ptr()->ob_type->tp_name));
} }
} else {
// Only collect Tensor type in 'kwargs' and pass them to backward.
// Ignore other types of input temporarily.
} }
} }
if (output_vars.size() == 0) { if (output_vars.size() == 0) {
......
...@@ -62,13 +62,22 @@ void RunPyObject(py::object *py_object, ...@@ -62,13 +62,22 @@ void RunPyObject(py::object *py_object,
for (size_t i = 0; i < result_tuple.size(); i++) { for (size_t i = 0; i < result_tuple.size(); i++) {
if ((*outs)[i] != nullptr) { if ((*outs)[i] != nullptr) {
if (Py_None != result_tuple[i].ptr()) { if (Py_None != result_tuple[i].ptr()) {
if (py::isinstance<imperative::VarBase>(result_tuple[i])) {
try { try {
auto result_var = auto result_var =
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>(); result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
*(*outs)[i] = result_var->Var(); *(*outs)[i] = result_var->Var();
} catch (py::cast_error &) { } catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`.")); "The `PyLayer.backward` function returns invalid argument, "
"the `%s` type argument can not be cast into `Tensor`.",
result_tuple[i].ptr()->ob_type->tp_name));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`, but "
"received `%s`.",
result_tuple[i].ptr()->ob_type->tp_name));
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -94,13 +103,22 @@ void RunPyObject(py::object *py_object, ...@@ -94,13 +103,22 @@ void RunPyObject(py::object *py_object,
} }
if ((*outs)[0] != nullptr) { if ((*outs)[0] != nullptr) {
if (Py_None != py_result.ptr()) { if (Py_None != py_result.ptr()) {
if (py::isinstance<imperative::VarBase>(py_result)) {
try { try {
auto result_var = auto result_var =
py_result.cast<std::shared_ptr<imperative::VarBase>>(); py_result.cast<std::shared_ptr<imperative::VarBase>>();
*((*outs)[0]) = result_var->Var(); *((*outs)[0]) = result_var->Var();
} catch (py::cast_error &) { } catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`.")); "The `PyLayer.backward` function returns invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
py_result.ptr()->ob_type->tp_name));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`, but "
"received `%s`",
py_result.ptr()->ob_type->tp_name));
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
......
...@@ -21,6 +21,11 @@ import paddle ...@@ -21,6 +21,11 @@ import paddle
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
class FakeTensor(paddle.fluid.core.VarBase):
def __init__(self):
pass
class TestPyLayer(unittest.TestCase): class TestPyLayer(unittest.TestCase):
def test_simple_pylayer_multiple_output(self): def test_simple_pylayer_multiple_output(self):
class tanh(PyLayer): class tanh(PyLayer):
...@@ -426,6 +431,129 @@ class TestPyLayer(unittest.TestCase): ...@@ -426,6 +431,129 @@ class TestPyLayer(unittest.TestCase):
z = paddle.tanh(data) z = paddle.tanh(data)
z = cus_tanh.apply(data) z = cus_tanh.apply(data)
def test_return_to_tensor(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
y1 = paddle.tanh(x1)
ctx.save_for_backward(y1)
tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
return y1, 5, None, "helloworld", tensor_1
@staticmethod
def backward(ctx, dy1, dy2):
y1, = ctx.saved_tensor()
re1 = dy1 * (1 - paddle.square(y1))
return dy1
input1 = paddle.randn([2, 3]).astype("float32")
input2 = input1.detach().clone()
input1.stop_gradient = False
input2.stop_gradient = False
z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
z.mean().backward()
class TestPyLayerReturnType(unittest.TestCase):
def test_forward_args_fake_tensor(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
y1 = FakeTensor()
return y1, x1
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = FakeTensor()
with self.assertRaises(ValueError):
y1, y2 = Tanh.apply(input1)
def test_forward_kwargs_fake_tensor(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
return x1
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = FakeTensor()
with self.assertRaises(ValueError):
y = Tanh.apply(x1=input1)
def test_forward_return_fake_tensor(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
return FakeTensor()
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = paddle.randn([3, 2])
with self.assertRaises(ValueError):
y = Tanh.apply(x1=input1)
def test_forward_return_fake_tensor_tuple(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
return FakeTensor(), FakeTensor()
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = paddle.randn([3, 2])
with self.assertRaises(ValueError):
y = Tanh.apply(x1=input1)
def test_backward_return_fake_tensor_tuple(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 + 1, x1 + 2
@staticmethod
def backward(ctx, dy1, dy2):
return FakeTensor(), 2
input1 = paddle.randn([3, 2])
input1.stop_gradient = False
y, _ = Tanh.apply(input1, 1 + input1)
with self.assertRaises(ValueError):
y.mean().backward()
def test_backward_return_fake_tensor(self):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
return x1 + 1, x1 + 2
@staticmethod
def backward(ctx, dy1, dy2):
return FakeTensor()
input1 = paddle.randn([3, 2])
input1.stop_gradient = False
y, _ = Tanh.apply(input1)
with self.assertRaises(ValueError):
y.mean().backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册