未验证 提交 2f1c1ae5 编写于 作者: P pangyoki 提交者: GitHub

support view strategy in eager_fluid state (#40830)

* support view strategy in eager_fluid state

* little change

* little change

* optimize unittest

* fix
上级 56493c9e
...@@ -1707,10 +1707,31 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1707,10 +1707,31 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
} }
} }
} }
generated_function_body += "\n";
VLOG(6) << "Generated Outs Map"; VLOG(6) << "Generated Outs Map";
// [Generation] Apply View Strategy (Tensor)
if (inplace_map.empty() && view_op_map.count(op_type)) {
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT =
" if (ins.count(\"%s\") && outs.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], "
"outs[\"%s\"][0]);\n"
" };\n";
std::string view_strategy_str = "";
std::string viwe_input_name = view_op_map[op_type].first;
std::string viwe_output_name = view_op_map[op_type].second;
view_strategy_str += paddle::string::Sprintf(
HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT, viwe_input_name, viwe_output_name,
viwe_input_name, viwe_output_name);
generated_function_body += view_strategy_str;
generated_function_body += "\n";
VLOG(6) << "Generated View Strategy";
}
generated_function_body += "\n";
// [Generation] Get Attrs // [Generation] Get Attrs
dygraph_function_args_str += dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map"; ", const paddle::framework::AttributeMap& attr_map";
......
...@@ -244,6 +244,33 @@ std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars( ...@@ -244,6 +244,33 @@ std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
return res; return res;
} }
void EagerUtils::HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var) {
PADDLE_ENFORCE_EQ(
input_var->Var().IsInitialized(), true,
paddle::platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", input_var->name()));
if (phi::DenseTensor::classof(input_var->GetTensorBase().get())) {
auto input_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input_var->GetTensorBase());
PADDLE_ENFORCE_EQ(
input_dense_tensor->IsInitialized(), true,
paddle::platform::errors::InvalidArgument(
"DenseTensor %s has not been initialized!", input_var->name()));
auto* view_output_tensor =
view_output_var->MutableVar()->GetMutable<phi::DenseTensor>();
view_output_tensor->ShareBufferWith(*input_dense_tensor);
view_output_tensor->ShareInplaceVersionCounterWith(*input_dense_tensor);
VLOG(3) << "Perform View between Output Var(" << view_output_var->name()
<< ") and Input Var(" << input_var->name()
<< "), share allocation and inplace version.";
}
}
void EagerUtils::ModifyInplaceInput( void EagerUtils::ModifyInplaceInput(
const std::shared_ptr<EagerVariable>& inplace_variable, const std::shared_ptr<EagerVariable>& inplace_variable,
paddle::experimental::Tensor* inplace_tensor) { paddle::experimental::Tensor* inplace_tensor) {
......
...@@ -168,6 +168,11 @@ class EagerUtils { ...@@ -168,6 +168,11 @@ class EagerUtils {
} }
} }
// View Strategy
static void HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var);
// TensorWrapper Utils // TensorWrapper Utils
static paddle::experimental::Tensor RecoverTensorWrapper( static paddle::experimental::Tensor RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node); TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
......
...@@ -29,13 +29,8 @@ from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode ...@@ -29,13 +29,8 @@ from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
# View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach` # View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach`
class TestDygraphViewReuseAllocation(unittest.TestCase): class TestDygraphViewReuseAllocation(unittest.TestCase):
def setUp(self): def setUp(self):
self.set_flag_to_test_eager_mode()
self.init_shape() self.init_shape()
# some op don't suport eager_final_state in temporary
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self): def init_shape(self):
self.input_shape = [2, 3, 1] self.input_shape = [2, 3, 1]
self.output_shape = [2, 3] self.output_shape = [2, 3]
...@@ -46,10 +41,7 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -46,10 +41,7 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
def func_test_view_api(self): def func_test_view_api(self):
var = paddle.rand(self.input_shape) var = paddle.rand(self.input_shape)
view_var = self.view_api_processing(var) view_var = self.view_api_processing(var)
# setitem don't support inplace in temporary. view_var[0] = 2.
# replace setitem with inplace exp_ in temporary.
# view_var[0] = 2.
view_var.exp_()
self.assertEqual(var.shape, self.input_shape) self.assertEqual(var.shape, self.input_shape)
self.assertEqual(view_var.shape, self.output_shape) self.assertEqual(view_var.shape, self.output_shape)
...@@ -58,7 +50,6 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -58,7 +50,6 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
self.assertTrue(np.array_equal(var_numpy, view_var_numpy)) self.assertTrue(np.array_equal(var_numpy, view_var_numpy))
def test_view_api(self): def test_view_api(self):
if self.flag_test_eager_mode:
with _test_eager_guard(): with _test_eager_guard():
self.func_test_view_api() self.func_test_view_api()
self.func_test_view_api() self.func_test_view_api()
...@@ -69,21 +60,18 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -69,21 +60,18 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
view_var = self.view_api_processing(var) view_var = self.view_api_processing(var)
self.assertEqual(view_var.inplace_version, 0) self.assertEqual(view_var.inplace_version, 0)
# var[0] = 2. var[0] = 2.
var.exp_()
self.assertEqual(var.inplace_version, 1) self.assertEqual(var.inplace_version, 1)
self.assertEqual(view_var.inplace_version, 1) self.assertEqual(view_var.inplace_version, 1)
view_var_2 = self.view_api_processing(var) view_var_2 = self.view_api_processing(var)
self.assertEqual(view_var_2.inplace_version, 1) self.assertEqual(view_var_2.inplace_version, 1)
# var[0] = 3. var[0] = 3.
var.exp_()
self.assertEqual(view_var.inplace_version, 2) self.assertEqual(view_var.inplace_version, 2)
self.assertEqual(view_var_2.inplace_version, 2) self.assertEqual(view_var_2.inplace_version, 2)
def test_forward_version(self): def test_forward_version(self):
if self.flag_test_eager_mode:
with _test_eager_guard(): with _test_eager_guard():
self.func_test_forward_version() self.func_test_forward_version()
self.func_test_forward_version() self.func_test_forward_version()
...@@ -100,8 +88,7 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -100,8 +88,7 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
# Here, the gradient computation will use the value of var_b # Here, the gradient computation will use the value of var_b
var_c = var_b**2 var_c = var_b**2
view_var_b = self.view_api_processing(var_b) view_var_b = self.view_api_processing(var_b)
# view_var_b[0] = 2. # var_b is modified inplace view_var_b[0] = 2. # var_b is modified inplace
view_var_b.exp_()
loss = paddle.nn.functional.relu(var_c) loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -118,16 +105,12 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -118,16 +105,12 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
loss.backward() loss.backward()
def test_backward_error(self): def test_backward_error(self):
if self.flag_test_eager_mode:
with _test_eager_guard(): with _test_eager_guard():
self.func_test_backward_error() self.func_test_backward_error()
self.func_test_backward_error() self.func_test_backward_error()
class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self): def init_shape(self):
self.input_shape = [2, 3] self.input_shape = [2, 3]
self.output_shape = [2, 3, 1] self.output_shape = [2, 3, 1]
...@@ -137,9 +120,6 @@ class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): ...@@ -137,9 +120,6 @@ class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = True
def init_shape(self): def init_shape(self):
self.input_shape = [3, 4] self.input_shape = [3, 4]
self.output_shape = [2, 2, 3] self.output_shape = [2, 2, 3]
...@@ -149,9 +129,6 @@ class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): ...@@ -149,9 +129,6 @@ class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation): class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False
def init_shape(self): def init_shape(self):
self.input_shape = [3, 4] self.input_shape = [3, 4]
self.output_shape = [12] self.output_shape = [12]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册