提交 df7c29e5 编写于 作者: Q qiaolongfei

override comparison operators in Python for Variable

上级 87b8c620
......@@ -151,7 +151,11 @@ def monkey_patch_variable():
("__div__", "elementwise_div", False),
("__rdiv__", "elementwise_div", True),
("__pow__", "elementwise_pow", False),
("__rpow__", "elementwise_pow", True)):
("__rpow__", "elementwise_pow", True),
# for logical compare
("__eq__", "equal", False),
("__lt__", "less_then", False),
("__le__", "less_equal", False), ):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))
......
......@@ -179,7 +179,7 @@ def polynomial_decay(learning_rate,
shape=[1], dtype='float32', value=1.0)
with layers.Switch() as switch:
with switch.case(layers.equal(x=global_step, y=zero_var)):
with switch.case(global_step == zero_var):
layers.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
......@@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values):
shape=[1], dtype='float32', value=float(boundaries[i]))
value_var = layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(layers.less_than(global_step, boundary_val)):
with switch.case(global_step < boundary_val):
layers.assign(value_var, lr)
last_value_var = layers.fill_constant(
shape=[1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册