未验证 提交 38d233d9 编写于 作者: A Aurelius84 提交者: GitHub

[Framework]Support deliver stop_gradient in static mode (#49013)

* [framework]support pass stop_gradient in static mode

* fix control_flow op stop_gradient
上级 03f72de4
...@@ -758,7 +758,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -758,7 +758,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs),
/*allow_null=*/true),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The output of(%s) should not be empty.", "The output of(%s) should not be empty.",
......
...@@ -64,7 +64,9 @@ class SelectOutputOp : public framework::OperatorBase { ...@@ -64,7 +64,9 @@ class SelectOutputOp : public framework::OperatorBase {
const framework::Variable *x = scope.FindVar(Input("X")); const framework::Variable *x = scope.FindVar(Input("X"));
framework::Variable *selected_out = scope.FindVar(out_names[output_branch]); framework::Variable *selected_out = scope.FindVar(out_names[output_branch]);
framework::VisitVarType(*x, AssignFunctor(selected_out, dev_ctx)); if (nullptr != selected_out) {
framework::VisitVarType(*x, AssignFunctor(selected_out, dev_ctx));
}
} }
}; };
...@@ -95,8 +97,10 @@ class SelectOutputInferShape : public framework::InferShapeBase { ...@@ -95,8 +97,10 @@ class SelectOutputInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput"); OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput"); OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
OP_INOUT_CHECK( OP_INOUT_CHECK(context->HasOutputs("Out", /*allow_null=*/true),
context->HasOutputs("Out", true), "Output", "Out", "SelectOutput"); "Output",
"Out",
"SelectOutput");
} }
}; };
......
...@@ -4074,6 +4074,9 @@ class Block: ...@@ -4074,6 +4074,9 @@ class Block:
param = EagerParamBase(*args, **kwargs) param = EagerParamBase(*args, **kwargs)
else: else:
param = Parameter(global_block, *args, **kwargs) param = Parameter(global_block, *args, **kwargs)
# NOTE(Aurelius84): we deliver stop_gradient in append_op, so we
# need recorde it state and reset it back after calling this API
stop_gradient = param.stop_gradient
if 'initializer' in kwargs: if 'initializer' in kwargs:
...@@ -4109,6 +4112,7 @@ class Block: ...@@ -4109,6 +4112,7 @@ class Block:
pass pass
else: else:
initializer(param, self) initializer(param, self)
param.stop_gradient = stop_gradient
return param return param
def append_op(self, *args, **kwargs): def append_op(self, *args, **kwargs):
...@@ -4118,10 +4122,10 @@ class Block: ...@@ -4118,10 +4122,10 @@ class Block:
Returns: Returns:
Operator: the append Operator. Operator: the append Operator.
""" """
op_type = kwargs.get("type", None)
if _non_static_mode(): if _non_static_mode():
attrs = kwargs.get("attrs", {}) attrs = kwargs.get("attrs", {})
inplace_map = kwargs.get("inplace_map", None) inplace_map = kwargs.get("inplace_map", None)
type = kwargs.get("type", None)
warnings.warn( warnings.warn(
"Op `%s` is executed through `append_op` under the dynamic mode, " "Op `%s` is executed through `append_op` under the dynamic mode, "
"the corresponding API implementation needs to be upgraded to " "the corresponding API implementation needs to be upgraded to "
...@@ -4131,7 +4135,7 @@ class Block: ...@@ -4131,7 +4135,7 @@ class Block:
op = Operator( op = Operator(
block=self, block=self,
desc=None, desc=None,
type=type, type=op_type,
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=attrs, attrs=attrs,
...@@ -4143,7 +4147,7 @@ class Block: ...@@ -4143,7 +4147,7 @@ class Block:
# currently, we only support stop_gradient in dygraph mode. # currently, we only support stop_gradient in dygraph mode.
_dygraph_tracer().trace_op( _dygraph_tracer().trace_op(
type, op_type,
kwargs.get("inputs", {}), kwargs.get("inputs", {}),
kwargs.get("outputs", {}), kwargs.get("outputs", {}),
attrs if attrs else {}, attrs if attrs else {},
...@@ -4152,18 +4156,43 @@ class Block: ...@@ -4152,18 +4156,43 @@ class Block:
) )
else: else:
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
from .layers.utils import flatten
def pass_stop_gradient(ins, outs):
"""
Set out.stop_gradient = True if all inputs stop_gradient is True.
"""
need_reset = True
for var in flatten(ins):
if getattr(var, 'stop_gradient', None) is False:
need_reset = False
break
if need_reset:
for var in flatten(outs):
if isinstance(var, Variable):
var.stop_gradient = True
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)
# NOTE(Aurelius84): In case of @to_static, all VarBase(s) should # NOTE(Aurelius84): In case of @to_static, all VarBase(s) should
# be converted into Variable(s) with same name and block location. # be converted into Variable(s) with same name and block location.
# This is ONE and ONLY logic of type transformation of dy2static. # This is ONE and ONLY logic of type transformation of dy2static.
inputs = kwargs.get("inputs", None) ignore_ops = {
outputs = kwargs.get("outputs", None) 'conditional_block',
'conditional_block_grad',
'recurrent',
'recurrent_grad',
'while',
'while_grad',
}
if op_type not in ignore_ops:
pass_stop_gradient(inputs, outputs)
with param_guard(inputs), param_guard(outputs): with param_guard(inputs), param_guard(outputs):
op = Operator( op = Operator(
block=self, block=self,
desc=op_desc, desc=op_desc,
type=kwargs.get("type", None), type=op_type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=kwargs.get("attrs", None), attrs=kwargs.get("attrs", None),
......
...@@ -33,6 +33,7 @@ class FusionGroupPassTest(PassTest): ...@@ -33,6 +33,7 @@ class FusionGroupPassTest(PassTest):
# subgraph with only 1 op node # subgraph with only 1 op node
tmp_0 = self.feed_vars[0] * self.feed_vars[1] tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_0.stop_gradient = False
tmp_1 = paddle.matmul(tmp_0, self.feed_vars[2]) tmp_1 = paddle.matmul(tmp_0, self.feed_vars[2])
# subgraph with 2 op nodes # subgraph with 2 op nodes
tmp_2 = paddle.nn.functional.relu(tmp_0 + tmp_1) tmp_2 = paddle.nn.functional.relu(tmp_0 + tmp_1)
...@@ -48,10 +49,11 @@ class FusionGroupPassTest(PassTest): ...@@ -48,10 +49,11 @@ class FusionGroupPassTest(PassTest):
self.pass_names = "fusion_group_pass" self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group" self.fused_op_type = "fusion_group"
def _prepare_feed_vars(self, shape, dtype, num_data): def _prepare_feed_vars(self, shape, dtype, num_data, stop_gradient=True):
feed_vars = [] feed_vars = []
for i in range(num_data): for i in range(num_data):
var = fluid.data(name=("data" + str(i)), shape=shape, dtype=dtype) var = fluid.data(name=("data" + str(i)), shape=shape, dtype=dtype)
var.stop_gradient = stop_gradient
feed_vars.append(var) feed_vars.append(var)
return feed_vars return feed_vars
...@@ -82,7 +84,7 @@ class FusionGroupPassTest(PassTest): ...@@ -82,7 +84,7 @@ class FusionGroupPassTest(PassTest):
class FusionGroupPassComplicatedTest(FusionGroupPassTest): class FusionGroupPassComplicatedTest(FusionGroupPassTest):
def build_program(self, dtype): def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 64], dtype, 5) self.feed_vars = self._prepare_feed_vars([32, 64], dtype, 5, False)
one = layers.fill_constant(shape=[1], dtype=dtype, value=1.0) one = layers.fill_constant(shape=[1], dtype=dtype, value=1.0)
tmp_0 = one * self.feed_vars[0] tmp_0 = one * self.feed_vars[0]
...@@ -138,6 +140,7 @@ class FusionGroupPassTestCastAndFP16(FusionGroupPassTest): ...@@ -138,6 +140,7 @@ class FusionGroupPassTestCastAndFP16(FusionGroupPassTest):
# subgraph with 2 op nodes # subgraph with 2 op nodes
tmp_0 = self.feed_vars[0] * self.feed_vars[1] tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_0.stop_gradient = False
tmp_1 = paddle.cast(tmp_0, dtype="float16") tmp_1 = paddle.cast(tmp_0, dtype="float16")
zero = layers.fill_constant(shape=[128], dtype="float16", value=0) zero = layers.fill_constant(shape=[128], dtype="float16", value=0)
# TODO(xreki): fix precision problem when using softmax of float16. # TODO(xreki): fix precision problem when using softmax of float16.
...@@ -168,6 +171,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest): ...@@ -168,6 +171,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
tmp_0 = paddle.add_n( tmp_0 = paddle.add_n(
[self.feed_vars[0], self.feed_vars[1], self.feed_vars[2]] [self.feed_vars[0], self.feed_vars[1], self.feed_vars[2]]
) )
tmp_0.stop_gradient = False
tmp_1 = paddle.sqrt(tmp_0) tmp_1 = paddle.sqrt(tmp_0)
tmp_2 = paddle.matmul(tmp_0, self.feed_vars[3]) tmp_2 = paddle.matmul(tmp_0, self.feed_vars[3])
# subgraph with 2 op nodes # subgraph with 2 op nodes
...@@ -185,6 +189,7 @@ class FusionGroupPassCastTest(FusionGroupPassTest): ...@@ -185,6 +189,7 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2) self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2)
tmp_0 = paddle.add(self.feed_vars[0], self.feed_vars[1]) tmp_0 = paddle.add(self.feed_vars[0], self.feed_vars[1])
tmp_0.stop_gradient = False
tmp_1 = paddle.cast(tmp_0, dtype="float64") tmp_1 = paddle.cast(tmp_0, dtype="float64")
tmp_2 = paddle.cast(tmp_1, dtype="float32") tmp_2 = paddle.cast(tmp_1, dtype="float32")
...@@ -206,6 +211,7 @@ class FusionGroupPassFillConstantTest(FusionGroupPassTest): ...@@ -206,6 +211,7 @@ class FusionGroupPassFillConstantTest(FusionGroupPassTest):
self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2) self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2)
tmp_0 = paddle.add(self.feed_vars[0], self.feed_vars[1]) tmp_0 = paddle.add(self.feed_vars[0], self.feed_vars[1])
tmp_0.stop_gradient = False
tmp_1 = layers.fill_constant(shape=[2, 2], dtype=dtype, value=2.0) tmp_1 = layers.fill_constant(shape=[2, 2], dtype=dtype, value=2.0)
tmp_2 = paddle.scale( tmp_2 = paddle.scale(
tmp_1, scale=3.0, bias=1.0, bias_after_scale=True tmp_1, scale=3.0, bias=1.0, bias_after_scale=True
......
...@@ -464,6 +464,7 @@ class TestCondNestedControlFlow(unittest.TestCase): ...@@ -464,6 +464,7 @@ class TestCondNestedControlFlow(unittest.TestCase):
startup_program = Program() startup_program = Program()
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='float32') i = fluid.data(name="i", shape=[1], dtype='float32')
i.stop_gradient = False
a = 2.0 * i a = 2.0 * i
out = paddle.static.nn.cond( out = paddle.static.nn.cond(
i < 5.0, i < 5.0,
......
...@@ -23,6 +23,7 @@ from paddle.fluid import core ...@@ -23,6 +23,7 @@ from paddle.fluid import core
SEED = 1 SEED = 1
DTYPE = "float32" DTYPE = "float32"
paddle.dataset.mnist.fetch() paddle.dataset.mnist.fetch()
paddle.enable_static()
# random seed must set before configuring the network. # random seed must set before configuring the network.
...@@ -207,7 +208,7 @@ class TestCloneWithStopGradient(unittest.TestCase): ...@@ -207,7 +208,7 @@ class TestCloneWithStopGradient(unittest.TestCase):
test_program.block(0).var(hidden1.name).stop_gradient, True test_program.block(0).var(hidden1.name).stop_gradient, True
) )
self.assertEqual( self.assertEqual(
test_program.block(0).var(hidden2.name).stop_gradient, False test_program.block(0).var(hidden2.name).stop_gradient, True
) )
......
...@@ -120,7 +120,7 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase): ...@@ -120,7 +120,7 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase):
gc_vars = core._get_eager_deletion_vars( gc_vars = core._get_eager_deletion_vars(
fluid.default_main_program().desc, [loss.name] fluid.default_main_program().desc, [loss.name]
) )
self.assertEqual(len(gc_vars), 5) self.assertEqual(len(gc_vars), 3)
exe = Executor(self.place) exe = Executor(self.place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -81,7 +81,9 @@ class TestRecurrentFeed(unittest.TestCase): ...@@ -81,7 +81,9 @@ class TestRecurrentFeed(unittest.TestCase):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
in1 = paddle.static.data(name="inp1", shape=[2, 2]) in1 = paddle.static.data(name="inp1", shape=[2, 2])
in1.stop_gradient = False
in2 = paddle.static.data(name="inp2", shape=[2, 2]) in2 = paddle.static.data(name="inp2", shape=[2, 2])
in2.stop_gradient = False
rt1 = RecurrentTest("RecurrentTest") rt1 = RecurrentTest("RecurrentTest")
static_sum_out, static_out = rt1(in1, in2) static_sum_out, static_out = rt1(in1, in2)
fluid.backward.append_backward(static_sum_out) fluid.backward.append_backward(static_sum_out)
......
...@@ -349,6 +349,7 @@ class TestConstantPadDoubleGradCheck(unittest.TestCase): ...@@ -349,6 +349,7 @@ class TestConstantPadDoubleGradCheck(unittest.TestCase):
x = paddle.static.data('x', x_shape, dtype) x = paddle.static.data('x', x_shape, dtype)
x.persistable = True x.persistable = True
x.stop_gradient = False
out = paddle.nn.functional.pad(x, pad) out = paddle.nn.functional.pad(x, pad)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
......
...@@ -1167,6 +1167,7 @@ class TestRecomputeOptimizer(unittest.TestCase): ...@@ -1167,6 +1167,7 @@ class TestRecomputeOptimizer(unittest.TestCase):
prediction = paddle.static.nn.fc( prediction = paddle.static.nn.fc(
x=[drop_res], size=2, activation='softmax' x=[drop_res], size=2, activation='softmax'
) )
drop_res.stop_gradient = False
cost = paddle.nn.functional.cross_entropy( cost = paddle.nn.functional.cross_entropy(
input=prediction, input=prediction,
label=input_y, label=input_y,
...@@ -1231,6 +1232,7 @@ class TestRecomputeOptimizerCUDA(unittest.TestCase): ...@@ -1231,6 +1232,7 @@ class TestRecomputeOptimizerCUDA(unittest.TestCase):
prediction = paddle.static.nn.fc( prediction = paddle.static.nn.fc(
x=[drop_res], size=2, activation='softmax' x=[drop_res], size=2, activation='softmax'
) )
drop_res.stop_gradient = False
cost = paddle.nn.functional.cross_entropy( cost = paddle.nn.functional.cross_entropy(
input=prediction, input=prediction,
label=input_y, label=input_y,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
class TestStopGradient(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def create_var(self, stop_gradient):
x = paddle.randn([2, 4])
x.stop_gradient = stop_gradient
return x
def test_unary(self):
x = self.create_var(True)
out = x.reshape([4, -1])
self.assertTrue(out.stop_gradient)
def test_binary(self):
x = self.create_var(True)
y = self.create_var(True)
out = x + y
self.assertTrue(out.stop_gradient)
def test_binary2(self):
x = self.create_var(True)
y = self.create_var(False)
out = x + y
self.assertFalse(out.stop_gradient)
if __name__ == '__main__':
unittest.main()
...@@ -996,6 +996,8 @@ class TestBackward(unittest.TestCase): ...@@ -996,6 +996,8 @@ class TestBackward(unittest.TestCase):
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
y = paddle.static.data(name="y", shape=[4, 4], dtype='float32') y = paddle.static.data(name="y", shape=[4, 4], dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
label = paddle.static.data( label = paddle.static.data(
name="label", shape=[4, 1], dtype='int64' name="label", shape=[4, 1], dtype='int64'
......
...@@ -139,6 +139,7 @@ class TestWhereAPI(unittest.TestCase): ...@@ -139,6 +139,7 @@ class TestWhereAPI(unittest.TestCase):
y.stop_gradient = y_stop_gradient y.stop_gradient = y_stop_gradient
y.desc.set_need_check_feed(False) y.desc.set_need_check_feed(False)
result = paddle.where(cond, x, y) result = paddle.where(cond, x, y)
result.stop_gradient = False
append_backward(paddle.mean(result)) append_backward(paddle.mean(result))
for use_cuda in [False, True]: for use_cuda in [False, True]:
if use_cuda and ( if use_cuda and (
......
...@@ -353,6 +353,7 @@ class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): ...@@ -353,6 +353,7 @@ class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
init = layers.zeros(shape=[10], dtype='float32') init = layers.zeros(shape=[10], dtype='float32')
mem_array = paddle.tensor.array_write(x=init, i=i) mem_array = paddle.tensor.array_write(x=init, i=i)
data_array = paddle.tensor.array_write(x=d0, i=i) data_array = paddle.tensor.array_write(x=d0, i=i)
mem_array.stop_gradient = False
i = paddle.increment(i) i = paddle.increment(i)
paddle.tensor.array_write(d1, i, array=data_array) paddle.tensor.array_write(d1, i, array=data_array)
i = paddle.increment(i) i = paddle.increment(i)
......
...@@ -3063,8 +3063,8 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3063,8 +3063,8 @@ class TestSundryAPIStatic(unittest.TestCase):
@prog_scope() @prog_scope()
def test_unsqueeze(self): def test_unsqueeze(self):
x1 = paddle.full([], 2) x1 = paddle.full([], 2)
out1 = paddle.unsqueeze(x1, axis=0)
x1.stop_gradient = False x1.stop_gradient = False
out1 = paddle.unsqueeze(x1, axis=0)
paddle.static.append_backward(out1.sum()) paddle.static.append_backward(out1.sum())
x2 = paddle.full([], 3) x2 = paddle.full([], 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册