未验证 提交 7170c687 编写于 作者: P pangyoki 提交者: GitHub

suppor inplace in tensor_method_setitem (#40915)

* suppor inplace in tensor_method_setitem

* delete bump_inplace_version

* optimize inplace unittest

* fix

* fix setitem bug

* update eager_generator

* optimize inplace unittest

* little change
上级 9fcb6a1d
......@@ -2108,6 +2108,10 @@ static std::string GenerateSingleOpBase(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grads_position);
} else {
if (dispensable_input_name_set.count(fwd_name) &&
grad_ins_fwd_slotname_map.count(fwd_name)) {
continue;
}
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) {
......@@ -2144,6 +2148,42 @@ static std::string GenerateSingleOpBase(
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name);
if (fwd_inputs_name_pos_map.count(fwd_name)) {
if (dispensable_input_name_set.count(fwd_name) &&
grad_ins_fwd_slotname_map.count(fwd_name)) {
if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.size() > 0) %s[\"%s\"] = egr::EagerUtils::CreateVars( "
"this->OutputMeta()[%d].size() );\n";
generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name,
grad_output_name, fwd_input_position);
} else {
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.initialized()) %s[\"%s\"] = "
"{std::make_shared<egr::EagerVariable>(egr::Controller::"
"Instance().GenerateUniqueName())};\n";
generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name,
grad_output_name);
}
}
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_output_name));
}
}
VLOG(6) << "Generated Outs Map";
......
......@@ -703,8 +703,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
}
});
// TODO(pangyoki) add inplace(BumpInplaceVersion) if need
// 1. Check argumnets
bool parse_index = true;
......@@ -753,12 +751,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
if (PyCheckTensor(value_obj)) {
value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
// pass the stop_gradient from value to tensor
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
} else if (py::isinstance<py::array>(value_obj)) {
paddle::experimental::Tensor value_tensor_tmp(
std::make_shared<phi::DenseTensor>(),
......@@ -858,8 +850,18 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
{
// Release gil and do tracing
py::gil_scoped_release release;
self->tensor = set_value_dygraph_function(self->tensor, value_tensor, {},
{}, {}, attrs);
// use inplace set_value_ operator
self->tensor = set_value__dygraph_function(self->tensor, value_tensor, {},
{}, {}, attrs);
}
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
......@@ -1179,6 +1181,15 @@ static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__bump_inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
self->tensor.bump_inplace_version();
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -1287,6 +1298,9 @@ PyMethodDef variable_methods[] = {
/***the method of sparse tensor****/
{"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_bump_inplace_version",
(PyCFunction)(void (*)(void))tensor__bump_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_selected_rows",
(PyCFunction)(void (*)(void))tensor_method_is_selected_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
......
......@@ -364,11 +364,7 @@ void Tensor::bump_inplace_version() {
auto &inplace_version_counter =
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
->InplaceVersionCounter();
VLOG(3) << "yoki: before bump inplace version: "
<< inplace_version_counter.CurrentVersion();
inplace_version_counter.Bump();
VLOG(3) << "yoki: after bump inplace version: "
<< inplace_version_counter.CurrentVersion();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"bump_inplace_version is only supported on DenseTensor now."));
......@@ -380,8 +376,6 @@ uint32_t Tensor::current_inplace_version() {
auto &inplace_version_counter =
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
->InplaceVersionCounter();
VLOG(3) << "yoki: print version: "
<< inplace_version_counter.CurrentVersion();
return inplace_version_counter.CurrentVersion();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
......
......@@ -965,7 +965,6 @@ set_tests_properties(test_bicubic_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
set_tests_properties(test_inplace_eager_fluid PROPERTIES TIMEOUT 120)
set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_entropy2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_fetch_unmerged PROPERTIES TIMEOUT 120)
......
......@@ -19,10 +19,11 @@ import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
class TestInplace(unittest.TestCase):
def test_forward_version(self):
def func_test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.assertEqual(var.inplace_version, 0)
......@@ -30,7 +31,11 @@ class TestInplace(unittest.TestCase):
var[0] = 1.1
self.assertEqual(var.inplace_version, 1)
paddle.assign(paddle.ones(shape=[3]), var)
# TODO1: assign don't support inplace in temporary
if in_dygraph_mode():
var[0] = 2
else:
paddle.assign(paddle.ones(shape=[3]), var)
# NOTE(liym27): assign(input, output) is an inplace operation for output.
# There is inplace-related processing for api assign, var.inplace_version should be 2 not 1.
......@@ -39,7 +44,12 @@ class TestInplace(unittest.TestCase):
var[2] = 3
self.assertEqual(var.inplace_version, 3)
def test_backward_error(self):
def test_forward_version(self):
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()
def func_test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
......@@ -55,13 +65,25 @@ class TestInplace(unittest.TestCase):
var_d = var_b**2
loss = paddle.nn.functional.relu(var_c + var_d)
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_success_1(self):
def test_backward_error(self):
with _test_eager_guard():
self.func_test_backward_error()
self.func_test_backward_error()
def func_test_backward_success_1(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
......@@ -76,7 +98,12 @@ class TestInplace(unittest.TestCase):
loss = var_c.sum()
loss.backward()
def test_backward_success_2(self):
def test_backward_success_1(self):
with _test_eager_guard():
self.func_test_backward_success_1()
self.func_test_backward_success_1()
def func_test_backward_success_2(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
......@@ -94,6 +121,12 @@ class TestInplace(unittest.TestCase):
loss.backward()
def test_backward_success_2(self):
# TODO2: need to process no_need_buffer in eager mode
# with _test_eager_guard():
# self.func_test_backward_success_2()
self.func_test_backward_success_2()
class TestDygraphInplace(unittest.TestCase):
def setUp(self):
......@@ -113,7 +146,7 @@ class TestDygraphInplace(unittest.TestCase):
def inplace_api_processing(self, var):
return paddle.squeeze_(var)
def test_inplace_api(self):
def func_test_inplace_api(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
inplace_var = self.inplace_api_processing(var)
self.assertTrue(id(var) == id(inplace_var))
......@@ -121,7 +154,12 @@ class TestDygraphInplace(unittest.TestCase):
inplace_var[0] = 2.
self.assertTrue(np.array_equal(var.numpy(), inplace_var.numpy()))
def test_forward_version(self):
def test_inplace_api(self):
with _test_eager_guard():
self.func_test_inplace_api()
self.func_test_inplace_api()
def func_test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
self.assertEqual(var.inplace_version, 0)
......@@ -135,7 +173,12 @@ class TestDygraphInplace(unittest.TestCase):
inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 3)
def test_leaf_inplace_var_error(self):
def test_forward_version(self):
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()
def func_test_leaf_inplace_var_error(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var.stop_gradient = False
......@@ -145,7 +188,12 @@ class TestDygraphInplace(unittest.TestCase):
self.assertRaises(ValueError, leaf_inplace_error)
def test_backward_error(self):
def test_leaf_inplace_var_error(self):
with _test_eager_guard():
self.func_test_leaf_inplace_var_error()
self.func_test_leaf_inplace_var_error()
def func_test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
......@@ -159,13 +207,25 @@ class TestDygraphInplace(unittest.TestCase):
self.inplace_api_processing(var_b)
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_success_1(self):
def test_backward_error(self):
with _test_eager_guard():
self.func_test_backward_error()
self.func_test_backward_error()
def func_test_backward_success_1(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
grad_var_a, grad_var_a_inplace = 0, 1
......@@ -196,7 +256,12 @@ class TestDygraphInplace(unittest.TestCase):
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
def test_backward_success_2(self):
def test_backward_success_1(self):
with _test_eager_guard():
self.func_test_backward_success_1()
self.func_test_backward_success_1()
def func_test_backward_success_2(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
grad_var_a, grad_var_a_inplace = 0, 1
......@@ -221,8 +286,7 @@ class TestDygraphInplace(unittest.TestCase):
var_b = var_a**2
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
......@@ -231,6 +295,11 @@ class TestDygraphInplace(unittest.TestCase):
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
def test_backward_success_2(self):
with _test_eager_guard():
self.func_test_backward_success_2()
self.func_test_backward_success_2()
class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
def non_inplace_api_processing(self, var):
......@@ -391,26 +460,29 @@ class TestDygraphInplaceAdd(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.rand(2, 3, 4)
self.dtype = "float32"
input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype)
self.input_var_2 = paddle.to_tensor(input_var_numpy_2)
self.input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype)
def non_inplace_api_processing(self, var):
return var.add(self.input_var_2)
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.add(input_var_2)
def inplace_api_processing(self, var):
return var.add_(self.input_var_2)
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.add_(input_var_2)
class TestDygraphInplaceSubtract(TestDygraphInplaceAdd):
def non_inplace_api_processing(self, var):
return var.subtract(self.input_var_2)
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.subtract(input_var_2)
def inplace_api_processing(self, var):
return var.subtract_(self.input_var_2)
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.subtract_(input_var_2)
class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
def func_test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
......@@ -433,9 +505,14 @@ class TestLossIsInplaceVar(unittest.TestCase):
self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))
def test_loss_is_inplace_var(self):
with _test_eager_guard():
self.func_test_loss_is_inplace_var()
self.func_test_loss_is_inplace_var()
class TestContinuouslyInplace(unittest.TestCase):
def test_continuously_inplace(self):
def func_test_continuously_inplace(self):
a = paddle.rand([2, 3])
a.stop_gradient = False
b = a * 2
......@@ -446,6 +523,11 @@ class TestContinuouslyInplace(unittest.TestCase):
b.backward()
def test_continuously_inplace(self):
with _test_eager_guard():
self.func_test_continuously_inplace()
self.func_test_continuously_inplace()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
class TestDygraphInplace(unittest.TestCase):
def setUp(self):
self.init_data()
self.set_np_compare_func()
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
def set_np_compare_func(self):
self.np_compare = np.array_equal
def non_inplace_api_processing(self, var):
return paddle.squeeze(var)
def inplace_api_processing(self, var):
return paddle.squeeze_(var)
def test_inplace_api(self):
with _test_eager_guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
inplace_var = self.inplace_api_processing(var)
self.assertTrue(id(var) == id(inplace_var))
inplace_var.exp_()
self.assertTrue(np.array_equal(var.numpy(), inplace_var.numpy()))
def test_forward_version(self):
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
self.assertEqual(var.inplace_version, 0)
inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 1)
inplace_var.exp_()
self.assertEqual(var.inplace_version, 2)
inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 3)
def test_leaf_inplace_var_error(self):
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var.stop_gradient = False
def leaf_inplace_error():
self.inplace_api_processing(var)
self.assertRaises(ValueError, leaf_inplace_error)
def test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
self.inplace_api_processing(var_b)
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
def test_backward_success_1(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
# Here, the gradient computation will use the value of var_b
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
def test_backward_success_2(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_3(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
helper = var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
# Here, the gradient computation will use the value of var_b
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
helper = var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_4(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_5(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
# Here, the gradient computation will use the value of var_b
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_6(self):
# Although var_b is modified inplace before using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.unsqueeze(var, -1)
def inplace_api_processing(self, var):
return paddle.unsqueeze_(var, -1)
class TestDygraphInplaceReshape(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.reshape(var, [-1])
def inplace_api_processing(self, var):
return paddle.reshape_(var, [-1])
class TestDygraphInplaceFlatten(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.flatten()
def inplace_api_processing(self, var):
return var.flatten_()
"""
# This case will fail while using `_C_ops.final_state_scatter`.
class TestDygraphInplaceScatter(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]])
self.dtype = "float32"
def non_inplace_api_processing(self, var):
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
updates = paddle.to_tensor(
[[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
return paddle.scatter(var, index, updates, overwrite=False)
def inplace_api_processing(self, var):
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
updates = paddle.to_tensor(
[[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
return paddle.scatter_(var, index, updates, overwrite=False)
"""
class TestDygraphInplaceElu(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.nn.functional.elu(var)
def inplace_api_processing(self, var):
return paddle.nn.functional.elu_(var)
class TestDygraphInplaceRelu(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.nn.functional.relu(var)
def inplace_api_processing(self, var):
return paddle.nn.functional.relu_(var)
class TestDygraphInplaceSoftmax(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.nn.functional.softmax(var)
def inplace_api_processing(self, var):
return paddle.nn.functional.softmax_(var)
class TestDygraphInplaceTanh(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.tanh(var)
def inplace_api_processing(self, var):
return paddle.tanh_(var)
class TestDygraphInplaceCeil(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.ceil()
def inplace_api_processing(self, var):
return var.ceil_()
class TestDygraphInplaceFloor(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.floor()
def inplace_api_processing(self, var):
return var.floor_()
class TestDygraphInplaceExp(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose
def non_inplace_api_processing(self, var):
return var.exp()
def inplace_api_processing(self, var):
return var.exp_()
class TestDygraphInplaceReciprocal(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.reciprocal()
def inplace_api_processing(self, var):
return var.reciprocal_()
class TestDygraphInplaceRound(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.round()
def inplace_api_processing(self, var):
return var.round_()
class TestDygraphInplaceSqrt(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(0, 5, [10, 20, 1])
self.dtype = "float32"
def non_inplace_api_processing(self, var):
return var.sqrt()
def inplace_api_processing(self, var):
return var.sqrt_()
class TestDygraphInplaceRsqrt(TestDygraphInplaceSqrt):
def non_inplace_api_processing(self, var):
return var.rsqrt()
def inplace_api_processing(self, var):
return var.rsqrt_()
class TestDygraphInplaceClip(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.clip(0.6, 1.5)
def inplace_api_processing(self, var):
return var.clip_(0.6, 1.5)
class TestDygraphInplaceScale(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.scale(scale=2.0, bias=3.0)
def inplace_api_processing(self, var):
return var.scale_(scale=2.0, bias=3.0)
class TestDygraphInplaceAdd(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.rand(2, 3, 4)
self.dtype = "float32"
self.input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype)
def non_inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.add(input_var_2)
def inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.add_(input_var_2)
class TestDygraphInplaceSubtract(TestDygraphInplaceAdd):
def non_inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.subtract(input_var_2)
def inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.subtract_(input_var_2)
class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
var_b = var_a * 2
loss = var_b.tanh_()
loss.backward()
inplace_grad_var_a = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
var_b = var_a * 2
loss = var_b.tanh()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))
class TestContinuouslyInplace(unittest.TestCase):
def test_continuously_inplace(self):
with _test_eager_guard():
a = paddle.rand([2, 3])
a.stop_gradient = False
b = a * 2
b.reshape_([-1])
b.reshape_([2, 3])
b.reshape_([-1])
b.backward()
if __name__ == '__main__':
unittest.main()
......@@ -1011,10 +1011,7 @@ class TestBackward(unittest.TestCase):
loss.backward()
self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape)
#
# TODO(pangyoki) add inplace and delete if
if _in_legacy_dygraph():
self.assertTrue((0 == x.grad[0, :, 0, 0]).all())
self.assertTrue((0 == x.grad[0, :, 0, 0]).all())
def test_dynamic(self):
with _test_eager_guard():
......@@ -1192,8 +1189,8 @@ class TestGradientTruncated(unittest.TestCase):
x[0, :] = value
self.assertTrue(~x.stop_gradient)
self.assertTrue(~x.is_leaf)
self.assertTrue(not x.stop_gradient)
self.assertTrue(not x.is_leaf)
def test_consistent_with_competitor(self):
with _test_eager_guard():
......
......@@ -674,8 +674,7 @@ def _setitem_impl_(var, item, value):
"paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value)))
if paddle.fluid.framework._in_legacy_dygraph():
# TODO(pangyoki) add inplace(BumpInplaceVersion) if need
if paddle.fluid.framework._non_static_mode():
var._bump_inplace_version()
cur_block = default_main_program().current_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册