未验证 提交 3271cf54 编写于 作者: S songyouwei 提交者: GitHub

[cherry pick] Fix layer & dygraph circular dependent (#23417)

test=develop test=release/1.7
上级 fd8e833b
...@@ -17,7 +17,8 @@ from __future__ import print_function ...@@ -17,7 +17,8 @@ from __future__ import print_function
from six.moves import reduce from six.moves import reduce
from .. import core from .. import core
from ..layers import utils from ..layers import utils
from ..dygraph import dygraph_utils from ..layers import nn
from .. import dygraph_utils
from . import layers from . import layers
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .. import core from . import core
from ..framework import dygraph_only from .framework import dygraph_only
@dygraph_only @dygraph_only
......
...@@ -32,7 +32,7 @@ from . import tensor ...@@ -32,7 +32,7 @@ from . import tensor
from ..initializer import init_on_cpu from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter, unique_name, name_scope from ..framework import default_main_program, Parameter, unique_name, name_scope
from ..framework import Variable from ..framework import Variable
from ..dygraph import base as imperative_base from ..framework import in_dygraph_mode
from ..dygraph import learning_rate_scheduler as imperate_lr from ..dygraph import learning_rate_scheduler as imperate_lr
__all__ = [ __all__ = [
...@@ -88,7 +88,7 @@ def noam_decay(d_model, warmup_steps): ...@@ -88,7 +88,7 @@ def noam_decay(d_model, warmup_steps):
warmup_steps) warmup_steps)
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.NoamDecay(d_model, warmup_steps) decay = imperate_lr.NoamDecay(d_model, warmup_steps)
return decay return decay
else: else:
...@@ -143,7 +143,7 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -143,7 +143,7 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps, decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps,
decay_rate, staircase) decay_rate, staircase)
return decay return decay
...@@ -199,7 +199,7 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -199,7 +199,7 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps, decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps,
decay_rate, staircase) decay_rate, staircase)
return decay return decay
...@@ -255,7 +255,7 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -255,7 +255,7 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
staircase=True)) staircase=True))
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps, decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps,
decay_rate, staircase) decay_rate, staircase)
return decay return decay
...@@ -311,7 +311,7 @@ def polynomial_decay(learning_rate, ...@@ -311,7 +311,7 @@ def polynomial_decay(learning_rate,
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps, decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps,
end_learning_rate, power, cycle) end_learning_rate, power, cycle)
return decay return decay
...@@ -380,7 +380,7 @@ def piecewise_decay(boundaries, values): ...@@ -380,7 +380,7 @@ def piecewise_decay(boundaries, values):
if len(values) - len(boundaries) != 1: if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1") raise ValueError("len(values) - len(boundaries) should be 1")
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.PiecewiseDecay(boundaries, values, 0) decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
return decay return decay
else: else:
...@@ -444,7 +444,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs): ...@@ -444,7 +444,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch, decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
epochs) epochs)
return decay return decay
...@@ -520,7 +520,7 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): ...@@ -520,7 +520,7 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
linear_step = float(end_lr) - float(start_lr) linear_step = float(end_lr) - float(start_lr)
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if imperative_base.enabled(): if in_dygraph_mode():
lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps, lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps,
start_lr, end_lr) start_lr, end_lr)
return lr return lr
......
...@@ -25,8 +25,7 @@ import inspect ...@@ -25,8 +25,7 @@ import inspect
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..dygraph import base from .. import dygraph_utils
from ..dygraph import dygraph_utils
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign, fill_constant, zeros from .tensor import concat, assign, fill_constant, zeros
...@@ -34,7 +33,6 @@ from . import utils ...@@ -34,7 +33,6 @@ from . import utils
from .. import unique_name from .. import unique_name
from functools import reduce from functools import reduce
from .. import core from .. import core
from ..dygraph import layers
from ..data_feeder import convert_dtype, check_type_and_dtype, check_type, check_dtype from ..data_feeder import convert_dtype, check_type_and_dtype, check_type, check_dtype
__all__ = [ __all__ = [
...@@ -10320,9 +10318,6 @@ def _elementwise_op(helper): ...@@ -10320,9 +10318,6 @@ def _elementwise_op(helper):
op_type = helper.layer_type op_type = helper.layer_type
x = helper.kwargs.get('x', None) x = helper.kwargs.get('x', None)
y = helper.kwargs.get('y', None) y = helper.kwargs.get('y', None)
if in_dygraph_mode():
x = base.to_variable(x)
y = base.to_variable(y)
assert x is not None, 'x cannot be None in {}'.format(op_type) assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册