未验证 提交 5516f180 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add unbind yaml and final state api (#41277)

* add unbind yaml

* fix unittest
上级 edbb3986
......@@ -475,6 +475,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output;
}
std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "unbind API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_input = PrepareData(input, kernel.InputAt(0), {});
// Calculate the number of out tensors
auto input_shape = input.dims();
if (axis < 0) {
axis = input_shape.size() + axis;
}
auto out_num = input_shape[axis];
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_num);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_num);
for (int64_t i = 0; i < out_num; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
int,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);
return out;
}
////////////////// Backward(grad) api impls //////////////////////
// TODO(chenweihang): the original sum grad op can support higher-level
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
......@@ -73,6 +75,8 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision,
float rescale_grad);
std::vector<Tensor> unbind_impl(const Tensor& input, int axis);
////////////////// Backward(grad) api impls //////////////////////
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
......
......@@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x,
void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs) {
std::vector<MetaTensor*> outs) {
auto in_dims = x.dims();
std::vector<int> out_dim;
axis = axis < 0 ? in_dims.size() + axis : axis;
......@@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x,
}
auto out_dims = phi::make_ddim(out_dim);
for (size_t i = 0; i < outs->size(); ++i) {
(*outs)[i].set_dtype(x.dtype());
(*outs)[i].set_dims(out_dims);
(*outs)[i].set_layout(x.layout());
(*outs)[i].share_lod(x);
for (size_t i = 0; i < outs.size(); ++i) {
outs[i]->set_dtype(x.dtype());
outs[i]->set_dims(out_dims);
outs[i]->set_layout(x.layout());
outs[i]->share_lod(x);
}
}
......
......@@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x,
void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs);
std::vector<MetaTensor*> outs);
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
......
......@@ -17,9 +17,11 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid import compiler, Program, program_guard, core
from paddle.fluid.framework import _test_eager_guard
class TestUnbind(unittest.TestCase):
......@@ -39,6 +41,25 @@ class TestUnbind(unittest.TestCase):
assert np.array_equal(res_1, input_1[0, 0:100])
assert np.array_equal(res_2, input_1[1, 0:100])
def test_unbind_dygraph(self):
with fluid.dygraph.guard():
np_x = np.random.random([2, 3]).astype("float32")
x = paddle.to_tensor(np_x)
x.stop_gradient = False
[res_1, res_2] = paddle.unbind(x, 0)
self.assertTrue(np.array_equal(res_1, np_x[0, 0:100]))
self.assertTrue(np.array_equal(res_2, np_x[1, 0:100]))
out = paddle.add_n([res_1, res_2])
np_grad = np.ones(x.shape, np.float32)
out.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np_grad))
def test_unbind_dygraph_final_state(self):
with _test_eager_guard():
self.test_unbind_dygraph()
class TestLayersUnbind(unittest.TestCase):
def test_layers_unbind(self):
......@@ -157,6 +178,7 @@ class TestUnbindOp4(TestUnbindOp):
class TestUnbindBF16Op(OpTest):
def setUp(self):
self._set_op_type()
self.python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
......
......@@ -1469,6 +1469,9 @@ def unbind(input, axis=0):
# x3.shape [3, 5]
"""
if in_dygraph_mode():
return _C_ops.final_state_unbind(input, axis)
if not isinstance(axis, (int)):
raise TypeError("The type of 'axis' must be int, but received %s." %
(type(axis)))
......@@ -1477,7 +1480,7 @@ def unbind(input, axis=0):
input_shape = input.shape
axis_ = axis if axis >= 0 else len(input_shape) + axis
num = input_shape[axis_]
if paddle.in_dynamic_mode():
if _in_legacy_dygraph():
return _C_ops.unbind(input, num, 'axis', axis)
helper = LayerHelper("unbind", **locals())
......
......@@ -1939,6 +1939,12 @@
backend : place
data_type : dtype
- api : unbind
args : (Tensor input, int axis)
output : Tensor[]
invoke : unbind_impl(input, axis)
backward : unbind_grad
# unfold
- api : unfold
args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
......
......@@ -1480,6 +1480,12 @@
kernel :
func : trunc_grad
- backward_api : unbind_grad
forward : unbind (Tensor input, int axis) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
output : Tensor(input_grad)
invoke : stack(out_grad, axis)
- backward_api : unfold_grad
forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册