From df7c29e5165847f9958903f8c492f566b3df63fc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Feb 2018 20:54:50 +0800 Subject: [PATCH] override comparison operators in Python for Variable --- python/paddle/v2/fluid/layers/math_op_patch.py | 6 +++++- python/paddle/v2/fluid/learning_rate_decay.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 9b5f22759cf..4cf995ec855 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -151,7 +151,11 @@ def monkey_patch_variable(): ("__div__", "elementwise_div", False), ("__rdiv__", "elementwise_div", True), ("__pow__", "elementwise_pow", False), - ("__rpow__", "elementwise_pow", True)): + ("__rpow__", "elementwise_pow", True), + # for logical compare + ("__eq__", "equal", False), + ("__lt__", "less_then", False), + ("__le__", "less_equal", False), ): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/learning_rate_decay.py b/python/paddle/v2/fluid/learning_rate_decay.py index 2a2a29fd9cb..0826d3da79a 100644 --- a/python/paddle/v2/fluid/learning_rate_decay.py +++ b/python/paddle/v2/fluid/learning_rate_decay.py @@ -179,7 +179,7 @@ def polynomial_decay(learning_rate, shape=[1], dtype='float32', value=1.0) with layers.Switch() as switch: - with switch.case(layers.equal(x=global_step, y=zero_var)): + with switch.case(global_step == zero_var): layers.assign(input=one_var, output=div_res) decay_steps = decay_steps * div_res else: @@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values): shape=[1], dtype='float32', value=float(boundaries[i])) value_var = layers.fill_constant( shape=[1], dtype='float32', value=float(values[i])) - with switch.case(layers.less_than(global_step, boundary_val)): + with switch.case(global_step < boundary_val): layers.assign(value_var, lr) last_value_var = layers.fill_constant( shape=[1], -- GitLab