未验证 提交 3271cf54 编写于 作者: S songyouwei 提交者: GitHub

[cherry pick] Fix layer & dygraph circular dependent (#23417)

test=develop test=release/1.7
上级 fd8e833b
......@@ -17,7 +17,8 @@ from __future__ import print_function
from six.moves import reduce
from .. import core
from ..layers import utils
from ..dygraph import dygraph_utils
from ..layers import nn
from .. import dygraph_utils
from . import layers
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator
from ..param_attr import ParamAttr
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import core
from ..framework import dygraph_only
from . import core
from .framework import dygraph_only
@dygraph_only
......
......@@ -32,7 +32,7 @@ from . import tensor
from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter, unique_name, name_scope
from ..framework import Variable
from ..dygraph import base as imperative_base
from ..framework import in_dygraph_mode
from ..dygraph import learning_rate_scheduler as imperate_lr
__all__ = [
......@@ -88,7 +88,7 @@ def noam_decay(d_model, warmup_steps):
warmup_steps)
"""
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.NoamDecay(d_model, warmup_steps)
return decay
else:
......@@ -143,7 +143,7 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps,
decay_rate, staircase)
return decay
......@@ -199,7 +199,7 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps,
decay_rate, staircase)
return decay
......@@ -255,7 +255,7 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
staircase=True))
"""
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps,
decay_rate, staircase)
return decay
......@@ -311,7 +311,7 @@ def polynomial_decay(learning_rate,
"""
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps,
end_learning_rate, power, cycle)
return decay
......@@ -380,7 +380,7 @@ def piecewise_decay(boundaries, values):
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
return decay
else:
......@@ -444,7 +444,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
"""
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
epochs)
return decay
......@@ -520,7 +520,7 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
linear_step = float(end_lr) - float(start_lr)
with default_main_program()._lr_schedule_guard():
if imperative_base.enabled():
if in_dygraph_mode():
lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps,
start_lr, end_lr)
return lr
......
......@@ -25,8 +25,7 @@ import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..dygraph import base
from ..dygraph import dygraph_utils
from .. import dygraph_utils
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign, fill_constant, zeros
......@@ -34,7 +33,6 @@ from . import utils
from .. import unique_name
from functools import reduce
from .. import core
from ..dygraph import layers
from ..data_feeder import convert_dtype, check_type_and_dtype, check_type, check_dtype
__all__ = [
......@@ -10320,9 +10318,6 @@ def _elementwise_op(helper):
op_type = helper.layer_type
x = helper.kwargs.get('x', None)
y = helper.kwargs.get('y', None)
if in_dygraph_mode():
x = base.to_variable(x)
y = base.to_variable(y)
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册