diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index f5ae3da91dadc6cabc29af972027455b00ebdc09..3f91abdebb9c397ac85358e7a8f218684a4dfd8c 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -887,7 +887,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): # Forward Full Logic function_name = forward_api_name if len(intermediate_outputs) > 0: - function_name = GetIntermediateAPIFunctionName(function_name) + if is_inplaced: + function_name = GetIntermediateAPIFunctionName( + forward_api_name[:-1]) + '_' + else: + function_name = GetIntermediateAPIFunctionName(function_name) forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 2e185fc0ca22bce314906cc3c6043ad0e0912cac..8acdd8b34f7d1dd099b848359140b3514cc03d93 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -36,6 +36,7 @@ DenseTensor::DenseTensor(const std::shared_ptr& holder, DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) { holder_ = other.holder_; + inplace_version_counter_ = other.inplace_version_counter_; #ifdef PADDLE_WITH_MKLDNN format_ = other.format_; @@ -45,6 +46,7 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) { DenseTensor& DenseTensor::operator=(const DenseTensor& other) { meta_ = other.meta(); holder_ = other.holder_; + inplace_version_counter_ = other.inplace_version_counter_; #ifdef PADDLE_WITH_MKLDNN format_ = other.format_; #endif @@ -54,6 +56,7 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { DenseTensor& DenseTensor::operator=(DenseTensor&& other) { meta_ = std::move(other.meta_); std::swap(holder_, other.holder_); + std::swap(inplace_version_counter_, other.inplace_version_counter_); return *this; } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 199029a3a094aacbe175a7f32d1c6af856192b80..fc1554f9a6c53f4f7cee6c1e4a1aa83eee052e1e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1527,8 +1527,8 @@ void ReshapeInferMeta(const MetaTensor& x, void ReshapeWithXShapeInferMeta(const MetaTensor& x, const ScalarArray& shape, - MetaTensor* xshape, MetaTensor* out, + MetaTensor* xshape, MetaConfig config) { PADDLE_ENFORCE_NOT_NULL( xshape, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index bae8083ef7191677cc625a720da448c452f61522..8a9876b11625c5aeaafac744496ff4240ec8cde0 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -238,8 +238,8 @@ void ReshapeInferMeta(const MetaTensor& x, void ReshapeWithXShapeInferMeta(const MetaTensor& x, const ScalarArray& shape, - MetaTensor* xshape, MetaTensor* out, + MetaTensor* xshape, MetaConfig config = MetaConfig()); void ReverseInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/reshape_kernel.cc b/paddle/phi/kernels/reshape_kernel.cc index f758d7c70518f067188242fdc9f014b5b414e885..12a75a838058a4b068b7a622c13916e09b013c53 100644 --- a/paddle/phi/kernels/reshape_kernel.cc +++ b/paddle/phi/kernels/reshape_kernel.cc @@ -45,8 +45,8 @@ template void ReshapeWithXShape(const Context& dev_ctx, const DenseTensor& x, const ScalarArray& shape, - DenseTensor* xshape, - DenseTensor* out) { + DenseTensor* out, + DenseTensor* xshape) { ReshapeKernel(dev_ctx, x, shape, out); } diff --git a/paddle/phi/kernels/reshape_kernel.h b/paddle/phi/kernels/reshape_kernel.h index 848f162a2a881ddc4d4ea136313216fd569accfd..11b19766a918bbc884f96e68a1a44db09609ce14 100644 --- a/paddle/phi/kernels/reshape_kernel.h +++ b/paddle/phi/kernels/reshape_kernel.h @@ -31,8 +31,8 @@ template void ReshapeWithXShape(const Context& dev_ctx, const DenseTensor& x, const ScalarArray& shape, - DenseTensor* xshape, - DenseTensor* out); + DenseTensor* out, + DenseTensor* xshape); template DenseTensor Reshape(const Context& dev_ctx, diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index ccae6aad02dd981caff89d5c2b0d8cbb3035ee75..6b528efe6d056264c689818a1e0c318046995f62 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -20,13 +20,13 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasOutput("XShape")) { if (ctx.InputSize("ShapeTensor") > 0) { return KernelSignature( - "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"XShape", "Out"}); + "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"}); } else if (ctx.HasInput("Shape")) { return KernelSignature( - "reshape_with_xshape", {"X"}, {"Shape"}, {"XShape", "Out"}); + "reshape_with_xshape", {"X"}, {"Shape"}, {"Out", "XShape"}); } else { return KernelSignature( - "reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"}); + "reshape_with_xshape", {"X"}, {"shape"}, {"Out", "XShape"}); } } else { if (ctx.InputSize("ShapeTensor") > 0) { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6752e65d016ea9b128cd91e9b19ded80936b9561..d1ef9d6d8b4ea7f080ad8a5791a5cc25d6f4e1b8 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6269,8 +6269,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): item.numpy().item(0) if isinstance(item, Variable) else item for item in shape ] - out, _ = _C_ops.reshape2(x, None, 'shape', shape) + out = _C_ops.final_state_reshape(x, shape) elif isinstance(shape, tmp_tensor_type): + # TODO: Tensor shape in final_state_reshape has not been tested shape.stop_gradient = True out, _ = _C_ops.reshape2(x, shape) else: diff --git a/python/paddle/fluid/tests/unittests/test_view_op_reuse_allocation.py b/python/paddle/fluid/tests/unittests/test_view_op_reuse_allocation.py index 9cabcf49bc05558ebb2115e07b96a9771205d7aa..85f1999ec878da6de2ddcf42c7e8877a3dd8c10c 100644 --- a/python/paddle/fluid/tests/unittests/test_view_op_reuse_allocation.py +++ b/python/paddle/fluid/tests/unittests/test_view_op_reuse_allocation.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTest import paddle +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode # NOTE(pangyoki): Tensor View Strategy. @@ -28,8 +29,13 @@ import paddle # View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach` class TestDygraphViewReuseAllocation(unittest.TestCase): def setUp(self): + self.set_flag_to_test_eager_mode() self.init_shape() + # some op don't suport eager_final_state in temporary + def set_flag_to_test_eager_mode(self): + self.flag_test_eager_mode = False + def init_shape(self): self.input_shape = [2, 3, 1] self.output_shape = [2, 3] @@ -37,10 +43,13 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): def view_api_processing(self, var): return paddle.squeeze(var) - def test_view_api(self): + def func_test_view_api(self): var = paddle.rand(self.input_shape) view_var = self.view_api_processing(var) - view_var[0] = 2. + # setitem don't support inplace in temporary. + # replace setitem with inplace exp_ in temporary. + # view_var[0] = 2. + view_var.exp_() self.assertEqual(var.shape, self.input_shape) self.assertEqual(view_var.shape, self.output_shape) @@ -48,24 +57,38 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): view_var_numpy = view_var.numpy() self.assertTrue(np.array_equal(var_numpy, view_var_numpy)) - def test_forward_version(self): + def test_view_api(self): + if self.flag_test_eager_mode: + with _test_eager_guard(): + self.func_test_view_api() + self.func_test_view_api() + + def func_test_forward_version(self): var = paddle.rand(self.input_shape) self.assertEqual(var.inplace_version, 0) view_var = self.view_api_processing(var) self.assertEqual(view_var.inplace_version, 0) - var[0] = 2. + # var[0] = 2. + var.exp_() self.assertEqual(var.inplace_version, 1) self.assertEqual(view_var.inplace_version, 1) view_var_2 = self.view_api_processing(var) self.assertEqual(view_var_2.inplace_version, 1) - var[0] = 3. + # var[0] = 3. + var.exp_() self.assertEqual(view_var.inplace_version, 2) self.assertEqual(view_var_2.inplace_version, 2) - def test_backward_error(self): + def test_forward_version(self): + if self.flag_test_eager_mode: + with _test_eager_guard(): + self.func_test_forward_version() + self.func_test_forward_version() + + def func_test_backward_error(self): # It raises an error because the inplace operator will result # in incorrect gradient computation. with paddle.fluid.dygraph.guard(): @@ -77,17 +100,34 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): # Here, the gradient computation will use the value of var_b var_c = var_b**2 view_var_b = self.view_api_processing(var_b) - view_var_b[0] = 2. # var_b is modified inplace + # view_var_b[0] = 2. # var_b is modified inplace + view_var_b.exp_() loss = paddle.nn.functional.relu(var_c) - with self.assertRaisesRegexp( - RuntimeError, - "received tensor_version:{} != wrapper_version_snapshot:{}". - format(1, 0)): - loss.backward() + if in_dygraph_mode(): + with self.assertRaisesRegexp( + RuntimeError, + "received current_inplace_version:{} != inplace_version_snapshot_:{}". + format(1, 0)): + loss.backward() + else: + with self.assertRaisesRegexp( + RuntimeError, + "received tensor_version:{} != wrapper_version_snapshot:{}". + format(1, 0)): + loss.backward() + + def test_backward_error(self): + if self.flag_test_eager_mode: + with _test_eager_guard(): + self.func_test_backward_error() + self.func_test_backward_error() class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): + def set_flag_to_test_eager_mode(self): + self.flag_test_eager_mode = False + def init_shape(self): self.input_shape = [2, 3] self.output_shape = [2, 3, 1] @@ -97,6 +137,9 @@ class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): + def set_flag_to_test_eager_mode(self): + self.flag_test_eager_mode = True + def init_shape(self): self.input_shape = [3, 4] self.output_shape = [2, 2, 3] @@ -106,6 +149,9 @@ class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): + def set_flag_to_test_eager_mode(self): + self.flag_test_eager_mode = False + def init_shape(self): self.input_shape = [3, 4] self.output_shape = [12] diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 6788c5899ac7e6545f9009485134ff5af2af5c54..09bfe746271092e5bd5777671893b5930ff76bdb 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1126,12 +1126,15 @@ - api : reshape args : (Tensor x, ScalarArray shape) - output : Tensor(out) + output : Tensor(out), Tensor(xshape) infer_meta : - func : ReshapeInferMeta + func : ReshapeWithXShapeInferMeta kernel : - func : reshape + func : reshape_with_xshape inplace : (x -> out) + view: (x -> out) + intermediate : xshape + backward: reshape_grad - api : scale args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 35d4cc7b5fa16d790979f935aa266be55e3eaac1..438e6f788ec56e4ac96308f967d64a074e1ff844 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -50,7 +50,8 @@ class BaseAPI(object): self.support_selected_rows_kernel = False if len(self.kernel[ 'func']) == 1 else True self.data_transform = self.parse_data_transform(api_item_yaml) - self.inplace_map = self.parse_inplace(api_item_yaml) + self.inplace_map, self.view_map = self.parse_inplace_and_view( + api_item_yaml) def get_api_name(self, api_item_yaml): return api_item_yaml['api'] @@ -273,24 +274,30 @@ class BaseAPI(object): return data_transform - def parse_inplace(self, api_item_yaml): - if 'inplace' in api_item_yaml: - inplace_map = {} - inplace_list = api_item_yaml['inplace'].split(',') - for item in inplace_list: - result = re.search(r"(?P\w+)\s*->\s(?P\w+)", item) - in_val = result.group('in') - out_val = result.group('out') - assert in_val in self.inputs['names'], \ - f"{self.api} : Inplace input error: the input var name('{in_val}') is not found in the input args of {self.api}." - assert out_val in self.outputs['names'], \ - f"{self.api} : Inplace output error: the output var name('{out_val}') is not found in the output args of {self.api}." - - inplace_map[out_val] = in_val - - return inplace_map - else: - return None + def parse_inplace_and_view(self, api_item_yaml): + inplace_map, view_map = None, None + for mode in ['inplace', 'view']: + if mode in api_item_yaml: + if mode == 'inplace': + inplace_map = {} + else: + view_map = {} + in_out_mapping_list = api_item_yaml[mode].split(',') + for item in in_out_mapping_list: + result = re.search(r"(?P\w+)\s*->\s(?P\w+)", item) + in_val = result.group('in') + out_val = result.group('out') + assert in_val in self.inputs['names'], \ + f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}." + assert out_val in self.outputs['names'], \ + f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}." + + if mode == 'inplace': + inplace_map[out_val] = in_val + else: + view_map[out_val] = in_val + + return inplace_map, view_map # Override by child class def get_return_type(self, out_type_list): diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index b0ed7886a710149e7487585e94862b74893e6da9..c8644a8812bd256f736343187086535e1ea79e9a 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -17,7 +17,7 @@ import yaml import argparse import re -from api_base import BaseAPI +from api_base import BaseAPI, PREFIX_TENSOR_NAME class ForwardAPI(BaseAPI): @@ -94,6 +94,13 @@ class ForwardAPI(BaseAPI): {code_indent} {self.outputs['return_type']} api_output{inplace_assign}; {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);""" + if not inplace_flag and self.view_map is not None and self.outputs[ + 'names'][0] in self.view_map: + output_create = output_create + f""" +{code_indent} kernel_out->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]}); +{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]}); +{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" + elif len(output_type_list) > 1: output_create = f""" {code_indent} {self.outputs['return_type']} api_output;""" @@ -109,6 +116,13 @@ class ForwardAPI(BaseAPI): output_create = output_create + f""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));""" + if not inplace_flag and self.view_map is not None and self.outputs[ + 'names'][i] in self.view_map: + output_create = output_create + f""" +{code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); +{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); +{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" + kernel_output = kernel_output[:-2] else: raise ValueError( diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index ab231cc36bb0ee1dff2db49b4045e4e970008f67..43f512540ec4b24629b02dd0130831fb5387b97e 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -659,6 +659,20 @@ kernel : func : relu_grad +- backward_api : reshape_grad + forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape) + args : (Tensor xshape, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : KernelWithXShapeInferMeta + param : [xshape] + kernel : + func : reshape_grad + param : [out_grad] + data_type: out_grad + backend: out_grad + layout: out_grad + - backward_api : scale_grad forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true)