未验证 提交 b556b0f1 编写于 作者: G GaoWei8 提交者: GitHub

[Cherry-Pick] [2.0-beta] add paddle.where interface and error enhancement (#23972)

上级 2c8a9181
...@@ -47,13 +47,13 @@ static inline framework::DDim ComputeAndCheckShape( ...@@ -47,13 +47,13 @@ static inline framework::DDim ComputeAndCheckShape(
is_runtime || (out_dims[j] > 0 && inputs_dims[i][j] > 0); is_runtime || (out_dims[j] > 0 && inputs_dims[i][j] > 0);
if (check_shape) { if (check_shape) {
// check all shape in run time // check all shape in run time
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j],
inputs_dims[0][j], inputs_dims[i][j], platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The %d-th dimension of input[0] and input[%d] "
"The shape of input[%d] must be equal to input[0]. " "is expected to be equal."
"But received input[0]'s shape = " "But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].", "[%s], input[%d]'s shape = [%s].",
i, inputs_dims[0], i, inputs_dims[i])); j, i, inputs_dims[0], i, inputs_dims[i]));
} }
} }
} }
...@@ -79,9 +79,9 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -79,9 +79,9 @@ class ConcatKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out"); framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(ins[0],
ins[0], platform::errors::NotFound( platform::errors::NotFound(
" The first input of concat should not be null.")); "The first input tensor is not initalized."));
auto axis = ctx.Attr<int>("axis"); auto axis = ctx.Attr<int>("axis");
bool need_resize_out_dims = false; bool need_resize_out_dims = false;
if (ctx.HasInput("AxisTensor")) { if (ctx.HasInput("AxisTensor")) {
...@@ -116,7 +116,9 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -116,7 +116,9 @@ class ConcatKernel : public framework::OpKernel<T> {
platform::errors::Unimplemented( platform::errors::Unimplemented(
"The lod level of all input LoDTensors should be same. " "The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat," "Maybe different lod level of input LoDTensors can concat,"
" it is not supported currently.")); "it is not supported currently. The lod level of %dth input "
"is %d and first input is %d.",
i, ins[i]->lod().size(), lod_size_0));
} else { } else {
lod_size = 0; lod_size = 0;
break; break;
...@@ -181,9 +183,9 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -181,9 +183,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
} }
} }
} }
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(ins[0],
ins[0], platform::errors::NotFound( platform::errors::NotFound(
"The first input of concat should not be null.")); "The first input tensor is not initalized."));
auto axis = ctx.Attr<int>("axis"); auto axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) { if (ctx.HasInput("AxisTensor")) {
......
...@@ -32,9 +32,9 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -32,9 +32,9 @@ class LoDResetOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
static_cast<int64_t>(level0.size()), 0, static_cast<int64_t>(level0.size()), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"If Input(Y) not provided, the target lod should be " "If Input(Y) is not provided, the output's LoD should be "
"specified by attribute `target_lod`. But the size of " "specified by attribute 'target_lod'. But the size of "
"`target_lod` is 0.")); "'target_lod' is 0."));
} else if (ctx->IsRuntime()) { } else if (ctx->IsRuntime()) {
ctx->ShareLoD("Y", "Out"); ctx->ShareLoD("Y", "Out");
} }
......
...@@ -41,10 +41,10 @@ class LoDResetKernel : public framework::OpKernel<T> { ...@@ -41,10 +41,10 @@ class LoDResetKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<int64_t>(last_level.back()), in->dims()[0], static_cast<int64_t>(last_level.back()), in->dims()[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The last value of `Y`'s last level LoD should be equal " "The last value of Input(Y)'s last level LoD should be equal "
"to the first dimension of `X`. But received the last value of " "to the first dimension of Input(X). But received the last "
"`Y`'s last level LoD is %d, the first dimension of `X` is " "value of Input(Y)'s last level LoD is %d, the first dimension "
"%d. ", "of Input(X) is %d.",
static_cast<int64_t>(last_level.back()), in->dims()[0])); static_cast<int64_t>(last_level.back()), in->dims()[0]));
out->set_lod(y_lod); out->set_lod(y_lod);
return; // early return, since lod already set return; // early return, since lod already set
...@@ -75,19 +75,16 @@ class LoDResetKernel : public framework::OpKernel<T> { ...@@ -75,19 +75,16 @@ class LoDResetKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<int64_t>(level0.back()), in->dims()[0], static_cast<int64_t>(level0.back()), in->dims()[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The last value of `Target LoD`'s last level LoD should be equal " "The last value of 'Target LoD''s last level LoD should be equal "
"to the first dimension of `X`. But received the last value of " "to the first dimension of Input(X). But received the 'Target LoD' "
"`Target LoD`'s last level LoD is %d, the first dimension of `X` " "is %s, Input(X)'s shape is is %s.",
"is " framework::make_ddim(level0), in->dims()));
"%d. ",
static_cast<int64_t>(level0.back()), in->dims()[0]));
for (size_t i = 0; i < level0.size() - 1; ++i) { for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(level0[i + 1], level0[i],
level0[i + 1], level0[i], platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "'Target LoD' should be an ascending "
"Target LoD should be an ascending vector. But the %s element is " "vector. But received the Target LoD is %s.",
"%s and the %s element of Target LoD is %s.", framework::make_ddim(level0)));
i + 1, level0[i + 1], i, level0[i]));
} }
// cast level0 to size_t // cast level0 to size_t
......
...@@ -30,15 +30,15 @@ __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, ...@@ -30,15 +30,15 @@ __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
} }
template <typename T> template <typename T>
__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond, __global__ void WhereGradCUDAKernel(const int N, const T* dout,
T* x, T* y) { const bool* cond, T* dx, T* dy) {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) { for (; idx < N; idx += blockDim.x * gridDim.x) {
if (x != nullptr) { if (dx != nullptr) {
x[idx] = out[idx] * (cond[idx] ? 1. : 0.); dx[idx] = cond[idx] ? dout[idx] : 0.;
} }
if (y != nullptr) { if (dy != nullptr) {
y[idx] = out[idx] * (cond[idx] ? 0. : 1.); dy[idx] = cond[idx] ? 0. : dout[idx];
} }
} }
} }
......
...@@ -194,7 +194,7 @@ from .tensor.search import argmax #DEFINE_ALIAS ...@@ -194,7 +194,7 @@ from .tensor.search import argmax #DEFINE_ALIAS
# from .tensor.search import has_nan #DEFINE_ALIAS # from .tensor.search import has_nan #DEFINE_ALIAS
# from .tensor.search import masked_select #DEFINE_ALIAS # from .tensor.search import masked_select #DEFINE_ALIAS
# from .tensor.search import topk #DEFINE_ALIAS # from .tensor.search import topk #DEFINE_ALIAS
# from .tensor.search import where #DEFINE_ALIAS from .tensor.search import where #DEFINE_ALIAS
from .tensor.search import index_select #DEFINE_ALIAS from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import index_sample #DEFINE_ALIAS from .tensor.search import index_sample #DEFINE_ALIAS
from .tensor.search import nonzero #DEFINE_ALIAS from .tensor.search import nonzero #DEFINE_ALIAS
......
...@@ -6196,10 +6196,12 @@ def lod_reset(x, y=None, target_lod=None): ...@@ -6196,10 +6196,12 @@ def lod_reset(x, y=None, target_lod=None):
out.dims = [6, 1] out.dims = [6, 1]
Args: Args:
x (Variable): Input variable which could be a Tensor or LoDTensor. x (Variable): Input variable which could be a Tensor or LoDTensor.
y (Variable|None): If provided, output's LoD would be derived The data type should be int32, int64, float32 or float64.
from :attr:`y`. y (Variable, optional): If provided, output's LoD would be derived from :attr:`y`.
target_lod (list|tuple|None): One level LoD which should be considered If y's lod level>0, the data type can be any type.
If y's lod level=0, the data type should be int32.
target_lod (list|tuple, optional): One level LoD which should be considered
as target LoD when :attr:`y` not provided. as target LoD when :attr:`y` not provided.
Returns: Returns:
...@@ -6221,11 +6223,9 @@ def lod_reset(x, y=None, target_lod=None): ...@@ -6221,11 +6223,9 @@ def lod_reset(x, y=None, target_lod=None):
helper = LayerHelper("lod_reset", **locals()) helper = LayerHelper("lod_reset", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
if y is not None: if y is not None:
if y.lod_level > 0: check_type(y, 'y', (Variable), 'lod_reset')
check_variable_and_dtype( if y.lod_level == 0:
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'lod_reset') check_variable_and_dtype(y, 'y', ['int32'], 'lod_reset')
else:
check_variable_and_dtype(y, 'y', ['int32', 'int64'], 'lod_reset')
helper.append_op( helper.append_op(
type="lod_reset", inputs={'X': x, type="lod_reset", inputs={'X': x,
'Y': y}, outputs={'Out': out}) 'Y': y}, outputs={'Out': out})
...@@ -6261,9 +6261,11 @@ def lod_append(x, level): ...@@ -6261,9 +6261,11 @@ def lod_append(x, level):
x.dims = [6, 1] x.dims = [6, 1]
Args: Args:
x (Variable): Input variable which could be a tensor or LoDTensor. x (Variable): Input variable which could be a tensor or LoDTensor.
level (list|tuple|Variable): The LoD level to be appended into LoD of x. The data type should be int32, int64, float32 or float64.
level (list|tuple|Variable, optional): The LoD level to be appended into LoD of x.
If level is variable and its lod level>0, the data type can be any type.
If level is variable and its lod level=0, the data type should be int32.
Returns: Returns:
Variable: Output variable with new LoD level. Variable: Output variable with new LoD level.
...@@ -6283,6 +6285,9 @@ def lod_append(x, level): ...@@ -6283,6 +6285,9 @@ def lod_append(x, level):
if (not isinstance(level, Iterable)) and (not isinstance(level, Variable)): if (not isinstance(level, Iterable)) and (not isinstance(level, Variable)):
raise ValueError("Input(level) must be list, tuple or Variable.") raise ValueError("Input(level) must be list, tuple or Variable.")
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'lod_append')
helper = LayerHelper("lod_append", **locals()) helper = LayerHelper("lod_append", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -6291,6 +6296,8 @@ def lod_append(x, level): ...@@ -6291,6 +6296,8 @@ def lod_append(x, level):
if isinstance(level, Variable): if isinstance(level, Variable):
inputs['Y'] = level inputs['Y'] = level
if level.lod_level == 0:
check_variable_and_dtype(level, 'level', ['int32'], 'lod_append')
else: else:
attrs['target_lod'] = level attrs['target_lod'] = level
helper.append_op( helper.append_op(
......
...@@ -3033,7 +3033,7 @@ class TestBook(LayerTest): ...@@ -3033,7 +3033,7 @@ class TestBook(LayerTest):
z = layers.lod_reset(x=x, y=y) z = layers.lod_reset(x=x, y=y)
self.assertTrue(z.lod_level == 2) self.assertTrue(z.lod_level == 2)
# case 2 # case 2
lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int64') lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int32')
z = layers.lod_reset(x=x, y=lod_tensor_in) z = layers.lod_reset(x=x, y=lod_tensor_in)
self.assertTrue(z.lod_level == 1) self.assertTrue(z.lod_level == 1)
# case 3 # case 3
......
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.op import Operator
from paddle.fluid.backward import append_backward
class TestLoDAppendAPI(unittest.TestCase):
def test_api(self, use_cuda=False):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[6], dtype='float32')
level = fluid.layers.data(
name='level', shape=[3], dtype='int32', lod_level=0)
result = fluid.layers.lod_append(x, level)
x_i = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).astype("float32")
level_i = np.array([0, 2, 6]).astype("int32")
for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
[out] = exe.run(fluid.default_main_program(),
feed={'x': x_i,
'level': level_i},
fetch_list=[result],
return_numpy=False)
self.assertEqual(out.recursive_sequence_lengths(), [[2, 4]])
class TestLodAppendOpError(unittest.TestCase):
def test_error(self):
# The input(x) must be Variable.
x1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
level1 = [0, 2, 4]
self.assertRaises(TypeError, fluid.layers.lod_append, x1, level1)
#The input(level) must be Variable or list.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32')
self.assertRaises(ValueError, fluid.layers.lod_append, x2, 2)
# Input(x) dtype must be float32 or float64 or int32 or int64
for dtype in ["bool", "float16"]:
x3 = fluid.layers.data(name='x3_' + dtype, shape=[4], dtype=dtype)
level3 = fluid.layers.data(
name='level3' + dtype, shape=[4], dtype='int32', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_append, x3, level3)
# Input(level) dtype must be int32 when lod_level=0
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
x4 = fluid.layers.data(
name='x4' + dtype, shape=[4], dtype='float32')
level4 = fluid.layers.data(
name='level4_' + dtype, shape=[4], dtype=dtype, lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_append, x4, level4)
if __name__ == "__main__":
unittest.main()
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
...@@ -136,28 +137,26 @@ class TestLodAppendOpByAttr(OpTest): ...@@ -136,28 +137,26 @@ class TestLodAppendOpByAttr(OpTest):
class TestLodResetOpError(unittest.TestCase): class TestLodResetOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The input must be Variable.
def test_Variable(): x1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
# The input must be Variable. target_lod = [2, 2]
x1 = fluid.create_lod_tensor( self.assertRaises(TypeError, fluid.layers.lod_reset, x1, target_lod)
np.ones([6]), [3, 3], fluid.CPUPlace())
y1 = fluid.create_lod_tensor( # Input(x) dtype must be float32 or float64 or int32 or int64
np.ones([6]), [2, 2, 2], fluid.CPUPlace()) for dtype in ["bool", "float16"]:
self.assertRaises(TypeError, fluid.layers.lod_reset, [x1, y1]) x2 = fluid.layers.data(
name='x2' + dtype, shape=[4], dtype=dtype)
def test_type():
# dtype must be float32 or float64 or int32 or int64
x2 = fluid.layers.data(shape=[4], dtype='uint8', name='x2')
y2 = fluid.layers.data( y2 = fluid.layers.data(
shape=[4], dtype='uint8', name='x2', lod_level=2) name='y2' + dtype, shape=[4], dtype='int32', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_reset, [x2, y2]) self.assertRaises(TypeError, fluid.layers.lod_reset, x2, y2)
def test_type2(): # Input(y) dtype must be int32 when lod_level=0
# dtype must be int32 or int64 for dtype in ["bool", "float16", "float32", "float64", "int64"]:
x3 = fluid.layers.data(shape=[4], dtype='float32', name='x3') x3 = fluid.layers.data(
name='x3' + dtype, shape=[4], dtype='float32')
y3 = fluid.layers.data( y3 = fluid.layers.data(
shape=[4], dtype='float32', name='x3', lod_level=0) name='y3' + dtype, shape=[4], dtype=dtype, lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_reset, [x3, y3]) self.assertRaises(TypeError, fluid.layers.lod_reset, x3, y3)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,9 +16,9 @@ from __future__ import print_function ...@@ -16,9 +16,9 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.tensor as tensor
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
...@@ -60,61 +60,64 @@ class TestWhereOp3(TestWhereOp): ...@@ -60,61 +60,64 @@ class TestWhereOp3(TestWhereOp):
class TestWhereAPI(unittest.TestCase): class TestWhereAPI(unittest.TestCase):
def test_api(self, use_cuda=False): def setUp(self):
main_program = Program() self.init_data()
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4], dtype='float32')
y = fluid.layers.data(name='y', shape=[4], dtype='float32')
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
cond_i = np.array([False, False, True, True]).astype("bool")
result = tensor.where(x > 1, x=x, y=y)
for use_cuda in [False, True]: def init_data(self):
if use_cuda and not fluid.core.is_compiled_with_cuda(): self.shape = [10, 15]
return self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32)
exe = fluid.Executor(place) self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32)
out = exe.run(fluid.default_main_program(), self.out = np.where(self.cond, self.x, self.y)
feed={'x': x_i,
'y': y_i},
fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
def test_grad(self, use_cuda=False): def ref_x_backward(self, dout):
main_program = Program() return np.where(self.cond == True, dout, 0)
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4], dtype='float32') def ref_y_backward(self, dout):
y = fluid.layers.data(name='y', shape=[4], dtype='float32') return np.where(self.cond == False, dout, 0)
for x_stop_gradient, y_stop_gradient in [[False, False],
[True, False], def test_api(self, use_cuda=False):
[False, True]]: for x_stop_gradient in [False, True]:
x.stop_gradient = x_stop_gradient for y_stop_gradient in [False, True]:
y.stop_gradient = y_stop_gradient with fluid.program_guard(Program(), Program()):
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") cond = fluid.layers.data(
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") name='cond', shape=self.shape, dtype='bool')
cond_i = np.array([False, False, True, True]).astype("bool") x = fluid.layers.data(
result = tensor.where(x > 1, x=x, y=y) name='x', shape=self.shape, dtype='float32')
x_mean = layers.mean(x) y = fluid.layers.data(
append_backward(x_mean) name='y', shape=self.shape, dtype='float32')
y_mean = layers.mean(y) x.stop_gradient = x_stop_gradient
append_backward(y_mean) y.stop_gradient = y_stop_gradient
result = paddle.where(cond, x, y)
for use_cuda in [False, True]: append_backward(layers.mean(result))
if use_cuda and not fluid.core.is_compiled_with_cuda():
return for use_cuda in [False, True]:
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() if use_cuda and not fluid.core.is_compiled_with_cuda():
exe = fluid.Executor(place) break
out = exe.run( place = fluid.CUDAPlace(
fluid.default_main_program(), 0) if use_cuda else fluid.CPUPlace()
feed={'x': x_i, exe = fluid.Executor(place)
'y': y_i}, fetch_list = [result, result.grad_name]
fetch_list=[result, x.grad_name, y.grad_name]) if x_stop_gradient is False:
x_grad = [0.25] * 4 fetch_list.append(x.grad_name)
y_grad = [0.25] * 4 if y_stop_gradient is False:
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) fetch_list.append(y.grad_name)
assert np.array_equal(out[1], x_grad) out = exe.run(
assert np.array_equal(out[2], y_grad) fluid.default_main_program(),
feed={'cond': self.cond,
'x': self.x,
'y': self.y},
fetch_list=fetch_list)
assert np.array_equal(out[0], self.out)
if x_stop_gradient is False:
assert np.array_equal(out[2],
self.ref_x_backward(out[1]))
if y.stop_gradient is False:
assert np.array_equal(
out[3], self.ref_y_backward(out[1]))
elif y.stop_gradient is False:
assert np.array_equal(out[2],
self.ref_y_backward(out[1]))
def test_api_broadcast(self, use_cuda=False): def test_api_broadcast(self, use_cuda=False):
main_program = Program() main_program = Program()
...@@ -124,9 +127,7 @@ class TestWhereAPI(unittest.TestCase): ...@@ -124,9 +127,7 @@ class TestWhereAPI(unittest.TestCase):
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
y_i = np.array([[1.0, 1.0, 1.0, 1.0], y_i = np.array([[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0]]).astype("float32") [1.0, 1.0, 1.0, 1.0]]).astype("float32")
cond_i = np.array([[False, False, True, True], result = paddle.where(x > 1, x=x, y=y)
[False, False, True, True]]).astype("bool")
result = tensor.where(x > 1, x=x, y=y)
for use_cuda in [False, True]: for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
...@@ -137,7 +138,7 @@ class TestWhereAPI(unittest.TestCase): ...@@ -137,7 +138,7 @@ class TestWhereAPI(unittest.TestCase):
feed={'x': x_i, feed={'x': x_i,
'y': y_i}, 'y': y_i},
fetch_list=[result]) fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i))
class TestWhereDygraphAPI(unittest.TestCase): class TestWhereDygraphAPI(unittest.TestCase):
...@@ -149,7 +150,7 @@ class TestWhereDygraphAPI(unittest.TestCase): ...@@ -149,7 +150,7 @@ class TestWhereDygraphAPI(unittest.TestCase):
x = fluid.dygraph.to_variable(x_i) x = fluid.dygraph.to_variable(x_i)
y = fluid.dygraph.to_variable(y_i) y = fluid.dygraph.to_variable(y_i)
cond = fluid.dygraph.to_variable(cond_i) cond = fluid.dygraph.to_variable(cond_i)
out = tensor.where(cond, x, y) out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
...@@ -161,7 +162,7 @@ class TestWhereOpError(unittest.TestCase): ...@@ -161,7 +162,7 @@ class TestWhereOpError(unittest.TestCase):
cond_i = np.array([False, False, True, True]).astype("bool") cond_i = np.array([False, False, True, True]).astype("bool")
def test_Variable(): def test_Variable():
tensor.where(cond_i, x_i, y_i) paddle.where(cond_i, x_i, y_i)
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
...@@ -169,7 +170,7 @@ class TestWhereOpError(unittest.TestCase): ...@@ -169,7 +170,7 @@ class TestWhereOpError(unittest.TestCase):
x = fluid.layers.data(name='x', shape=[4], dtype='bool') x = fluid.layers.data(name='x', shape=[4], dtype='bool')
y = fluid.layers.data(name='y', shape=[4], dtype='float16') y = fluid.layers.data(name='y', shape=[4], dtype='float16')
cond = fluid.layers.data(name='cond', shape=[4], dtype='int32') cond = fluid.layers.data(name='cond', shape=[4], dtype='int32')
tensor.where(cond, x, y) paddle.where(cond, x, y)
self.assertRaises(TypeError, test_type) self.assertRaises(TypeError, test_type)
......
...@@ -388,9 +388,9 @@ def where(condition, x, y, name=None): ...@@ -388,9 +388,9 @@ def where(condition, x, y, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.tensor as paddle
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
...@@ -417,8 +417,7 @@ def where(condition, x, y, name=None): ...@@ -417,8 +417,7 @@ def where(condition, x, y, name=None):
return core.ops.where(condition, x, y) return core.ops.where(condition, x, y)
else: else:
helper = LayerHelper("where", **locals()) helper = LayerHelper("where", **locals())
dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='where', type='where',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册