未验证 提交 ff0ab756 编写于 作者: G GaoWei8 提交者: GitHub

polish tensor.where codes and english document (#23687)

上级 52979565
...@@ -102,7 +102,7 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -102,7 +102,7 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), The second input tensor of where op. When the " "(Tensor), The second input tensor of where op. When the "
"corresponding position of condition is false, the output takes " "corresponding position of condition is false, the output takes "
"the element of Y."); "the element of Y.");
AddOutput("Out", "(Tensor), The output tensor of mul op."); AddOutput("Out", "(Tensor), The output tensor of where op.");
AddComment(R"DOC( AddComment(R"DOC(
Where Operator. Where Operator.
Return a tensor of elements selected from either $X$ or $Y$, depending on condition. Return a tensor of elements selected from either $X$ or $Y$, depending on condition.
......
...@@ -48,9 +48,6 @@ class WhereKernel<platform::CUDADeviceContext, T> ...@@ -48,9 +48,6 @@ class WhereKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::PermissionDenied("It must use CUDAPlace."));
auto* condition = context.Input<framework::Tensor>("Condition"); auto* condition = context.Input<framework::Tensor>("Condition");
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y"); auto* Y = context.Input<framework::Tensor>("Y");
...@@ -78,10 +75,6 @@ class WhereGradKernel<platform::CUDADeviceContext, T> ...@@ -78,10 +75,6 @@ class WhereGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::PermissionDenied("It must use CUDAPlace."));
auto* condition = context.Input<framework::Tensor>("Condition"); auto* condition = context.Input<framework::Tensor>("Condition");
const bool* cond_data = condition->data<bool>(); const bool* cond_data = condition->data<bool>();
auto numel = condition->numel(); auto numel = condition->numel();
......
...@@ -68,45 +68,53 @@ class TestWhereAPI(unittest.TestCase): ...@@ -68,45 +68,53 @@ class TestWhereAPI(unittest.TestCase):
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
cond_i = np.array([False, False, True, True]).astype("bool") cond_i = np.array([False, False, True, True]).astype("bool")
result = tensor.where(x > 1, X=x, Y=y) result = tensor.where(x > 1, x=x, y=y)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() for use_cuda in [False, True]:
exe = fluid.Executor(place) if use_cuda and not fluid.core.is_compiled_with_cuda():
out = exe.run(fluid.default_main_program(), return
feed={'x': x_i, place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
'y': y_i}, exe = fluid.Executor(place)
fetch_list=[result]) out = exe.run(fluid.default_main_program(),
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) feed={'x': x_i,
'y': y_i},
fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
def test_grad(self, use_cuda=False): def test_grad(self, use_cuda=False):
main_program = Program() main_program = Program()
for x_stop_gradient, y_stop_gradient in [[False, False], [True, False], with fluid.program_guard(main_program):
[False, True]]: x = fluid.layers.data(name='x', shape=[4], dtype='float32')
with fluid.program_guard(main_program): y = fluid.layers.data(name='y', shape=[4], dtype='float32')
x = fluid.layers.data(name='x', shape=[4], dtype='float32') for x_stop_gradient, y_stop_gradient in [[False, False],
y = fluid.layers.data(name='y', shape=[4], dtype='float32') [True, False],
[False, True]]:
x.stop_gradient = x_stop_gradient x.stop_gradient = x_stop_gradient
y.stop_gradient = y_stop_gradient y.stop_gradient = y_stop_gradient
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
cond_i = np.array([False, False, True, True]).astype("bool") cond_i = np.array([False, False, True, True]).astype("bool")
result = tensor.where(x > 1, X=x, Y=y) result = tensor.where(x > 1, x=x, y=y)
x_mean = layers.mean(x) x_mean = layers.mean(x)
append_backward(x_mean) append_backward(x_mean)
y_mean = layers.mean(y) y_mean = layers.mean(y)
append_backward(y_mean) append_backward(y_mean)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() for use_cuda in [False, True]:
exe = fluid.Executor(place) if use_cuda and not fluid.core.is_compiled_with_cuda():
out = exe.run(fluid.default_main_program(), return
feed={'x': x_i, place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
'y': y_i}, exe = fluid.Executor(place)
fetch_list=[result, x.grad_name, y.grad_name]) out = exe.run(
x_grad = [0.25] * 4 fluid.default_main_program(),
y_grad = [0.25] * 4 feed={'x': x_i,
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) 'y': y_i},
assert np.array_equal(out[1], x_grad) fetch_list=[result, x.grad_name, y.grad_name])
assert np.array_equal(out[2], y_grad) x_grad = [0.25] * 4
y_grad = [0.25] * 4
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
assert np.array_equal(out[1], x_grad)
assert np.array_equal(out[2], y_grad)
def test_api_broadcast(self, use_cuda=False): def test_api_broadcast(self, use_cuda=False):
main_program = Program() main_program = Program()
...@@ -114,25 +122,22 @@ class TestWhereAPI(unittest.TestCase): ...@@ -114,25 +122,22 @@ class TestWhereAPI(unittest.TestCase):
x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32') x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32')
y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32') y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32')
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
y_i = np.array( y_i = np.array([[1.0, 1.0, 1.0, 1.0],
[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32") [1.0, 1.0, 1.0, 1.0]]).astype("float32")
cond_i = np.array([[False, False, True, True], cond_i = np.array([[False, False, True, True],
[False, False, True, True]]).astype("bool") [False, False, True, True]]).astype("bool")
result = tensor.where(x > 1, X=x, Y=y) result = tensor.where(x > 1, x=x, y=y)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() for use_cuda in [False, True]:
exe = fluid.Executor(place) if use_cuda and not fluid.core.is_compiled_with_cuda():
out = exe.run(fluid.default_main_program(), return
feed={'x': x_i, place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
'y': y_i}, exe = fluid.Executor(place)
fetch_list=[result]) out = exe.run(fluid.default_main_program(),
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) feed={'x': x_i,
'y': y_i},
def test_fw_bw(self): fetch_list=[result])
if core.is_compiled_with_cuda(): assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
self.test_api(use_cuda=True)
self.test_api_broadcast(use_cuda=True)
self.test_grad(use_cuda=True)
class TestWhereDygraphAPI(unittest.TestCase): class TestWhereDygraphAPI(unittest.TestCase):
......
...@@ -13,17 +13,8 @@ ...@@ -13,17 +13,8 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import warnings
import six
import os
import inspect
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..fluid.initializer import Normal, Constant, NumpyArrayInitializer
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..fluid import dygraph_utils
from ..fluid.param_attr import ParamAttr
from ..fluid import unique_name
from ..fluid import core, layers from ..fluid import core, layers
# TODO: define searching & indexing functions of a tensor # TODO: define searching & indexing functions of a tensor
...@@ -224,43 +215,61 @@ def sort(input, axis=-1, descending=False, out=None, name=None): ...@@ -224,43 +215,61 @@ def sort(input, axis=-1, descending=False, out=None, name=None):
return out, ids return out, ids
def where(Condition, X, Y): def where(condition, x, y, name=None):
""" """
Return a tensor of elements selected from either $X$ or $Y$, depending on $Condition$. Return a tensor of elements selected from either $x$ or $y$, depending on $condition$.
.. math::
out_i =
\\begin{cases}
x_i, \quad \\text{if} \\ condition_i \\ is \\ True \\\\
y_i, \quad \\text{if} \\ condition_i \\ is \\ False \\\\
\\end{cases}
Args: Args:
Condition(Variable): A bool tensor with rank at least 1, the data type is bool. condition(Variable): The condition to choose x or y.
X(Variable): X is a Tensor Variable. x(Variable): x is a Tensor Variable with data type float32, float64, int32, int64.
Y(Variable): Y is a Tensor Variable. y(Variable): y is a Tensor Variable with data type float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns: Returns:
out : The tensor. Variable: A Tensor with the same data dype as x.
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np import numpy as np
import paddle as paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.tensor as paddle
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
x = fluid.dygraph.to_variable(x_i) x = fluid.dygraph.to_variable(x_i)
y = fluid.dygraph.to_variable(y_i) y = fluid.dygraph.to_variable(y_i)
out = paddle.where(x>1, x, y) out = paddle.where(x>1, x, y)
print(out.numpy())
#out: [1.0, 1.0, 3.2, 1.2] print(out.numpy())
#out: [1.0, 1.0, 3.2, 1.2]
""" """
if not in_dygraph_mode(): if not in_dygraph_mode():
check_variable_and_dtype(Condition, 'Condition', ['bool'], 'where') check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype( check_variable_and_dtype(
X, 'X', ['float32', 'float64', 'int32', 'int64'], 'where') x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where')
check_variable_and_dtype( check_variable_and_dtype(
Y, 'Y', ['float32', 'float64', 'int32', 'int64'], 'where') y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where')
X_shape = list(X.shape) x_shape = list(x.shape)
Y_shape = list(Y.shape) y_shape = list(y.shape)
if X_shape == Y_shape: if x_shape == y_shape:
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.where(Condition, X, Y) return core.ops.where(condition, x, y)
else: else:
helper = LayerHelper("where", **locals()) helper = LayerHelper("where", **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
...@@ -268,16 +277,16 @@ def where(Condition, X, Y): ...@@ -268,16 +277,16 @@ def where(Condition, X, Y):
helper.append_op( helper.append_op(
type='where', type='where',
inputs={'Condition': Condition, inputs={'Condition': condition,
'X': X, 'X': x,
'Y': Y}, 'Y': y},
outputs={'Out': [out]}) outputs={'Out': [out]})
return out return out
else: else:
cond_int = layers.cast(Condition, X.dtype) cond_int = layers.cast(condition, x.dtype)
cond_not_int = layers.cast(layers.logical_not(Condition), X.dtype) cond_not_int = layers.cast(layers.logical_not(condition), x.dtype)
out1 = layers.elementwise_mul(X, cond_int) out1 = layers.elementwise_mul(x, cond_int)
out2 = layers.elementwise_mul(Y, cond_not_int) out2 = layers.elementwise_mul(y, cond_not_int)
out = layers.elementwise_add(out1, out2) out = layers.elementwise_add(out1, out2)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册