未验证 提交 495ca4aa 编写于 作者: P pangyoki 提交者: GitHub

support view strategy in dygraph eager_final state (#40891)

* support view strategy in eager_final state

* perfect reshape kernel

* fix bugs of sig

* add unittest for reshape_sig

* fix bugs when run converage

* fix inplace bug in final_state eager_gen

* fix python_c_gen

* support view strategy for final state

* fix order of out and xshape in reshape

* fix Coverage_CI unittest timeout error

* support reshape view

* fix reshape_sig

* fix yml and api_base
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
上级 775ddb5a
......@@ -887,6 +887,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
# Forward Full Logic
function_name = forward_api_name
if len(intermediate_outputs) > 0:
if is_inplaced:
function_name = GetIntermediateAPIFunctionName(
forward_api_name[:-1]) + '_'
else:
function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
......
......@@ -36,6 +36,7 @@ DenseTensor::DenseTensor(const std::shared_ptr<phi::Allocation>& holder,
DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
holder_ = other.holder_;
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
......@@ -45,6 +46,7 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
meta_ = other.meta();
holder_ = other.holder_;
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
#endif
......@@ -54,6 +56,7 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
meta_ = std::move(other.meta_);
std::swap(holder_, other.holder_);
std::swap(inplace_version_counter_, other.inplace_version_counter_);
return *this;
}
......
......@@ -1527,8 +1527,8 @@ void ReshapeInferMeta(const MetaTensor& x,
void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
xshape,
......
......@@ -238,8 +238,8 @@ void ReshapeInferMeta(const MetaTensor& x,
void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());
void ReverseInferMeta(const MetaTensor& x,
......
......@@ -45,8 +45,8 @@ template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out) {
DenseTensor* out,
DenseTensor* xshape) {
ReshapeKernel(dev_ctx, x, shape, out);
}
......
......@@ -31,8 +31,8 @@ template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out);
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx,
......
......@@ -20,13 +20,13 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"XShape", "Out"});
"reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"});
} else if (ctx.HasInput("Shape")) {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"Shape"}, {"XShape", "Out"});
"reshape_with_xshape", {"X"}, {"Shape"}, {"Out", "XShape"});
} else {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"});
"reshape_with_xshape", {"X"}, {"shape"}, {"Out", "XShape"});
}
} else {
if (ctx.InputSize("ShapeTensor") > 0) {
......
......@@ -6269,8 +6269,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = _C_ops.reshape2(x, None, 'shape', shape)
out = _C_ops.final_state_reshape(x, shape)
elif isinstance(shape, tmp_tensor_type):
# TODO: Tensor shape in final_state_reshape has not been tested
shape.stop_gradient = True
out, _ = _C_ops.reshape2(x, shape)
else:
......
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
# NOTE(pangyoki): Tensor View Strategy.
......@@ -28,8 +29,13 @@ import paddle
# View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach`
class TestDygraphViewReuseAllocation(unittest.TestCase):
def setUp(self):
self.set_flag_to_test_eager_mode()
self.init_shape()
# some op don't suport eager_final_state in temporary
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self):
self.input_shape = [2, 3, 1]
self.output_shape = [2, 3]
......@@ -37,10 +43,13 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
def view_api_processing(self, var):
return paddle.squeeze(var)
def test_view_api(self):
def func_test_view_api(self):
var = paddle.rand(self.input_shape)
view_var = self.view_api_processing(var)
view_var[0] = 2.
# setitem don't support inplace in temporary.
# replace setitem with inplace exp_ in temporary.
# view_var[0] = 2.
view_var.exp_()
self.assertEqual(var.shape, self.input_shape)
self.assertEqual(view_var.shape, self.output_shape)
......@@ -48,24 +57,38 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
view_var_numpy = view_var.numpy()
self.assertTrue(np.array_equal(var_numpy, view_var_numpy))
def test_forward_version(self):
def test_view_api(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_view_api()
self.func_test_view_api()
def func_test_forward_version(self):
var = paddle.rand(self.input_shape)
self.assertEqual(var.inplace_version, 0)
view_var = self.view_api_processing(var)
self.assertEqual(view_var.inplace_version, 0)
var[0] = 2.
# var[0] = 2.
var.exp_()
self.assertEqual(var.inplace_version, 1)
self.assertEqual(view_var.inplace_version, 1)
view_var_2 = self.view_api_processing(var)
self.assertEqual(view_var_2.inplace_version, 1)
var[0] = 3.
# var[0] = 3.
var.exp_()
self.assertEqual(view_var.inplace_version, 2)
self.assertEqual(view_var_2.inplace_version, 2)
def test_backward_error(self):
def test_forward_version(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()
def func_test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
......@@ -77,17 +100,34 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
view_var_b = self.view_api_processing(var_b)
view_var_b[0] = 2. # var_b is modified inplace
# view_var_b[0] = 2. # var_b is modified inplace
view_var_b.exp_()
loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_error(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_backward_error()
self.func_test_backward_error()
class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self):
self.input_shape = [2, 3]
self.output_shape = [2, 3, 1]
......@@ -97,6 +137,9 @@ class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = True
def init_shape(self):
self.input_shape = [3, 4]
self.output_shape = [2, 2, 3]
......@@ -106,6 +149,9 @@ class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self):
self.input_shape = [3, 4]
self.output_shape = [12]
......
......@@ -1126,12 +1126,15 @@
- api : reshape
args : (Tensor x, ScalarArray shape)
output : Tensor(out)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : ReshapeInferMeta
func : ReshapeWithXShapeInferMeta
kernel :
func : reshape
func : reshape_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward: reshape_grad
- api : scale
args : (Tensor x, Scalar scale, float bias, bool bias_after_scale)
......
......@@ -50,7 +50,8 @@ class BaseAPI(object):
self.support_selected_rows_kernel = False if len(self.kernel[
'func']) == 1 else True
self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map = self.parse_inplace(api_item_yaml)
self.inplace_map, self.view_map = self.parse_inplace_and_view(
api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['api']
......@@ -273,24 +274,30 @@ class BaseAPI(object):
return data_transform
def parse_inplace(self, api_item_yaml):
if 'inplace' in api_item_yaml:
def parse_inplace_and_view(self, api_item_yaml):
inplace_map, view_map = None, None
for mode in ['inplace', 'view']:
if mode in api_item_yaml:
if mode == 'inplace':
inplace_map = {}
inplace_list = api_item_yaml['inplace'].split(',')
for item in inplace_list:
else:
view_map = {}
in_out_mapping_list = api_item_yaml[mode].split(',')
for item in in_out_mapping_list:
result = re.search(r"(?P<in>\w+)\s*->\s(?P<out>\w+)", item)
in_val = result.group('in')
out_val = result.group('out')
assert in_val in self.inputs['names'], \
f"{self.api} : Inplace input error: the input var name('{in_val}') is not found in the input args of {self.api}."
f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}."
assert out_val in self.outputs['names'], \
f"{self.api} : Inplace output error: the output var name('{out_val}') is not found in the output args of {self.api}."
f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}."
if mode == 'inplace':
inplace_map[out_val] = in_val
return inplace_map
else:
return None
view_map[out_val] = in_val
return inplace_map, view_map
# Override by child class
def get_return_type(self, out_type_list):
......
......@@ -17,7 +17,7 @@ import yaml
import argparse
import re
from api_base import BaseAPI
from api_base import BaseAPI, PREFIX_TENSOR_NAME
class ForwardAPI(BaseAPI):
......@@ -94,6 +94,13 @@ class ForwardAPI(BaseAPI):
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][0] in self.view_map:
output_create = output_create + f"""
{code_indent} kernel_out->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
elif len(output_type_list) > 1:
output_create = f"""
{code_indent} {self.outputs['return_type']} api_output;"""
......@@ -109,6 +116,13 @@ class ForwardAPI(BaseAPI):
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));"""
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map:
output_create = output_create + f"""
{code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
......
......@@ -659,6 +659,20 @@
kernel :
func : relu_grad
- backward_api : reshape_grad
forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : reshape_grad
param : [out_grad]
data_type: out_grad
backend: out_grad
layout: out_grad
- backward_api : scale_grad
forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册