diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu index daa7c07840fc2faf8d7d7cbc913f0e57bf91890c..39983e7de03b5404e20ba36378cbc552337a268e 100644 --- a/paddle/fluid/operators/where_op.cu +++ b/paddle/fluid/operators/where_op.cu @@ -30,15 +30,15 @@ __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, } template -__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond, - T* x, T* y) { +__global__ void WhereGradCUDAKernel(const int N, const T* dout, + const bool* cond, T* dx, T* dy) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < N; idx += blockDim.x * gridDim.x) { - if (x != nullptr) { - x[idx] = out[idx] * (cond[idx] ? 1. : 0.); + if (dx != nullptr) { + dx[idx] = cond[idx] ? dout[idx] : 0.; } - if (y != nullptr) { - y[idx] = out[idx] * (cond[idx] ? 0. : 1.); + if (dy != nullptr) { + dy[idx] = cond[idx] ? 0. : dout[idx]; } } } diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9186288aab616f591ec9aa4e06d97dcd7e2d4a43..05bebd759d39379f23071d6315234c7ce44e6a8e 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -191,7 +191,7 @@ from .tensor.search import argmax #DEFINE_ALIAS # from .tensor.search import has_nan #DEFINE_ALIAS # from .tensor.search import masked_select #DEFINE_ALIAS # from .tensor.search import topk #DEFINE_ALIAS -# from .tensor.search import where #DEFINE_ALIAS +from .tensor.search import where #DEFINE_ALIAS from .tensor.search import index_select #DEFINE_ALIAS from .tensor.search import index_sample #DEFINE_ALIAS from .tensor.search import nonzero #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 16971e435ca43d5e3425b6c0668888ef8262bc08..5eaf140461bce8442c0f11b4e6ae007f7907549a 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -16,9 +16,9 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers -import paddle.tensor as tensor import paddle.fluid.core as core from op_test import OpTest from paddle.fluid import compiler, Program, program_guard @@ -60,61 +60,64 @@ class TestWhereOp3(TestWhereOp): class TestWhereAPI(unittest.TestCase): - def test_api(self, use_cuda=False): - main_program = Program() - with fluid.program_guard(main_program): - x = fluid.layers.data(name='x', shape=[4], dtype='float32') - y = fluid.layers.data(name='y', shape=[4], dtype='float32') - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") - cond_i = np.array([False, False, True, True]).astype("bool") - result = tensor.where(x > 1, x=x, y=y) + def setUp(self): + self.init_data() - for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run(fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result]) - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + def init_data(self): + self.shape = [10, 15] + self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool) + self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.out = np.where(self.cond, self.x, self.y) - def test_grad(self, use_cuda=False): - main_program = Program() - with fluid.program_guard(main_program): - x = fluid.layers.data(name='x', shape=[4], dtype='float32') - y = fluid.layers.data(name='y', shape=[4], dtype='float32') - for x_stop_gradient, y_stop_gradient in [[False, False], - [True, False], - [False, True]]: - x.stop_gradient = x_stop_gradient - y.stop_gradient = y_stop_gradient - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") - cond_i = np.array([False, False, True, True]).astype("bool") - result = tensor.where(x > 1, x=x, y=y) - x_mean = layers.mean(x) - append_backward(x_mean) - y_mean = layers.mean(y) - append_backward(y_mean) - - for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run( - fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result, x.grad_name, y.grad_name]) - x_grad = [0.25] * 4 - y_grad = [0.25] * 4 - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) - assert np.array_equal(out[1], x_grad) - assert np.array_equal(out[2], y_grad) + def ref_x_backward(self, dout): + return np.where(self.cond == True, dout, 0) + + def ref_y_backward(self, dout): + return np.where(self.cond == False, dout, 0) + + def test_api(self, use_cuda=False): + for x_stop_gradient in [False, True]: + for y_stop_gradient in [False, True]: + with fluid.program_guard(Program(), Program()): + cond = fluid.layers.data( + name='cond', shape=self.shape, dtype='bool') + x = fluid.layers.data( + name='x', shape=self.shape, dtype='float32') + y = fluid.layers.data( + name='y', shape=self.shape, dtype='float32') + x.stop_gradient = x_stop_gradient + y.stop_gradient = y_stop_gradient + result = paddle.where(cond, x, y) + append_backward(layers.mean(result)) + + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + break + place = fluid.CUDAPlace( + 0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + fetch_list = [result, result.grad_name] + if x_stop_gradient is False: + fetch_list.append(x.grad_name) + if y_stop_gradient is False: + fetch_list.append(y.grad_name) + out = exe.run( + fluid.default_main_program(), + feed={'cond': self.cond, + 'x': self.x, + 'y': self.y}, + fetch_list=fetch_list) + assert np.array_equal(out[0], self.out) + if x_stop_gradient is False: + assert np.array_equal(out[2], + self.ref_x_backward(out[1])) + if y.stop_gradient is False: + assert np.array_equal( + out[3], self.ref_y_backward(out[1])) + elif y.stop_gradient is False: + assert np.array_equal(out[2], + self.ref_y_backward(out[1])) def test_api_broadcast(self, use_cuda=False): main_program = Program() @@ -124,9 +127,7 @@ class TestWhereAPI(unittest.TestCase): x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32") - cond_i = np.array([[False, False, True, True], - [False, False, True, True]]).astype("bool") - result = tensor.where(x > 1, x=x, y=y) + result = paddle.where(x > 1, x=x, y=y) for use_cuda in [False, True]: if use_cuda and not fluid.core.is_compiled_with_cuda(): @@ -137,7 +138,7 @@ class TestWhereAPI(unittest.TestCase): feed={'x': x_i, 'y': y_i}, fetch_list=[result]) - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i)) class TestWhereDygraphAPI(unittest.TestCase): @@ -149,7 +150,7 @@ class TestWhereDygraphAPI(unittest.TestCase): x = fluid.dygraph.to_variable(x_i) y = fluid.dygraph.to_variable(y_i) cond = fluid.dygraph.to_variable(cond_i) - out = tensor.where(cond, x, y) + out = paddle.where(cond, x, y) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) @@ -161,7 +162,7 @@ class TestWhereOpError(unittest.TestCase): cond_i = np.array([False, False, True, True]).astype("bool") def test_Variable(): - tensor.where(cond_i, x_i, y_i) + paddle.where(cond_i, x_i, y_i) self.assertRaises(TypeError, test_Variable) @@ -169,7 +170,7 @@ class TestWhereOpError(unittest.TestCase): x = fluid.layers.data(name='x', shape=[4], dtype='bool') y = fluid.layers.data(name='y', shape=[4], dtype='float16') cond = fluid.layers.data(name='cond', shape=[4], dtype='int32') - tensor.where(cond, x, y) + paddle.where(cond, x, y) self.assertRaises(TypeError, test_type) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index c415c9189a5609103b3bcaf4b8acb589ac79e8c5..de02c412b9f5cbc2de359a52c3978ff2bea6a449 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -388,9 +388,9 @@ def where(condition, x, y, name=None): Examples: .. code-block:: python + import paddle import numpy as np import paddle.fluid as fluid - import paddle.tensor as paddle x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") @@ -417,8 +417,7 @@ def where(condition, x, y, name=None): return core.ops.where(condition, x, y) else: helper = LayerHelper("where", **locals()) - dtype = helper.input_dtype() - out = helper.create_variable_for_type_inference(dtype) + out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type='where',