未验证 提交 490db7f3 编写于 作者: G GaoWei8 提交者: GitHub

add paddle.tensor interface (#23801)

* add paddle.tensor
test=develop

* polish gpu where codes
test=develop

* polish test code
test=develop
上级 e9289e8c
...@@ -30,15 +30,15 @@ __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, ...@@ -30,15 +30,15 @@ __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
} }
template <typename T> template <typename T>
__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond, __global__ void WhereGradCUDAKernel(const int N, const T* dout,
T* x, T* y) { const bool* cond, T* dx, T* dy) {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) { for (; idx < N; idx += blockDim.x * gridDim.x) {
if (x != nullptr) { if (dx != nullptr) {
x[idx] = out[idx] * (cond[idx] ? 1. : 0.); dx[idx] = cond[idx] ? dout[idx] : 0.;
} }
if (y != nullptr) { if (dy != nullptr) {
y[idx] = out[idx] * (cond[idx] ? 0. : 1.); dy[idx] = cond[idx] ? 0. : dout[idx];
} }
} }
} }
......
...@@ -191,7 +191,7 @@ from .tensor.search import argmax #DEFINE_ALIAS ...@@ -191,7 +191,7 @@ from .tensor.search import argmax #DEFINE_ALIAS
# from .tensor.search import has_nan #DEFINE_ALIAS # from .tensor.search import has_nan #DEFINE_ALIAS
# from .tensor.search import masked_select #DEFINE_ALIAS # from .tensor.search import masked_select #DEFINE_ALIAS
# from .tensor.search import topk #DEFINE_ALIAS # from .tensor.search import topk #DEFINE_ALIAS
# from .tensor.search import where #DEFINE_ALIAS from .tensor.search import where #DEFINE_ALIAS
from .tensor.search import index_select #DEFINE_ALIAS from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import index_sample #DEFINE_ALIAS from .tensor.search import index_sample #DEFINE_ALIAS
from .tensor.search import nonzero #DEFINE_ALIAS from .tensor.search import nonzero #DEFINE_ALIAS
......
...@@ -16,9 +16,9 @@ from __future__ import print_function ...@@ -16,9 +16,9 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.tensor as tensor
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
...@@ -60,61 +60,64 @@ class TestWhereOp3(TestWhereOp): ...@@ -60,61 +60,64 @@ class TestWhereOp3(TestWhereOp):
class TestWhereAPI(unittest.TestCase): class TestWhereAPI(unittest.TestCase):
def test_api(self, use_cuda=False): def setUp(self):
main_program = Program() self.init_data()
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4], dtype='float32')
y = fluid.layers.data(name='y', shape=[4], dtype='float32')
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
cond_i = np.array([False, False, True, True]).astype("bool")
result = tensor.where(x > 1, x=x, y=y)
for use_cuda in [False, True]: def init_data(self):
if use_cuda and not fluid.core.is_compiled_with_cuda(): self.shape = [10, 15]
return self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32)
exe = fluid.Executor(place) self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32)
out = exe.run(fluid.default_main_program(), self.out = np.where(self.cond, self.x, self.y)
feed={'x': x_i,
'y': y_i},
fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
def test_grad(self, use_cuda=False): def ref_x_backward(self, dout):
main_program = Program() return np.where(self.cond == True, dout, 0)
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4], dtype='float32') def ref_y_backward(self, dout):
y = fluid.layers.data(name='y', shape=[4], dtype='float32') return np.where(self.cond == False, dout, 0)
for x_stop_gradient, y_stop_gradient in [[False, False],
[True, False], def test_api(self, use_cuda=False):
[False, True]]: for x_stop_gradient in [False, True]:
x.stop_gradient = x_stop_gradient for y_stop_gradient in [False, True]:
y.stop_gradient = y_stop_gradient with fluid.program_guard(Program(), Program()):
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") cond = fluid.layers.data(
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") name='cond', shape=self.shape, dtype='bool')
cond_i = np.array([False, False, True, True]).astype("bool") x = fluid.layers.data(
result = tensor.where(x > 1, x=x, y=y) name='x', shape=self.shape, dtype='float32')
x_mean = layers.mean(x) y = fluid.layers.data(
append_backward(x_mean) name='y', shape=self.shape, dtype='float32')
y_mean = layers.mean(y) x.stop_gradient = x_stop_gradient
append_backward(y_mean) y.stop_gradient = y_stop_gradient
result = paddle.where(cond, x, y)
for use_cuda in [False, True]: append_backward(layers.mean(result))
if use_cuda and not fluid.core.is_compiled_with_cuda():
return for use_cuda in [False, True]:
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() if use_cuda and not fluid.core.is_compiled_with_cuda():
exe = fluid.Executor(place) break
out = exe.run( place = fluid.CUDAPlace(
fluid.default_main_program(), 0) if use_cuda else fluid.CPUPlace()
feed={'x': x_i, exe = fluid.Executor(place)
'y': y_i}, fetch_list = [result, result.grad_name]
fetch_list=[result, x.grad_name, y.grad_name]) if x_stop_gradient is False:
x_grad = [0.25] * 4 fetch_list.append(x.grad_name)
y_grad = [0.25] * 4 if y_stop_gradient is False:
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) fetch_list.append(y.grad_name)
assert np.array_equal(out[1], x_grad) out = exe.run(
assert np.array_equal(out[2], y_grad) fluid.default_main_program(),
feed={'cond': self.cond,
'x': self.x,
'y': self.y},
fetch_list=fetch_list)
assert np.array_equal(out[0], self.out)
if x_stop_gradient is False:
assert np.array_equal(out[2],
self.ref_x_backward(out[1]))
if y.stop_gradient is False:
assert np.array_equal(
out[3], self.ref_y_backward(out[1]))
elif y.stop_gradient is False:
assert np.array_equal(out[2],
self.ref_y_backward(out[1]))
def test_api_broadcast(self, use_cuda=False): def test_api_broadcast(self, use_cuda=False):
main_program = Program() main_program = Program()
...@@ -124,9 +127,7 @@ class TestWhereAPI(unittest.TestCase): ...@@ -124,9 +127,7 @@ class TestWhereAPI(unittest.TestCase):
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
y_i = np.array([[1.0, 1.0, 1.0, 1.0], y_i = np.array([[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0]]).astype("float32") [1.0, 1.0, 1.0, 1.0]]).astype("float32")
cond_i = np.array([[False, False, True, True], result = paddle.where(x > 1, x=x, y=y)
[False, False, True, True]]).astype("bool")
result = tensor.where(x > 1, x=x, y=y)
for use_cuda in [False, True]: for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
...@@ -137,7 +138,7 @@ class TestWhereAPI(unittest.TestCase): ...@@ -137,7 +138,7 @@ class TestWhereAPI(unittest.TestCase):
feed={'x': x_i, feed={'x': x_i,
'y': y_i}, 'y': y_i},
fetch_list=[result]) fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i))
class TestWhereDygraphAPI(unittest.TestCase): class TestWhereDygraphAPI(unittest.TestCase):
...@@ -149,7 +150,7 @@ class TestWhereDygraphAPI(unittest.TestCase): ...@@ -149,7 +150,7 @@ class TestWhereDygraphAPI(unittest.TestCase):
x = fluid.dygraph.to_variable(x_i) x = fluid.dygraph.to_variable(x_i)
y = fluid.dygraph.to_variable(y_i) y = fluid.dygraph.to_variable(y_i)
cond = fluid.dygraph.to_variable(cond_i) cond = fluid.dygraph.to_variable(cond_i)
out = tensor.where(cond, x, y) out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
...@@ -161,7 +162,7 @@ class TestWhereOpError(unittest.TestCase): ...@@ -161,7 +162,7 @@ class TestWhereOpError(unittest.TestCase):
cond_i = np.array([False, False, True, True]).astype("bool") cond_i = np.array([False, False, True, True]).astype("bool")
def test_Variable(): def test_Variable():
tensor.where(cond_i, x_i, y_i) paddle.where(cond_i, x_i, y_i)
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
...@@ -169,7 +170,7 @@ class TestWhereOpError(unittest.TestCase): ...@@ -169,7 +170,7 @@ class TestWhereOpError(unittest.TestCase):
x = fluid.layers.data(name='x', shape=[4], dtype='bool') x = fluid.layers.data(name='x', shape=[4], dtype='bool')
y = fluid.layers.data(name='y', shape=[4], dtype='float16') y = fluid.layers.data(name='y', shape=[4], dtype='float16')
cond = fluid.layers.data(name='cond', shape=[4], dtype='int32') cond = fluid.layers.data(name='cond', shape=[4], dtype='int32')
tensor.where(cond, x, y) paddle.where(cond, x, y)
self.assertRaises(TypeError, test_type) self.assertRaises(TypeError, test_type)
......
...@@ -388,9 +388,9 @@ def where(condition, x, y, name=None): ...@@ -388,9 +388,9 @@ def where(condition, x, y, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.tensor as paddle
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
...@@ -417,8 +417,7 @@ def where(condition, x, y, name=None): ...@@ -417,8 +417,7 @@ def where(condition, x, y, name=None):
return core.ops.where(condition, x, y) return core.ops.where(condition, x, y)
else: else:
helper = LayerHelper("where", **locals()) helper = LayerHelper("where", **locals())
dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='where', type='where',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册