未验证 提交 5516f180 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add unbind yaml and final state api (#41277)

* add unbind yaml

* fix unittest
上级 edbb3986
...@@ -475,6 +475,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -475,6 +475,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output; return api_output;
} }
std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "unbind API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_input = PrepareData(input, kernel.InputAt(0), {});
// Calculate the number of out tensors
auto input_shape = input.dims();
if (axis < 0) {
axis = input_shape.size() + axis;
}
auto out_num = input_shape[axis];
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_num);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_num);
for (int64_t i = 0; i < out_num; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
int,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);
return out;
}
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
// TODO(chenweihang): the original sum grad op can support higher-level // TODO(chenweihang): the original sum grad op can support higher-level
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
...@@ -73,6 +75,8 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -73,6 +75,8 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision, bool multi_precision,
float rescale_grad); float rescale_grad);
std::vector<Tensor> unbind_impl(const Tensor& input, int axis);
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x, std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
......
...@@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x, ...@@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x,
void UnbindInferMeta(const MetaTensor& x, void UnbindInferMeta(const MetaTensor& x,
int axis, int axis,
std::vector<MetaTensor>* outs) { std::vector<MetaTensor*> outs) {
auto in_dims = x.dims(); auto in_dims = x.dims();
std::vector<int> out_dim; std::vector<int> out_dim;
axis = axis < 0 ? in_dims.size() + axis : axis; axis = axis < 0 ? in_dims.size() + axis : axis;
...@@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x, ...@@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x,
} }
auto out_dims = phi::make_ddim(out_dim); auto out_dims = phi::make_ddim(out_dim);
for (size_t i = 0; i < outs->size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
(*outs)[i].set_dtype(x.dtype()); outs[i]->set_dtype(x.dtype());
(*outs)[i].set_dims(out_dims); outs[i]->set_dims(out_dims);
(*outs)[i].set_layout(x.layout()); outs[i]->set_layout(x.layout());
(*outs)[i].share_lod(x); outs[i]->share_lod(x);
} }
} }
......
...@@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x, ...@@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x,
void UnbindInferMeta(const MetaTensor& x, void UnbindInferMeta(const MetaTensor& x,
int axis, int axis,
std::vector<MetaTensor>* outs); std::vector<MetaTensor*> outs);
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out); void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
......
...@@ -17,9 +17,11 @@ from __future__ import print_function ...@@ -17,9 +17,11 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest, convert_float_to_uint16 from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.tensor as tensor import paddle.tensor as tensor
from paddle.fluid import compiler, Program, program_guard, core from paddle.fluid import compiler, Program, program_guard, core
from paddle.fluid.framework import _test_eager_guard
class TestUnbind(unittest.TestCase): class TestUnbind(unittest.TestCase):
...@@ -39,6 +41,25 @@ class TestUnbind(unittest.TestCase): ...@@ -39,6 +41,25 @@ class TestUnbind(unittest.TestCase):
assert np.array_equal(res_1, input_1[0, 0:100]) assert np.array_equal(res_1, input_1[0, 0:100])
assert np.array_equal(res_2, input_1[1, 0:100]) assert np.array_equal(res_2, input_1[1, 0:100])
def test_unbind_dygraph(self):
with fluid.dygraph.guard():
np_x = np.random.random([2, 3]).astype("float32")
x = paddle.to_tensor(np_x)
x.stop_gradient = False
[res_1, res_2] = paddle.unbind(x, 0)
self.assertTrue(np.array_equal(res_1, np_x[0, 0:100]))
self.assertTrue(np.array_equal(res_2, np_x[1, 0:100]))
out = paddle.add_n([res_1, res_2])
np_grad = np.ones(x.shape, np.float32)
out.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np_grad))
def test_unbind_dygraph_final_state(self):
with _test_eager_guard():
self.test_unbind_dygraph()
class TestLayersUnbind(unittest.TestCase): class TestLayersUnbind(unittest.TestCase):
def test_layers_unbind(self): def test_layers_unbind(self):
...@@ -157,6 +178,7 @@ class TestUnbindOp4(TestUnbindOp): ...@@ -157,6 +178,7 @@ class TestUnbindOp4(TestUnbindOp):
class TestUnbindBF16Op(OpTest): class TestUnbindBF16Op(OpTest):
def setUp(self): def setUp(self):
self._set_op_type() self._set_op_type()
self.python_api = paddle.unbind
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.axis = 0 self.axis = 0
self.num = 3 self.num = 3
......
...@@ -1469,6 +1469,9 @@ def unbind(input, axis=0): ...@@ -1469,6 +1469,9 @@ def unbind(input, axis=0):
# x3.shape [3, 5] # x3.shape [3, 5]
""" """
if in_dygraph_mode():
return _C_ops.final_state_unbind(input, axis)
if not isinstance(axis, (int)): if not isinstance(axis, (int)):
raise TypeError("The type of 'axis' must be int, but received %s." % raise TypeError("The type of 'axis' must be int, but received %s." %
(type(axis))) (type(axis)))
...@@ -1477,7 +1480,7 @@ def unbind(input, axis=0): ...@@ -1477,7 +1480,7 @@ def unbind(input, axis=0):
input_shape = input.shape input_shape = input.shape
axis_ = axis if axis >= 0 else len(input_shape) + axis axis_ = axis if axis >= 0 else len(input_shape) + axis
num = input_shape[axis_] num = input_shape[axis_]
if paddle.in_dynamic_mode(): if _in_legacy_dygraph():
return _C_ops.unbind(input, num, 'axis', axis) return _C_ops.unbind(input, num, 'axis', axis)
helper = LayerHelper("unbind", **locals()) helper = LayerHelper("unbind", **locals())
......
...@@ -1939,6 +1939,12 @@ ...@@ -1939,6 +1939,12 @@
backend : place backend : place
data_type : dtype data_type : dtype
- api : unbind
args : (Tensor input, int axis)
output : Tensor[]
invoke : unbind_impl(input, axis)
backward : unbind_grad
# unfold # unfold
- api : unfold - api : unfold
args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
......
...@@ -1480,6 +1480,12 @@ ...@@ -1480,6 +1480,12 @@
kernel : kernel :
func : trunc_grad func : trunc_grad
- backward_api : unbind_grad
forward : unbind (Tensor input, int axis) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
output : Tensor(input_grad)
invoke : stack(out_grad, axis)
- backward_api : unfold_grad - backward_api : unfold_grad
forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out) forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册