未验证 提交 495ca4aa 编写于 作者: P pangyoki 提交者: GitHub

support view strategy in dygraph eager_final state (#40891)

* support view strategy in eager_final state

* perfect reshape kernel

* fix bugs of sig

* add unittest for reshape_sig

* fix bugs when run converage

* fix inplace bug in final_state eager_gen

* fix python_c_gen

* support view strategy for final state

* fix order of out and xshape in reshape

* fix Coverage_CI unittest timeout error

* support reshape view

* fix reshape_sig

* fix yml and api_base
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
上级 775ddb5a
...@@ -887,7 +887,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -887,7 +887,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
# Forward Full Logic # Forward Full Logic
function_name = forward_api_name function_name = forward_api_name
if len(intermediate_outputs) > 0: if len(intermediate_outputs) > 0:
function_name = GetIntermediateAPIFunctionName(function_name) if is_inplaced:
function_name = GetIntermediateAPIFunctionName(
forward_api_name[:-1]) + '_'
else:
function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
......
...@@ -36,6 +36,7 @@ DenseTensor::DenseTensor(const std::shared_ptr<phi::Allocation>& holder, ...@@ -36,6 +36,7 @@ DenseTensor::DenseTensor(const std::shared_ptr<phi::Allocation>& holder,
DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) { DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
holder_ = other.holder_; holder_ = other.holder_;
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
format_ = other.format_; format_ = other.format_;
...@@ -45,6 +46,7 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) { ...@@ -45,6 +46,7 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
DenseTensor& DenseTensor::operator=(const DenseTensor& other) { DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
meta_ = other.meta(); meta_ = other.meta();
holder_ = other.holder_; holder_ = other.holder_;
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
format_ = other.format_; format_ = other.format_;
#endif #endif
...@@ -54,6 +56,7 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { ...@@ -54,6 +56,7 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
DenseTensor& DenseTensor::operator=(DenseTensor&& other) { DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
meta_ = std::move(other.meta_); meta_ = std::move(other.meta_);
std::swap(holder_, other.holder_); std::swap(holder_, other.holder_);
std::swap(inplace_version_counter_, other.inplace_version_counter_);
return *this; return *this;
} }
......
...@@ -1527,8 +1527,8 @@ void ReshapeInferMeta(const MetaTensor& x, ...@@ -1527,8 +1527,8 @@ void ReshapeInferMeta(const MetaTensor& x,
void ReshapeWithXShapeInferMeta(const MetaTensor& x, void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) { MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
xshape, xshape,
......
...@@ -238,8 +238,8 @@ void ReshapeInferMeta(const MetaTensor& x, ...@@ -238,8 +238,8 @@ void ReshapeInferMeta(const MetaTensor& x,
void ReshapeWithXShapeInferMeta(const MetaTensor& x, void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void ReverseInferMeta(const MetaTensor& x, void ReverseInferMeta(const MetaTensor& x,
......
...@@ -45,8 +45,8 @@ template <typename Context> ...@@ -45,8 +45,8 @@ template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx, void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* out,
DenseTensor* out) { DenseTensor* xshape) {
ReshapeKernel(dev_ctx, x, shape, out); ReshapeKernel(dev_ctx, x, shape, out);
} }
......
...@@ -31,8 +31,8 @@ template <typename Context> ...@@ -31,8 +31,8 @@ template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx, void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* out,
DenseTensor* out); DenseTensor* xshape);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx, DenseTensor Reshape(const Context& dev_ctx,
......
...@@ -20,13 +20,13 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -20,13 +20,13 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) { if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("ShapeTensor") > 0) { if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature( return KernelSignature(
"reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"XShape", "Out"}); "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"});
} else if (ctx.HasInput("Shape")) { } else if (ctx.HasInput("Shape")) {
return KernelSignature( return KernelSignature(
"reshape_with_xshape", {"X"}, {"Shape"}, {"XShape", "Out"}); "reshape_with_xshape", {"X"}, {"Shape"}, {"Out", "XShape"});
} else { } else {
return KernelSignature( return KernelSignature(
"reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"}); "reshape_with_xshape", {"X"}, {"shape"}, {"Out", "XShape"});
} }
} else { } else {
if (ctx.InputSize("ShapeTensor") > 0) { if (ctx.InputSize("ShapeTensor") > 0) {
......
...@@ -6269,8 +6269,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -6269,8 +6269,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
item.numpy().item(0) if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape for item in shape
] ]
out, _ = _C_ops.reshape2(x, None, 'shape', shape) out = _C_ops.final_state_reshape(x, shape)
elif isinstance(shape, tmp_tensor_type): elif isinstance(shape, tmp_tensor_type):
# TODO: Tensor shape in final_state_reshape has not been tested
shape.stop_gradient = True shape.stop_gradient = True
out, _ = _C_ops.reshape2(x, shape) out, _ = _C_ops.reshape2(x, shape)
else: else:
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
# NOTE(pangyoki): Tensor View Strategy. # NOTE(pangyoki): Tensor View Strategy.
...@@ -28,8 +29,13 @@ import paddle ...@@ -28,8 +29,13 @@ import paddle
# View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach` # View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach`
class TestDygraphViewReuseAllocation(unittest.TestCase): class TestDygraphViewReuseAllocation(unittest.TestCase):
def setUp(self): def setUp(self):
self.set_flag_to_test_eager_mode()
self.init_shape() self.init_shape()
# some op don't suport eager_final_state in temporary
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self): def init_shape(self):
self.input_shape = [2, 3, 1] self.input_shape = [2, 3, 1]
self.output_shape = [2, 3] self.output_shape = [2, 3]
...@@ -37,10 +43,13 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -37,10 +43,13 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
def view_api_processing(self, var): def view_api_processing(self, var):
return paddle.squeeze(var) return paddle.squeeze(var)
def test_view_api(self): def func_test_view_api(self):
var = paddle.rand(self.input_shape) var = paddle.rand(self.input_shape)
view_var = self.view_api_processing(var) view_var = self.view_api_processing(var)
view_var[0] = 2. # setitem don't support inplace in temporary.
# replace setitem with inplace exp_ in temporary.
# view_var[0] = 2.
view_var.exp_()
self.assertEqual(var.shape, self.input_shape) self.assertEqual(var.shape, self.input_shape)
self.assertEqual(view_var.shape, self.output_shape) self.assertEqual(view_var.shape, self.output_shape)
...@@ -48,24 +57,38 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -48,24 +57,38 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
view_var_numpy = view_var.numpy() view_var_numpy = view_var.numpy()
self.assertTrue(np.array_equal(var_numpy, view_var_numpy)) self.assertTrue(np.array_equal(var_numpy, view_var_numpy))
def test_forward_version(self): def test_view_api(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_view_api()
self.func_test_view_api()
def func_test_forward_version(self):
var = paddle.rand(self.input_shape) var = paddle.rand(self.input_shape)
self.assertEqual(var.inplace_version, 0) self.assertEqual(var.inplace_version, 0)
view_var = self.view_api_processing(var) view_var = self.view_api_processing(var)
self.assertEqual(view_var.inplace_version, 0) self.assertEqual(view_var.inplace_version, 0)
var[0] = 2. # var[0] = 2.
var.exp_()
self.assertEqual(var.inplace_version, 1) self.assertEqual(var.inplace_version, 1)
self.assertEqual(view_var.inplace_version, 1) self.assertEqual(view_var.inplace_version, 1)
view_var_2 = self.view_api_processing(var) view_var_2 = self.view_api_processing(var)
self.assertEqual(view_var_2.inplace_version, 1) self.assertEqual(view_var_2.inplace_version, 1)
var[0] = 3. # var[0] = 3.
var.exp_()
self.assertEqual(view_var.inplace_version, 2) self.assertEqual(view_var.inplace_version, 2)
self.assertEqual(view_var_2.inplace_version, 2) self.assertEqual(view_var_2.inplace_version, 2)
def test_backward_error(self): def test_forward_version(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()
def func_test_backward_error(self):
# It raises an error because the inplace operator will result # It raises an error because the inplace operator will result
# in incorrect gradient computation. # in incorrect gradient computation.
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
...@@ -77,17 +100,34 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -77,17 +100,34 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
# Here, the gradient computation will use the value of var_b # Here, the gradient computation will use the value of var_b
var_c = var_b**2 var_c = var_b**2
view_var_b = self.view_api_processing(var_b) view_var_b = self.view_api_processing(var_b)
view_var_b[0] = 2. # var_b is modified inplace # view_var_b[0] = 2. # var_b is modified inplace
view_var_b.exp_()
loss = paddle.nn.functional.relu(var_c) loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegexp( if in_dygraph_mode():
RuntimeError, with self.assertRaisesRegexp(
"received tensor_version:{} != wrapper_version_snapshot:{}". RuntimeError,
format(1, 0)): "received current_inplace_version:{} != inplace_version_snapshot_:{}".
loss.backward() format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_error(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_backward_error()
self.func_test_backward_error()
class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self): def init_shape(self):
self.input_shape = [2, 3] self.input_shape = [2, 3]
self.output_shape = [2, 3, 1] self.output_shape = [2, 3, 1]
...@@ -97,6 +137,9 @@ class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): ...@@ -97,6 +137,9 @@ class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = True
def init_shape(self): def init_shape(self):
self.input_shape = [3, 4] self.input_shape = [3, 4]
self.output_shape = [2, 2, 3] self.output_shape = [2, 2, 3]
...@@ -106,6 +149,9 @@ class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): ...@@ -106,6 +149,9 @@ class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self): def init_shape(self):
self.input_shape = [3, 4] self.input_shape = [3, 4]
self.output_shape = [12] self.output_shape = [12]
......
...@@ -1126,12 +1126,15 @@ ...@@ -1126,12 +1126,15 @@
- api : reshape - api : reshape
args : (Tensor x, ScalarArray shape) args : (Tensor x, ScalarArray shape)
output : Tensor(out) output : Tensor(out), Tensor(xshape)
infer_meta : infer_meta :
func : ReshapeInferMeta func : ReshapeWithXShapeInferMeta
kernel : kernel :
func : reshape func : reshape_with_xshape
inplace : (x -> out) inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward: reshape_grad
- api : scale - api : scale
args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) args : (Tensor x, Scalar scale, float bias, bool bias_after_scale)
......
...@@ -50,7 +50,8 @@ class BaseAPI(object): ...@@ -50,7 +50,8 @@ class BaseAPI(object):
self.support_selected_rows_kernel = False if len(self.kernel[ self.support_selected_rows_kernel = False if len(self.kernel[
'func']) == 1 else True 'func']) == 1 else True
self.data_transform = self.parse_data_transform(api_item_yaml) self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map = self.parse_inplace(api_item_yaml) self.inplace_map, self.view_map = self.parse_inplace_and_view(
api_item_yaml)
def get_api_name(self, api_item_yaml): def get_api_name(self, api_item_yaml):
return api_item_yaml['api'] return api_item_yaml['api']
...@@ -273,24 +274,30 @@ class BaseAPI(object): ...@@ -273,24 +274,30 @@ class BaseAPI(object):
return data_transform return data_transform
def parse_inplace(self, api_item_yaml): def parse_inplace_and_view(self, api_item_yaml):
if 'inplace' in api_item_yaml: inplace_map, view_map = None, None
inplace_map = {} for mode in ['inplace', 'view']:
inplace_list = api_item_yaml['inplace'].split(',') if mode in api_item_yaml:
for item in inplace_list: if mode == 'inplace':
result = re.search(r"(?P<in>\w+)\s*->\s(?P<out>\w+)", item) inplace_map = {}
in_val = result.group('in') else:
out_val = result.group('out') view_map = {}
assert in_val in self.inputs['names'], \ in_out_mapping_list = api_item_yaml[mode].split(',')
f"{self.api} : Inplace input error: the input var name('{in_val}') is not found in the input args of {self.api}." for item in in_out_mapping_list:
assert out_val in self.outputs['names'], \ result = re.search(r"(?P<in>\w+)\s*->\s(?P<out>\w+)", item)
f"{self.api} : Inplace output error: the output var name('{out_val}') is not found in the output args of {self.api}." in_val = result.group('in')
out_val = result.group('out')
inplace_map[out_val] = in_val assert in_val in self.inputs['names'], \
f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}."
return inplace_map assert out_val in self.outputs['names'], \
else: f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}."
return None
if mode == 'inplace':
inplace_map[out_val] = in_val
else:
view_map[out_val] = in_val
return inplace_map, view_map
# Override by child class # Override by child class
def get_return_type(self, out_type_list): def get_return_type(self, out_type_list):
......
...@@ -17,7 +17,7 @@ import yaml ...@@ -17,7 +17,7 @@ import yaml
import argparse import argparse
import re import re
from api_base import BaseAPI from api_base import BaseAPI, PREFIX_TENSOR_NAME
class ForwardAPI(BaseAPI): class ForwardAPI(BaseAPI):
...@@ -94,6 +94,13 @@ class ForwardAPI(BaseAPI): ...@@ -94,6 +94,13 @@ class ForwardAPI(BaseAPI):
{code_indent} {self.outputs['return_type']} api_output{inplace_assign}; {code_indent} {self.outputs['return_type']} api_output{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][0] in self.view_map:
output_create = output_create + f"""
{code_indent} kernel_out->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} api_output;""" {code_indent} {self.outputs['return_type']} api_output;"""
...@@ -109,6 +116,13 @@ class ForwardAPI(BaseAPI): ...@@ -109,6 +116,13 @@ class ForwardAPI(BaseAPI):
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));"""
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map:
output_create = output_create + f"""
{code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
raise ValueError( raise ValueError(
......
...@@ -659,6 +659,20 @@ ...@@ -659,6 +659,20 @@
kernel : kernel :
func : relu_grad func : relu_grad
- backward_api : reshape_grad
forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : reshape_grad
param : [out_grad]
data_type: out_grad
backend: out_grad
layout: out_grad
- backward_api : scale_grad - backward_api : scale_grad
forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true) args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册