未验证 提交 c86e3a11 编写于 作者: P pangyoki 提交者: GitHub

replace append_op with C_ops for assign op (#41118)

* support C_ops assign

* open unittest

* fix clone
上级 bcdffe66
......@@ -2803,7 +2803,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
// Inplace Function Generator.
// `sum` op has duplicate input. Don't consider adding inplace strategy
// for `sum` in temporary.
if (op_type != "sum" && infer_inplace) {
if (infer_inplace && !special_inplace_op_set.count(op_type)) {
auto in_to_outs = infer_inplace(true);
for (auto& inplace_pair : in_to_outs) {
inplace_map[inplace_pair.second] = inplace_pair.first;
......
......@@ -94,9 +94,9 @@ class TensorWrapper {
return paddle::experimental::Tensor();
}
check_inplace_version();
// if it's full_reserved just return the full copy of tensor
if (full_reserved_) {
check_inplace_version();
return intermidiate_tensor_;
} else {
std::shared_ptr<GradNodeBase> new_grad_node = grad_node;
......@@ -105,7 +105,6 @@ class TensorWrapper {
intermidiate_tensor_.set_autograd_meta(
std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
p_ab_autograd_meta));
check_inplace_version();
return intermidiate_tensor_;
}
}
......
......@@ -433,7 +433,7 @@ GenerateOpFunctions() {
std::map<std::string, std::string> inplace_map;
// `sum` op has duplicate input. Don't consider adding inplace strategy
// for `sum` in temporary.
if (op_type != "sum" && infer_inplace) {
if (infer_inplace && !special_inplace_op_set.count(op_type)) {
// Inplace OP: op_type_.
// The inplace OP needs a new implementation method.
auto in_to_outs = infer_inplace(true);
......
......@@ -222,6 +222,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"c_reduce", {"Out"}},
{"c_scatter", {"Out"}},
{"barrier", {"Out"}},
{"assign", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
......@@ -249,3 +250,12 @@ std::map<std::string, std::pair<std::string, std::string>> view_op_map = {
{"reshape2", {"X", "Out"}},
{"flatten_contiguous_range", {"X", "Out"}},
};
// NOTE(pangyoki): Special inplace ops that are not supported in temporary.
// The input and output of some inplace ops are special, such as
// duplicate input. These inplace ops have no usage scenarios and
// are not supported in temporary.
std::set<std::string> special_inplace_op_set = {
"sum", // `sum` op has duplicate input
"assign", // output of `assign` op is in `op_passing_outs_map`
};
......@@ -20,7 +20,7 @@ import sys
import paddle
from .. import framework
from ..framework import convert_np_dtype_to_dtype_
from ..framework import convert_np_dtype_to_dtype_, _in_legacy_dygraph
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase
......@@ -798,7 +798,11 @@ def monkey_patch_varbase():
@framework.dygraph_only
def clone(self):
return _C_ops.assign(self)
if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
return _C_ops.assign(self, output)
@framework.dygraph_only
def value(self):
......
......@@ -606,6 +606,14 @@ def assign(input, output=None):
# isinstance(VarBase, Variable) == False. It will cause return None
# after this api.
if isinstance(input, (Variable, core.VarBase)):
if _non_static_mode():
if output is None:
if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
_C_ops.assign(input, output)
else:
check_dtype(input.dtype, 'input', [
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64',
'uint8', 'bool'
......@@ -614,7 +622,8 @@ def assign(input, output=None):
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
type='assign', inputs={'X': [input]},
outputs={'Out': [output]})
elif isinstance(input, numpy.ndarray):
# Not support [var, var, ...] currently.
if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
......@@ -663,8 +672,6 @@ def assign(input, output=None):
})
if is_inplace and _non_static_mode():
# TODO(jiabin): Remove this when we support inplace
if _in_legacy_dygraph():
output._bump_inplace_version()
return output
......
......@@ -31,10 +31,6 @@ class TestInplace(unittest.TestCase):
var[0] = 1.1
self.assertEqual(var.inplace_version, 1)
# TODO1: assign don't support inplace in temporary
if in_dygraph_mode():
var[0] = 2
else:
paddle.assign(paddle.ones(shape=[3]), var)
# NOTE(liym27): assign(input, output) is an inplace operation for output.
......@@ -122,7 +118,7 @@ class TestInplace(unittest.TestCase):
loss.backward()
def test_backward_success_2(self):
# TODO2: need to process no_need_buffer in eager mode
# TODO: need to process no_need_buffer in eager mode
# with _test_eager_guard():
# self.func_test_backward_success_2()
self.func_test_backward_success_2()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册