From ff0ab7566294006969719eecd2e4e578aa568be2 Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Mon, 13 Apr 2020 08:51:58 +0800 Subject: [PATCH] polish tensor.where codes and english document (#23687) --- paddle/fluid/operators/where_op.cc | 2 +- paddle/fluid/operators/where_op.cu | 7 -- .../fluid/tests/unittests/test_where_op.py | 89 ++++++++++--------- python/paddle/tensor/search.py | 77 +++++++++------- 4 files changed, 91 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/operators/where_op.cc b/paddle/fluid/operators/where_op.cc index bdb3fb24ded..7b198cf2240 100644 --- a/paddle/fluid/operators/where_op.cc +++ b/paddle/fluid/operators/where_op.cc @@ -102,7 +102,7 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor), The second input tensor of where op. When the " "corresponding position of condition is false, the output takes " "the element of Y."); - AddOutput("Out", "(Tensor), The output tensor of mul op."); + AddOutput("Out", "(Tensor), The output tensor of where op."); AddComment(R"DOC( Where Operator. Return a tensor of elements selected from either $X$ or $Y$, depending on condition. diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu index 0ec1a3c6fa6..daa7c07840f 100644 --- a/paddle/fluid/operators/where_op.cu +++ b/paddle/fluid/operators/where_op.cu @@ -48,9 +48,6 @@ class WhereKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(context.GetPlace()), true, - platform::errors::PermissionDenied("It must use CUDAPlace.")); auto* condition = context.Input("Condition"); auto* X = context.Input("X"); auto* Y = context.Input("Y"); @@ -78,10 +75,6 @@ class WhereGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(context.GetPlace()), true, - platform::errors::PermissionDenied("It must use CUDAPlace.")); - auto* condition = context.Input("Condition"); const bool* cond_data = condition->data(); auto numel = condition->numel(); diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 1ae311a9c46..16971e435ca 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -68,45 +68,53 @@ class TestWhereAPI(unittest.TestCase): x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") cond_i = np.array([False, False, True, True]).astype("bool") - result = tensor.where(x > 1, X=x, Y=y) + result = tensor.where(x > 1, x=x, y=y) - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run(fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result]) - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + out = exe.run(fluid.default_main_program(), + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result]) + assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) def test_grad(self, use_cuda=False): main_program = Program() - for x_stop_gradient, y_stop_gradient in [[False, False], [True, False], - [False, True]]: - with fluid.program_guard(main_program): - x = fluid.layers.data(name='x', shape=[4], dtype='float32') - y = fluid.layers.data(name='y', shape=[4], dtype='float32') + with fluid.program_guard(main_program): + x = fluid.layers.data(name='x', shape=[4], dtype='float32') + y = fluid.layers.data(name='y', shape=[4], dtype='float32') + for x_stop_gradient, y_stop_gradient in [[False, False], + [True, False], + [False, True]]: x.stop_gradient = x_stop_gradient y.stop_gradient = y_stop_gradient x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") cond_i = np.array([False, False, True, True]).astype("bool") - result = tensor.where(x > 1, X=x, Y=y) + result = tensor.where(x > 1, x=x, y=y) x_mean = layers.mean(x) append_backward(x_mean) y_mean = layers.mean(y) append_backward(y_mean) - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run(fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result, x.grad_name, y.grad_name]) - x_grad = [0.25] * 4 - y_grad = [0.25] * 4 - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) - assert np.array_equal(out[1], x_grad) - assert np.array_equal(out[2], y_grad) + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + out = exe.run( + fluid.default_main_program(), + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result, x.grad_name, y.grad_name]) + x_grad = [0.25] * 4 + y_grad = [0.25] * 4 + assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + assert np.array_equal(out[1], x_grad) + assert np.array_equal(out[2], y_grad) def test_api_broadcast(self, use_cuda=False): main_program = Program() @@ -114,25 +122,22 @@ class TestWhereAPI(unittest.TestCase): x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32') y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32') x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") - y_i = np.array( - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32") + y_i = np.array([[1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0]]).astype("float32") cond_i = np.array([[False, False, True, True], [False, False, True, True]]).astype("bool") - result = tensor.where(x > 1, X=x, Y=y) - - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run(fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result]) - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) - - def test_fw_bw(self): - if core.is_compiled_with_cuda(): - self.test_api(use_cuda=True) - self.test_api_broadcast(use_cuda=True) - self.test_grad(use_cuda=True) + result = tensor.where(x > 1, x=x, y=y) + + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + out = exe.run(fluid.default_main_program(), + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result]) + assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) class TestWhereDygraphAPI(unittest.TestCase): diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index c66effd1d75..a4a1304acfa 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -13,17 +13,8 @@ # limitations under the License. from __future__ import print_function import numpy as np -import warnings -import six -import os -import inspect from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype -from ..fluid.initializer import Normal, Constant, NumpyArrayInitializer -from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program -from ..fluid import dygraph_utils -from ..fluid.param_attr import ParamAttr -from ..fluid import unique_name from ..fluid import core, layers # TODO: define searching & indexing functions of a tensor @@ -224,43 +215,61 @@ def sort(input, axis=-1, descending=False, out=None, name=None): return out, ids -def where(Condition, X, Y): +def where(condition, x, y, name=None): """ - Return a tensor of elements selected from either $X$ or $Y$, depending on $Condition$. + Return a tensor of elements selected from either $x$ or $y$, depending on $condition$. + + .. math:: + + out_i = + \\begin{cases} + x_i, \quad \\text{if} \\ condition_i \\ is \\ True \\\\ + y_i, \quad \\text{if} \\ condition_i \\ is \\ False \\\\ + \\end{cases} + + Args: - Condition(Variable): A bool tensor with rank at least 1, the data type is bool. - X(Variable): X is a Tensor Variable. - Y(Variable): Y is a Tensor Variable. + condition(Variable): The condition to choose x or y. + x(Variable): x is a Tensor Variable with data type float32, float64, int32, int64. + y(Variable): y is a Tensor Variable with data type float32, float64, int32, int64. + + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: - out : The tensor. + Variable: A Tensor with the same data dype as x. + Examples: .. code-block:: python import numpy as np - import paddle as paddle import paddle.fluid as fluid + import paddle.tensor as paddle + + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") with fluid.dygraph.guard(): - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64") x = fluid.dygraph.to_variable(x_i) y = fluid.dygraph.to_variable(y_i) out = paddle.where(x>1, x, y) - print(out.numpy()) - #out: [1.0, 1.0, 3.2, 1.2] + + print(out.numpy()) + #out: [1.0, 1.0, 3.2, 1.2] """ if not in_dygraph_mode(): - check_variable_and_dtype(Condition, 'Condition', ['bool'], 'where') + check_variable_and_dtype(condition, 'condition', ['bool'], 'where') check_variable_and_dtype( - X, 'X', ['float32', 'float64', 'int32', 'int64'], 'where') + x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where') check_variable_and_dtype( - Y, 'Y', ['float32', 'float64', 'int32', 'int64'], 'where') + y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where') - X_shape = list(X.shape) - Y_shape = list(Y.shape) - if X_shape == Y_shape: + x_shape = list(x.shape) + y_shape = list(y.shape) + if x_shape == y_shape: if in_dygraph_mode(): - return core.ops.where(Condition, X, Y) + return core.ops.where(condition, x, y) else: helper = LayerHelper("where", **locals()) dtype = helper.input_dtype() @@ -268,16 +277,16 @@ def where(Condition, X, Y): helper.append_op( type='where', - inputs={'Condition': Condition, - 'X': X, - 'Y': Y}, + inputs={'Condition': condition, + 'X': x, + 'Y': y}, outputs={'Out': [out]}) return out else: - cond_int = layers.cast(Condition, X.dtype) - cond_not_int = layers.cast(layers.logical_not(Condition), X.dtype) - out1 = layers.elementwise_mul(X, cond_int) - out2 = layers.elementwise_mul(Y, cond_not_int) + cond_int = layers.cast(condition, x.dtype) + cond_not_int = layers.cast(layers.logical_not(condition), x.dtype) + out1 = layers.elementwise_mul(x, cond_int) + out2 = layers.elementwise_mul(y, cond_not_int) out = layers.elementwise_add(out1, out2) return out -- GitLab