From b556b0f107258a2e2cefb9352def1600e9c7ca75 Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Wed, 22 Apr 2020 11:24:25 +0800 Subject: [PATCH] [Cherry-Pick] [2.0-beta] add paddle.where interface and error enhancement (#23972) --- paddle/fluid/operators/concat_op.h | 30 +++-- paddle/fluid/operators/lod_reset_op.cc | 6 +- paddle/fluid/operators/lod_reset_op.h | 29 ++--- paddle/fluid/operators/where_op.cu | 12 +- python/paddle/__init__.py | 2 +- python/paddle/fluid/layers/nn.py | 31 +++-- .../fluid/tests/unittests/test_layers.py | 2 +- .../tests/unittests/test_lod_append_op.py | 80 ++++++++++++ .../tests/unittests/test_lod_reset_op.py | 37 +++--- .../fluid/tests/unittests/test_where_op.py | 123 +++++++++--------- python/paddle/tensor/search.py | 5 +- 11 files changed, 221 insertions(+), 136 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_lod_append_op.py diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 6f4b9147c66..eceb68815e7 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -47,13 +47,13 @@ static inline framework::DDim ComputeAndCheckShape( is_runtime || (out_dims[j] > 0 && inputs_dims[i][j] > 0); if (check_shape) { // check all shape in run time - PADDLE_ENFORCE_EQ( - inputs_dims[0][j], inputs_dims[i][j], - platform::errors::InvalidArgument( - "The shape of input[%d] must be equal to input[0]. " - "But received input[0]'s shape = " - "[%s], input[%d]'s shape = [%s].", - i, inputs_dims[0], i, inputs_dims[i])); + PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j], + platform::errors::InvalidArgument( + "The %d-th dimension of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + j, i, inputs_dims[0], i, inputs_dims[i])); } } } @@ -79,9 +79,9 @@ class ConcatKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); framework::LoDTensor* out = ctx.Output("Out"); - PADDLE_ENFORCE_NOT_NULL( - ins[0], platform::errors::NotFound( - " The first input of concat should not be null.")); + PADDLE_ENFORCE_NOT_NULL(ins[0], + platform::errors::NotFound( + "The first input tensor is not initalized.")); auto axis = ctx.Attr("axis"); bool need_resize_out_dims = false; if (ctx.HasInput("AxisTensor")) { @@ -116,7 +116,9 @@ class ConcatKernel : public framework::OpKernel { platform::errors::Unimplemented( "The lod level of all input LoDTensors should be same. " "Maybe different lod level of input LoDTensors can concat," - " it is not supported currently.")); + "it is not supported currently. The lod level of %dth input " + "is %d and first input is %d.", + i, ins[i]->lod().size(), lod_size_0)); } else { lod_size = 0; break; @@ -181,9 +183,9 @@ class ConcatGradKernel : public framework::OpKernel { } } } - PADDLE_ENFORCE_NOT_NULL( - ins[0], platform::errors::NotFound( - "The first input of concat should not be null.")); + PADDLE_ENFORCE_NOT_NULL(ins[0], + platform::errors::NotFound( + "The first input tensor is not initalized.")); auto axis = ctx.Attr("axis"); if (ctx.HasInput("AxisTensor")) { diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 66c6de45077..7adcc678f5c 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -32,9 +32,9 @@ class LoDResetOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GT( static_cast(level0.size()), 0, platform::errors::InvalidArgument( - "If Input(Y) not provided, the target lod should be " - "specified by attribute `target_lod`. But the size of " - "`target_lod` is 0.")); + "If Input(Y) is not provided, the output's LoD should be " + "specified by attribute 'target_lod'. But the size of " + "'target_lod' is 0.")); } else if (ctx->IsRuntime()) { ctx->ShareLoD("Y", "Out"); } diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index 8a809ece49b..1318bf2385c 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -41,10 +41,10 @@ class LoDResetKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( static_cast(last_level.back()), in->dims()[0], platform::errors::InvalidArgument( - "The last value of `Y`'s last level LoD should be equal " - "to the first dimension of `X`. But received the last value of " - "`Y`'s last level LoD is %d, the first dimension of `X` is " - "%d. ", + "The last value of Input(Y)'s last level LoD should be equal " + "to the first dimension of Input(X). But received the last " + "value of Input(Y)'s last level LoD is %d, the first dimension " + "of Input(X) is %d.", static_cast(last_level.back()), in->dims()[0])); out->set_lod(y_lod); return; // early return, since lod already set @@ -75,19 +75,16 @@ class LoDResetKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( static_cast(level0.back()), in->dims()[0], platform::errors::InvalidArgument( - "The last value of `Target LoD`'s last level LoD should be equal " - "to the first dimension of `X`. But received the last value of " - "`Target LoD`'s last level LoD is %d, the first dimension of `X` " - "is " - "%d. ", - static_cast(level0.back()), in->dims()[0])); + "The last value of 'Target LoD''s last level LoD should be equal " + "to the first dimension of Input(X). But received the 'Target LoD' " + "is %s, Input(X)'s shape is is %s.", + framework::make_ddim(level0), in->dims())); for (size_t i = 0; i < level0.size() - 1; ++i) { - PADDLE_ENFORCE_GE( - level0[i + 1], level0[i], - platform::errors::InvalidArgument( - "Target LoD should be an ascending vector. But the %s element is " - "%s and the %s element of Target LoD is %s.", - i + 1, level0[i + 1], i, level0[i])); + PADDLE_ENFORCE_GE(level0[i + 1], level0[i], + platform::errors::InvalidArgument( + "'Target LoD' should be an ascending " + "vector. But received the Target LoD is %s.", + framework::make_ddim(level0))); } // cast level0 to size_t diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu index daa7c07840f..39983e7de03 100644 --- a/paddle/fluid/operators/where_op.cu +++ b/paddle/fluid/operators/where_op.cu @@ -30,15 +30,15 @@ __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, } template -__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond, - T* x, T* y) { +__global__ void WhereGradCUDAKernel(const int N, const T* dout, + const bool* cond, T* dx, T* dy) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < N; idx += blockDim.x * gridDim.x) { - if (x != nullptr) { - x[idx] = out[idx] * (cond[idx] ? 1. : 0.); + if (dx != nullptr) { + dx[idx] = cond[idx] ? dout[idx] : 0.; } - if (y != nullptr) { - y[idx] = out[idx] * (cond[idx] ? 0. : 1.); + if (dy != nullptr) { + dy[idx] = cond[idx] ? 0. : dout[idx]; } } } diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 696a97c16d3..358933bd352 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -194,7 +194,7 @@ from .tensor.search import argmax #DEFINE_ALIAS # from .tensor.search import has_nan #DEFINE_ALIAS # from .tensor.search import masked_select #DEFINE_ALIAS # from .tensor.search import topk #DEFINE_ALIAS -# from .tensor.search import where #DEFINE_ALIAS +from .tensor.search import where #DEFINE_ALIAS from .tensor.search import index_select #DEFINE_ALIAS from .tensor.search import index_sample #DEFINE_ALIAS from .tensor.search import nonzero #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5338a276655..35aa4e76448 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6196,10 +6196,12 @@ def lod_reset(x, y=None, target_lod=None): out.dims = [6, 1] Args: - x (Variable): Input variable which could be a Tensor or LoDTensor. - y (Variable|None): If provided, output's LoD would be derived - from :attr:`y`. - target_lod (list|tuple|None): One level LoD which should be considered + x (Variable): Input variable which could be a Tensor or LoDTensor. + The data type should be int32, int64, float32 or float64. + y (Variable, optional): If provided, output's LoD would be derived from :attr:`y`. + If y's lod level>0, the data type can be any type. + If y's lod level=0, the data type should be int32. + target_lod (list|tuple, optional): One level LoD which should be considered as target LoD when :attr:`y` not provided. Returns: @@ -6221,11 +6223,9 @@ def lod_reset(x, y=None, target_lod=None): helper = LayerHelper("lod_reset", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) if y is not None: - if y.lod_level > 0: - check_variable_and_dtype( - y, 'y', ['float32', 'float64', 'int32', 'int64'], 'lod_reset') - else: - check_variable_and_dtype(y, 'y', ['int32', 'int64'], 'lod_reset') + check_type(y, 'y', (Variable), 'lod_reset') + if y.lod_level == 0: + check_variable_and_dtype(y, 'y', ['int32'], 'lod_reset') helper.append_op( type="lod_reset", inputs={'X': x, 'Y': y}, outputs={'Out': out}) @@ -6261,9 +6261,11 @@ def lod_append(x, level): x.dims = [6, 1] Args: - x (Variable): Input variable which could be a tensor or LoDTensor. - level (list|tuple|Variable): The LoD level to be appended into LoD of x. - + x (Variable): Input variable which could be a tensor or LoDTensor. + The data type should be int32, int64, float32 or float64. + level (list|tuple|Variable, optional): The LoD level to be appended into LoD of x. + If level is variable and its lod level>0, the data type can be any type. + If level is variable and its lod level=0, the data type should be int32. Returns: Variable: Output variable with new LoD level. @@ -6283,6 +6285,9 @@ def lod_append(x, level): if (not isinstance(level, Iterable)) and (not isinstance(level, Variable)): raise ValueError("Input(level) must be list, tuple or Variable.") + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + 'lod_append') + helper = LayerHelper("lod_append", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -6291,6 +6296,8 @@ def lod_append(x, level): if isinstance(level, Variable): inputs['Y'] = level + if level.lod_level == 0: + check_variable_and_dtype(level, 'level', ['int32'], 'lod_append') else: attrs['target_lod'] = level helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 4259fed5dd3..1df1f34e761 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3033,7 +3033,7 @@ class TestBook(LayerTest): z = layers.lod_reset(x=x, y=y) self.assertTrue(z.lod_level == 2) # case 2 - lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int64') + lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int32') z = layers.lod_reset(x=x, y=lod_tensor_in) self.assertTrue(z.lod_level == 1) # case 3 diff --git a/python/paddle/fluid/tests/unittests/test_lod_append_op.py b/python/paddle/fluid/tests/unittests/test_lod_append_op.py new file mode 100644 index 00000000000..82cf4318098 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lod_append_op.py @@ -0,0 +1,80 @@ +#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.op import Operator +from paddle.fluid.backward import append_backward + + +class TestLoDAppendAPI(unittest.TestCase): + def test_api(self, use_cuda=False): + main_program = Program() + with fluid.program_guard(main_program): + x = fluid.layers.data(name='x', shape=[6], dtype='float32') + level = fluid.layers.data( + name='level', shape=[3], dtype='int32', lod_level=0) + result = fluid.layers.lod_append(x, level) + + x_i = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).astype("float32") + level_i = np.array([0, 2, 6]).astype("int32") + + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + [out] = exe.run(fluid.default_main_program(), + feed={'x': x_i, + 'level': level_i}, + fetch_list=[result], + return_numpy=False) + self.assertEqual(out.recursive_sequence_lengths(), [[2, 4]]) + + +class TestLodAppendOpError(unittest.TestCase): + def test_error(self): + # The input(x) must be Variable. + x1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") + level1 = [0, 2, 4] + self.assertRaises(TypeError, fluid.layers.lod_append, x1, level1) + + #The input(level) must be Variable or list. + x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32') + self.assertRaises(ValueError, fluid.layers.lod_append, x2, 2) + + # Input(x) dtype must be float32 or float64 or int32 or int64 + for dtype in ["bool", "float16"]: + x3 = fluid.layers.data(name='x3_' + dtype, shape=[4], dtype=dtype) + level3 = fluid.layers.data( + name='level3' + dtype, shape=[4], dtype='int32', lod_level=2) + self.assertRaises(TypeError, fluid.layers.lod_append, x3, level3) + + # Input(level) dtype must be int32 when lod_level=0 + for dtype in ["bool", "float16", "float32", "float64", "int64"]: + x4 = fluid.layers.data( + name='x4' + dtype, shape=[4], dtype='float32') + level4 = fluid.layers.data( + name='level4_' + dtype, shape=[4], dtype=dtype, lod_level=0) + self.assertRaises(TypeError, fluid.layers.lod_append, x4, level4) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py index 236c9944145..ac2cd1be27f 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py +++ b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid as fluid from op_test import OpTest from paddle.fluid import Program, program_guard @@ -136,28 +137,26 @@ class TestLodAppendOpByAttr(OpTest): class TestLodResetOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): - - def test_Variable(): - # The input must be Variable. - x1 = fluid.create_lod_tensor( - np.ones([6]), [3, 3], fluid.CPUPlace()) - y1 = fluid.create_lod_tensor( - np.ones([6]), [2, 2, 2], fluid.CPUPlace()) - self.assertRaises(TypeError, fluid.layers.lod_reset, [x1, y1]) - - def test_type(): - # dtype must be float32 or float64 or int32 or int64 - x2 = fluid.layers.data(shape=[4], dtype='uint8', name='x2') + # The input must be Variable. + x1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") + target_lod = [2, 2] + self.assertRaises(TypeError, fluid.layers.lod_reset, x1, target_lod) + + # Input(x) dtype must be float32 or float64 or int32 or int64 + for dtype in ["bool", "float16"]: + x2 = fluid.layers.data( + name='x2' + dtype, shape=[4], dtype=dtype) y2 = fluid.layers.data( - shape=[4], dtype='uint8', name='x2', lod_level=2) - self.assertRaises(TypeError, fluid.layers.lod_reset, [x2, y2]) + name='y2' + dtype, shape=[4], dtype='int32', lod_level=2) + self.assertRaises(TypeError, fluid.layers.lod_reset, x2, y2) - def test_type2(): - # dtype must be int32 or int64 - x3 = fluid.layers.data(shape=[4], dtype='float32', name='x3') + # Input(y) dtype must be int32 when lod_level=0 + for dtype in ["bool", "float16", "float32", "float64", "int64"]: + x3 = fluid.layers.data( + name='x3' + dtype, shape=[4], dtype='float32') y3 = fluid.layers.data( - shape=[4], dtype='float32', name='x3', lod_level=0) - self.assertRaises(TypeError, fluid.layers.lod_reset, [x3, y3]) + name='y3' + dtype, shape=[4], dtype=dtype, lod_level=0) + self.assertRaises(TypeError, fluid.layers.lod_reset, x3, y3) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 16971e435ca..5eaf140461b 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -16,9 +16,9 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers -import paddle.tensor as tensor import paddle.fluid.core as core from op_test import OpTest from paddle.fluid import compiler, Program, program_guard @@ -60,61 +60,64 @@ class TestWhereOp3(TestWhereOp): class TestWhereAPI(unittest.TestCase): - def test_api(self, use_cuda=False): - main_program = Program() - with fluid.program_guard(main_program): - x = fluid.layers.data(name='x', shape=[4], dtype='float32') - y = fluid.layers.data(name='y', shape=[4], dtype='float32') - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") - cond_i = np.array([False, False, True, True]).astype("bool") - result = tensor.where(x > 1, x=x, y=y) + def setUp(self): + self.init_data() - for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run(fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result]) - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + def init_data(self): + self.shape = [10, 15] + self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool) + self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.out = np.where(self.cond, self.x, self.y) - def test_grad(self, use_cuda=False): - main_program = Program() - with fluid.program_guard(main_program): - x = fluid.layers.data(name='x', shape=[4], dtype='float32') - y = fluid.layers.data(name='y', shape=[4], dtype='float32') - for x_stop_gradient, y_stop_gradient in [[False, False], - [True, False], - [False, True]]: - x.stop_gradient = x_stop_gradient - y.stop_gradient = y_stop_gradient - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") - cond_i = np.array([False, False, True, True]).astype("bool") - result = tensor.where(x > 1, x=x, y=y) - x_mean = layers.mean(x) - append_backward(x_mean) - y_mean = layers.mean(y) - append_backward(y_mean) - - for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - out = exe.run( - fluid.default_main_program(), - feed={'x': x_i, - 'y': y_i}, - fetch_list=[result, x.grad_name, y.grad_name]) - x_grad = [0.25] * 4 - y_grad = [0.25] * 4 - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) - assert np.array_equal(out[1], x_grad) - assert np.array_equal(out[2], y_grad) + def ref_x_backward(self, dout): + return np.where(self.cond == True, dout, 0) + + def ref_y_backward(self, dout): + return np.where(self.cond == False, dout, 0) + + def test_api(self, use_cuda=False): + for x_stop_gradient in [False, True]: + for y_stop_gradient in [False, True]: + with fluid.program_guard(Program(), Program()): + cond = fluid.layers.data( + name='cond', shape=self.shape, dtype='bool') + x = fluid.layers.data( + name='x', shape=self.shape, dtype='float32') + y = fluid.layers.data( + name='y', shape=self.shape, dtype='float32') + x.stop_gradient = x_stop_gradient + y.stop_gradient = y_stop_gradient + result = paddle.where(cond, x, y) + append_backward(layers.mean(result)) + + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + break + place = fluid.CUDAPlace( + 0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + fetch_list = [result, result.grad_name] + if x_stop_gradient is False: + fetch_list.append(x.grad_name) + if y_stop_gradient is False: + fetch_list.append(y.grad_name) + out = exe.run( + fluid.default_main_program(), + feed={'cond': self.cond, + 'x': self.x, + 'y': self.y}, + fetch_list=fetch_list) + assert np.array_equal(out[0], self.out) + if x_stop_gradient is False: + assert np.array_equal(out[2], + self.ref_x_backward(out[1])) + if y.stop_gradient is False: + assert np.array_equal( + out[3], self.ref_y_backward(out[1])) + elif y.stop_gradient is False: + assert np.array_equal(out[2], + self.ref_y_backward(out[1])) def test_api_broadcast(self, use_cuda=False): main_program = Program() @@ -124,9 +127,7 @@ class TestWhereAPI(unittest.TestCase): x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32") - cond_i = np.array([[False, False, True, True], - [False, False, True, True]]).astype("bool") - result = tensor.where(x > 1, x=x, y=y) + result = paddle.where(x > 1, x=x, y=y) for use_cuda in [False, True]: if use_cuda and not fluid.core.is_compiled_with_cuda(): @@ -137,7 +138,7 @@ class TestWhereAPI(unittest.TestCase): feed={'x': x_i, 'y': y_i}, fetch_list=[result]) - assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i)) class TestWhereDygraphAPI(unittest.TestCase): @@ -149,7 +150,7 @@ class TestWhereDygraphAPI(unittest.TestCase): x = fluid.dygraph.to_variable(x_i) y = fluid.dygraph.to_variable(y_i) cond = fluid.dygraph.to_variable(cond_i) - out = tensor.where(cond, x, y) + out = paddle.where(cond, x, y) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) @@ -161,7 +162,7 @@ class TestWhereOpError(unittest.TestCase): cond_i = np.array([False, False, True, True]).astype("bool") def test_Variable(): - tensor.where(cond_i, x_i, y_i) + paddle.where(cond_i, x_i, y_i) self.assertRaises(TypeError, test_Variable) @@ -169,7 +170,7 @@ class TestWhereOpError(unittest.TestCase): x = fluid.layers.data(name='x', shape=[4], dtype='bool') y = fluid.layers.data(name='y', shape=[4], dtype='float16') cond = fluid.layers.data(name='cond', shape=[4], dtype='int32') - tensor.where(cond, x, y) + paddle.where(cond, x, y) self.assertRaises(TypeError, test_type) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index c415c9189a5..de02c412b9f 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -388,9 +388,9 @@ def where(condition, x, y, name=None): Examples: .. code-block:: python + import paddle import numpy as np import paddle.fluid as fluid - import paddle.tensor as paddle x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") @@ -417,8 +417,7 @@ def where(condition, x, y, name=None): return core.ops.where(condition, x, y) else: helper = LayerHelper("where", **locals()) - dtype = helper.input_dtype() - out = helper.create_variable_for_type_inference(dtype) + out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type='where', -- GitLab