未验证 提交 e9d30991 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #8357 from jacquesqiao/override-compare-op-in-python

override comparison operators in Python for Variable
...@@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); ...@@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
...@@ -17,3 +17,4 @@ limitations under the License. */ ...@@ -17,3 +17,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
...@@ -48,6 +48,14 @@ struct EqualFunctor { ...@@ -48,6 +48,14 @@ struct EqualFunctor {
} }
}; };
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return !EqualFunctor<T>()(a, b);
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class CompareOpKernel class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
......
...@@ -152,7 +152,12 @@ def monkey_patch_variable(): ...@@ -152,7 +152,12 @@ def monkey_patch_variable():
("__div__", "elementwise_div", False), ("__div__", "elementwise_div", False),
("__rdiv__", "elementwise_div", True), ("__rdiv__", "elementwise_div", True),
("__pow__", "elementwise_pow", False), ("__pow__", "elementwise_pow", False),
("__rpow__", "elementwise_pow", True)): ("__rpow__", "elementwise_pow", True),
# for logical compare
("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False),
("__le__", "less_equal", False)):
setattr(Variable, method_name, setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse)) _elemwise_method_creator_(method_name, op_type, reverse))
......
...@@ -179,7 +179,7 @@ def polynomial_decay(learning_rate, ...@@ -179,7 +179,7 @@ def polynomial_decay(learning_rate,
shape=[1], dtype='float32', value=1.0) shape=[1], dtype='float32', value=1.0)
with layers.Switch() as switch: with layers.Switch() as switch:
with switch.case(layers.equal(x=global_step, y=zero_var)): with switch.case(global_step == zero_var):
layers.assign(input=one_var, output=div_res) layers.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res decay_steps = decay_steps * div_res
else: else:
...@@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values): ...@@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values):
shape=[1], dtype='float32', value=float(boundaries[i])) shape=[1], dtype='float32', value=float(boundaries[i]))
value_var = layers.fill_constant( value_var = layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i])) shape=[1], dtype='float32', value=float(values[i]))
with switch.case(layers.less_than(global_step, boundary_val)): with switch.case(global_step < boundary_val):
layers.assign(value_var, lr) layers.assign(value_var, lr)
last_value_var = layers.fill_constant( last_value_var = layers.fill_constant(
shape=[1], shape=[1],
......
...@@ -161,8 +161,8 @@ class TestBook(unittest.TestCase): ...@@ -161,8 +161,8 @@ class TestBook(unittest.TestCase):
label=label, label=label,
chunk_scheme="IOB", chunk_scheme="IOB",
num_chunk_types=(label_dict_len - 1) / 2) num_chunk_types=(label_dict_len - 1) / 2)
self.assertNotEqual(crf, None) self.assertFalse(crf is None)
self.assertNotEqual(crf_decode, None) self.assertFalse(crf_decode is None)
print(str(program)) print(str(program))
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid as fluid
class TestPythonOperatorOverride(unittest.TestCase):
def check_result(self, fn, place, dtype):
shape = [9, 10]
x_data = np.random.random(size=shape).astype(dtype)
y_data = np.random.random(size=shape).astype(dtype)
python_out = fn(x_data, y_data)
x_var = layers.create_global_var(
name='x', shape=shape, value=0.0, dtype=dtype, persistable=True)
y_var = layers.create_global_var(
name='y', shape=shape, value=0.0, dtype=dtype, persistable=True)
out = fn(x_var, y_var)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid_out = exe.run(fluid.default_main_program(),
feed={'x': x_data,
'y': y_data},
fetch_list=[out])
np.testing.assert_array_equal(python_out, fluid_out[0])
def test_override(self):
# compare func to check
compare_fns = [
lambda _a, _b: _a == _b,
lambda _a, _b: _a != _b,
lambda _a, _b: _a < _b,
lambda _a, _b: _a <= _b,
lambda _a, _b: _a > _b,
lambda _a, _b: _a >= _b,
]
# places to check
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
# dtypes to check
dtypes = ['int32', 'float32']
for place in places:
for dtype in dtypes:
for compare_fn in compare_fns:
with framework.program_guard(framework.Program(),
framework.Program()):
self.check_result(compare_fn, place, dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册