未验证 提交 770395cb 编写于 作者: L Leo Chen 提交者: GitHub

Split train_mode and has_grad for tracer (#29064)

* split train_mode and has_grad

* fix format

* fix ci problems

* fix sample code
上级 cddc7096
...@@ -38,11 +38,20 @@ void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) { ...@@ -38,11 +38,20 @@ void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
} }
static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) { static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
for (const auto& name_pair : outs) { for (const auto& pair : outs) {
for (const auto& vb : name_pair.second) { for (const auto& var : pair.second) {
VLOG(6) << "Set output: " << vb->Name() << "'s OverridedStopGradient as " // NOTE(zhiqiu): this happends when None output are passed from python
// side. For example, fake_quantize_dequantize_moving_average_abs_max may
// pass None OutAccum in eval mode.
// It can be refined by generate several different pybind interface for
// one operator with different function signature.
if (var == nullptr) {
VLOG(4) << pair.first << " is NULL";
continue;
}
VLOG(6) << "Set output: " << var->Name() << "'s OverridedStopGradient as "
<< generate_grad; << generate_grad;
vb->InnerSetOverridedStopGradient(generate_grad); var->InnerSetOverridedStopGradient(generate_grad);
} }
} }
} }
......
...@@ -1087,7 +1087,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -1087,7 +1087,7 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::SetEnableProgramDescTracing) &imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled, .def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
&imperative::Tracer::SetEnableAutoCast) &imperative::Tracer::SetEnableAutoCast)
.def_property("_train_mode", &imperative::Tracer::HasGrad, .def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad) &imperative::Tracer::SetHasGrad)
.def_property( .def_property(
"_expected_place", "_expected_place",
......
...@@ -190,12 +190,12 @@ def disable_dygraph(): ...@@ -190,12 +190,12 @@ def disable_dygraph():
def _switch_tracer_mode_guard_(is_train=True): def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer: if tracer:
mode = tracer._train_mode has_grad = tracer._has_grad
tracer._train_mode = is_train tracer._has_grad = is_train
try: try:
yield yield
finally: finally:
tracer._train_mode = mode tracer._has_grad = has_grad
else: else:
yield yield
......
...@@ -41,7 +41,7 @@ class Tracer(core.Tracer): ...@@ -41,7 +41,7 @@ class Tracer(core.Tracer):
def trace_op(self, type, inputs, outputs, attrs, stop_gradient=False): def trace_op(self, type, inputs, outputs, attrs, stop_gradient=False):
self.trace(type, inputs, outputs, attrs, self.trace(type, inputs, outputs, attrs,
framework._current_expected_place(), self._train_mode and framework._current_expected_place(), self._has_grad and
not stop_gradient) not stop_gradient)
def train_mode(self): def train_mode(self):
......
...@@ -945,7 +945,7 @@ def cos_sim(X, Y): ...@@ -945,7 +945,7 @@ def cos_sim(X, Y):
@deprecated(since="2.0.0", update_to="paddle.nn.functional.dropout") @deprecated(since="2.0.0", update_to="paddle.nn.functional.dropout")
def dropout(x, def dropout(x,
dropout_prob, dropout_prob,
is_test=False, is_test=None,
seed=None, seed=None,
name=None, name=None,
dropout_implementation="downgrade_in_infer"): dropout_implementation="downgrade_in_infer"):
...@@ -965,6 +965,7 @@ def dropout(x, ...@@ -965,6 +965,7 @@ def dropout(x,
x (Variable): The input tensor variable. The data type is float16 or float32 or float64. x (Variable): The input tensor variable. The data type is float16 or float32 or float64.
dropout_prob (float): Probability of setting units to zero. dropout_prob (float): Probability of setting units to zero.
is_test (bool): A flag indicating whether it is in test phrase or not. is_test (bool): A flag indicating whether it is in test phrase or not.
Default None, in dynamic graph, it use global tracer mode; in static graph, it means False.
seed (int): A Python integer used to create random seeds. If this seed (int): A Python integer used to create random seeds. If this
parameter is set to None, a random seed is used. parameter is set to None, a random seed is used.
NOTE: If an integer seed is given, always the same output NOTE: If an integer seed is given, always the same output
...@@ -996,7 +997,10 @@ def dropout(x, ...@@ -996,7 +997,10 @@ def dropout(x,
.. code-block:: python .. code-block:: python
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
paddle.enable_static()
x = fluid.data(name="data", shape=[None, 32, 32], dtype="float32") x = fluid.data(name="data", shape=[None, 32, 32], dtype="float32")
dropped = fluid.layers.dropout(x, dropout_prob=0.5) dropped = fluid.layers.dropout(x, dropout_prob=0.5)
""" """
...@@ -1017,9 +1021,10 @@ def dropout(x, ...@@ -1017,9 +1021,10 @@ def dropout(x,
if (seed is None or if (seed is None or
seed == 0) and default_main_program().random_seed != 0: seed == 0) and default_main_program().random_seed != 0:
seed = default_main_program().random_seed seed = default_main_program().random_seed
_is_test = not _dygraph_tracer()._train_mode if is_test is None:
is_test = not _dygraph_tracer()._train_mode
out, mask = core.ops.dropout( out, mask = core.ops.dropout(
x, 'dropout_prob', dropout_prob, 'is_test', _is_test, 'fix_seed', x, 'dropout_prob', dropout_prob, 'is_test', is_test, 'fix_seed',
seed is not None, 'seed', seed if seed is not None else 0, seed is not None, 'seed', seed if seed is not None else 0,
'dropout_implementation', dropout_implementation) 'dropout_implementation', dropout_implementation)
return out return out
......
...@@ -64,7 +64,7 @@ class PrePostProcessLayer(Layer): ...@@ -64,7 +64,7 @@ class PrePostProcessLayer(Layer):
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
if dropout_rate: if dropout_rate:
self.functors.append(lambda x: layers.dropout( self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False)) x, dropout_prob=dropout_rate))
def forward(self, x, residual=None): def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd): for i, cmd in enumerate(self.process_cmd):
...@@ -137,8 +137,7 @@ class MultiHeadAttention(Layer): ...@@ -137,8 +137,7 @@ class MultiHeadAttention(Layer):
product += attn_bias product += attn_bias
weights = layers.softmax(product) weights = layers.softmax(product)
if self.dropout_rate: if self.dropout_rate:
weights = layers.dropout( weights = layers.dropout(weights, dropout_prob=self.dropout_rate)
weights, dropout_prob=self.dropout_rate, is_test=False)
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
out = layers.transpose(out, perm=[0, 2, 1, 3]) out = layers.transpose(out, perm=[0, 2, 1, 3])
out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
...@@ -156,8 +155,7 @@ class FFN(Layer): ...@@ -156,8 +155,7 @@ class FFN(Layer):
def forward(self, x): def forward(self, x):
hidden = self.fc1(x) hidden = self.fc1(x)
if self.dropout_rate: if self.dropout_rate:
hidden = layers.dropout( hidden = layers.dropout(hidden, dropout_prob=self.dropout_rate)
hidden, dropout_prob=self.dropout_rate, is_test=False)
out = self.fc2(hidden) out = self.fc2(hidden)
return out return out
...@@ -276,8 +274,8 @@ class WrapEncoder(Layer): ...@@ -276,8 +274,8 @@ class WrapEncoder(Layer):
pos_enc.stop_gradient = True pos_enc.stop_gradient = True
emb = word_emb + pos_enc emb = word_emb + pos_enc
enc_input = layers.dropout( enc_input = layers.dropout(
emb, dropout_prob=self.emb_dropout, emb,
is_test=False) if self.emb_dropout else emb dropout_prob=self.emb_dropout, ) if self.emb_dropout else emb
enc_output = self.encoder(enc_input, src_slf_attn_bias) enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output return enc_output
...@@ -407,8 +405,8 @@ class WrapDecoder(Layer): ...@@ -407,8 +405,8 @@ class WrapDecoder(Layer):
pos_enc.stop_gradient = True pos_enc.stop_gradient = True
emb = word_emb + pos_enc emb = word_emb + pos_enc
dec_input = layers.dropout( dec_input = layers.dropout(
emb, dropout_prob=self.emb_dropout, emb,
is_test=False) if self.emb_dropout else emb dropout_prob=self.emb_dropout, ) if self.emb_dropout else emb
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias, dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, caches) trg_src_attn_bias, caches)
dec_output = layers.reshape( dec_output = layers.reshape(
......
...@@ -287,13 +287,14 @@ class TestImperative(unittest.TestCase): ...@@ -287,13 +287,14 @@ class TestImperative(unittest.TestCase):
with paddle.no_grad(): with paddle.no_grad():
self.assertTrue(l1.weight.stop_gradient is False) self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2 tmp = l1.weight * 2
self.assertTrue(tmp.stop_gradient) print(tmp)
self.assertFalse(tmp.stop_gradient)
x = fluid.dygraph.to_variable(data) x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp y = l0(x) + tmp
o = l1(y) o = l1(y)
o.backward() o.backward()
self.assertTrue(tmp._grad_ivar() is None) self.assertTrue(tmp._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None) self.assertTrue(l0.weight._grad_ivar() is not None)
def test_sum_op(self): def test_sum_op(self):
......
...@@ -30,7 +30,7 @@ class TestTracerMode(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestTracerMode(unittest.TestCase):
@fluid.dygraph.no_grad @fluid.dygraph.no_grad
def no_grad_func(self, a): def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False) self.assertEqual(self.tracer._has_grad, False)
return a return a
@framework.dygraph_not_support @framework.dygraph_not_support
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册