未验证 提交 c86e3a11 编写于 作者: P pangyoki 提交者: GitHub

replace append_op with C_ops for assign op (#41118)

* support C_ops assign

* open unittest

* fix clone
上级 bcdffe66
...@@ -2803,7 +2803,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -2803,7 +2803,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
// Inplace Function Generator. // Inplace Function Generator.
// `sum` op has duplicate input. Don't consider adding inplace strategy // `sum` op has duplicate input. Don't consider adding inplace strategy
// for `sum` in temporary. // for `sum` in temporary.
if (op_type != "sum" && infer_inplace) { if (infer_inplace && !special_inplace_op_set.count(op_type)) {
auto in_to_outs = infer_inplace(true); auto in_to_outs = infer_inplace(true);
for (auto& inplace_pair : in_to_outs) { for (auto& inplace_pair : in_to_outs) {
inplace_map[inplace_pair.second] = inplace_pair.first; inplace_map[inplace_pair.second] = inplace_pair.first;
......
...@@ -94,9 +94,9 @@ class TensorWrapper { ...@@ -94,9 +94,9 @@ class TensorWrapper {
return paddle::experimental::Tensor(); return paddle::experimental::Tensor();
} }
check_inplace_version();
// if it's full_reserved just return the full copy of tensor // if it's full_reserved just return the full copy of tensor
if (full_reserved_) { if (full_reserved_) {
check_inplace_version();
return intermidiate_tensor_; return intermidiate_tensor_;
} else { } else {
std::shared_ptr<GradNodeBase> new_grad_node = grad_node; std::shared_ptr<GradNodeBase> new_grad_node = grad_node;
...@@ -105,7 +105,6 @@ class TensorWrapper { ...@@ -105,7 +105,6 @@ class TensorWrapper {
intermidiate_tensor_.set_autograd_meta( intermidiate_tensor_.set_autograd_meta(
std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
p_ab_autograd_meta)); p_ab_autograd_meta));
check_inplace_version();
return intermidiate_tensor_; return intermidiate_tensor_;
} }
} }
......
...@@ -433,7 +433,7 @@ GenerateOpFunctions() { ...@@ -433,7 +433,7 @@ GenerateOpFunctions() {
std::map<std::string, std::string> inplace_map; std::map<std::string, std::string> inplace_map;
// `sum` op has duplicate input. Don't consider adding inplace strategy // `sum` op has duplicate input. Don't consider adding inplace strategy
// for `sum` in temporary. // for `sum` in temporary.
if (op_type != "sum" && infer_inplace) { if (infer_inplace && !special_inplace_op_set.count(op_type)) {
// Inplace OP: op_type_. // Inplace OP: op_type_.
// The inplace OP needs a new implementation method. // The inplace OP needs a new implementation method.
auto in_to_outs = infer_inplace(true); auto in_to_outs = infer_inplace(true);
......
...@@ -222,6 +222,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -222,6 +222,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"c_reduce", {"Out"}}, {"c_reduce", {"Out"}},
{"c_scatter", {"Out"}}, {"c_scatter", {"Out"}},
{"barrier", {"Out"}}, {"barrier", {"Out"}},
{"assign", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max", {"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}}, {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
...@@ -249,3 +250,12 @@ std::map<std::string, std::pair<std::string, std::string>> view_op_map = { ...@@ -249,3 +250,12 @@ std::map<std::string, std::pair<std::string, std::string>> view_op_map = {
{"reshape2", {"X", "Out"}}, {"reshape2", {"X", "Out"}},
{"flatten_contiguous_range", {"X", "Out"}}, {"flatten_contiguous_range", {"X", "Out"}},
}; };
// NOTE(pangyoki): Special inplace ops that are not supported in temporary.
// The input and output of some inplace ops are special, such as
// duplicate input. These inplace ops have no usage scenarios and
// are not supported in temporary.
std::set<std::string> special_inplace_op_set = {
"sum", // `sum` op has duplicate input
"assign", // output of `assign` op is in `op_passing_outs_map`
};
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
import paddle import paddle
from .. import framework from .. import framework
from ..framework import convert_np_dtype_to_dtype_ from ..framework import convert_np_dtype_to_dtype_, _in_legacy_dygraph
from .. import core from .. import core
from .. import unique_name from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase
...@@ -798,7 +798,11 @@ def monkey_patch_varbase(): ...@@ -798,7 +798,11 @@ def monkey_patch_varbase():
@framework.dygraph_only @framework.dygraph_only
def clone(self): def clone(self):
return _C_ops.assign(self) if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
return _C_ops.assign(self, output)
@framework.dygraph_only @framework.dygraph_only
def value(self): def value(self):
......
...@@ -606,15 +606,24 @@ def assign(input, output=None): ...@@ -606,15 +606,24 @@ def assign(input, output=None):
# isinstance(VarBase, Variable) == False. It will cause return None # isinstance(VarBase, Variable) == False. It will cause return None
# after this api. # after this api.
if isinstance(input, (Variable, core.VarBase)): if isinstance(input, (Variable, core.VarBase)):
check_dtype(input.dtype, 'input', [ if _non_static_mode():
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', if output is None:
'uint8', 'bool' if _in_legacy_dygraph():
], 'assign', '(When the type of input in assign is Variable.)') output = core.VarBase()
if output is None: else:
output = helper.create_variable_for_type_inference( output = core.eager.Tensor()
dtype=input.dtype) _C_ops.assign(input, output)
helper.append_op( else:
type='assign', inputs={'X': [input]}, outputs={'Out': [output]}) check_dtype(input.dtype, 'input', [
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64',
'uint8', 'bool'
], 'assign', '(When the type of input in assign is Variable.)')
if output is None:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
helper.append_op(
type='assign', inputs={'X': [input]},
outputs={'Out': [output]})
elif isinstance(input, numpy.ndarray): elif isinstance(input, numpy.ndarray):
# Not support [var, var, ...] currently. # Not support [var, var, ...] currently.
if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input): if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
...@@ -663,9 +672,7 @@ def assign(input, output=None): ...@@ -663,9 +672,7 @@ def assign(input, output=None):
}) })
if is_inplace and _non_static_mode(): if is_inplace and _non_static_mode():
# TODO(jiabin): Remove this when we support inplace output._bump_inplace_version()
if _in_legacy_dygraph():
output._bump_inplace_version()
return output return output
......
...@@ -31,11 +31,7 @@ class TestInplace(unittest.TestCase): ...@@ -31,11 +31,7 @@ class TestInplace(unittest.TestCase):
var[0] = 1.1 var[0] = 1.1
self.assertEqual(var.inplace_version, 1) self.assertEqual(var.inplace_version, 1)
# TODO1: assign don't support inplace in temporary paddle.assign(paddle.ones(shape=[3]), var)
if in_dygraph_mode():
var[0] = 2
else:
paddle.assign(paddle.ones(shape=[3]), var)
# NOTE(liym27): assign(input, output) is an inplace operation for output. # NOTE(liym27): assign(input, output) is an inplace operation for output.
# There is inplace-related processing for api assign, var.inplace_version should be 2 not 1. # There is inplace-related processing for api assign, var.inplace_version should be 2 not 1.
...@@ -122,7 +118,7 @@ class TestInplace(unittest.TestCase): ...@@ -122,7 +118,7 @@ class TestInplace(unittest.TestCase):
loss.backward() loss.backward()
def test_backward_success_2(self): def test_backward_success_2(self):
# TODO2: need to process no_need_buffer in eager mode # TODO: need to process no_need_buffer in eager mode
# with _test_eager_guard(): # with _test_eager_guard():
# self.func_test_backward_success_2() # self.func_test_backward_success_2()
self.func_test_backward_success_2() self.func_test_backward_success_2()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册