未验证 提交 119816f9 编写于 作者: C chentianyu03 提交者: GitHub

[Yaml]Add concat grad yaml (#41365)

* add concat_grad kernel

* fix error

* remove comment code

* fix outs nullptr error

* change to phi header

* add concat_grad declare for standalone_executor_test

* add concat_grad yaml

* add concat api

* fix test concat op error

* fix test concat op error
上级 0bcfc474
...@@ -165,7 +165,7 @@ cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place) ...@@ -165,7 +165,7 @@ cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place)
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool) cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool)
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor) cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform) cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform backward_infermeta)
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/nullary.h"
...@@ -166,5 +167,70 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -166,5 +167,70 @@ std::vector<Tensor> split_impl(const Tensor& x,
return out; return out;
} }
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"concat_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "concat_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "concat_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
// std::unique_ptr<std::vector<phi::DenseTensor>>
auto dense_x = PrepareData(x, kernel.InputAt(0), {});
auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
// Calculate the number of out tensors
size_t out_number = x.size();
std::vector<Tensor> x_grad;
auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad);
std::vector<phi::MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<phi::MetaTensor*> meta_x_ptrs;
meta_x_ptrs.reserve(x.size());
for (const auto& t : *dense_x) {
meta_x.push_back(t);
meta_x_ptrs.push_back(&meta_x.back());
}
std::vector<phi::MetaTensor> meta_x_grad;
meta_x_grad.reserve(x.size());
std::vector<phi::MetaTensor*> meta_x_grad_ptrs;
meta_x_grad_ptrs.reserve(x.size());
for (size_t i = 0; i < out_number; ++i) {
meta_x_grad.push_back(*dense_x_grad[i]);
meta_x_grad_ptrs.push_back(&meta_x_grad.back());
}
phi::UnchangedMultiInferMeta(meta_x_ptrs, meta_x_grad_ptrs);
std::vector<const phi::DenseTensor*> dense_x_ptr;
dense_x_ptr.reserve(x.size());
for (const auto& t : *dense_x) {
dense_x_ptr.push_back(&t);
}
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const phi::DenseTensor&,
const phi::Scalar&,
std::vector<phi::DenseTensor*>);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx, dense_x_ptr, *dense_out_grad, phi::Scalar(axis), dense_x_grad);
return x_grad;
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -31,5 +31,9 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -31,5 +31,9 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections, const IntArray& num_or_sections,
const Scalar& axis); const Scalar& axis);
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -1909,6 +1909,13 @@ void StackInferMeta(const std::vector<MetaTensor*>& x, ...@@ -1909,6 +1909,13 @@ void StackInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x,
std::vector<MetaTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) {
out[i]->share_meta(*x[i]);
}
}
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length, const paddle::optional<const MetaTensor&> logits_length,
......
...@@ -289,6 +289,9 @@ void StackInferMeta(const std::vector<MetaTensor*>& x, ...@@ -289,6 +289,9 @@ void StackInferMeta(const std::vector<MetaTensor*>& x,
int axis, int axis,
MetaTensor* out); MetaTensor* out);
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x,
std::vector<MetaTensor*> out);
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length, const paddle::optional<const MetaTensor&> logits_length,
......
...@@ -323,7 +323,15 @@ def concat(input, axis=0, name=None): ...@@ -323,7 +323,15 @@ def concat(input, axis=0, name=None):
# [14 15 16]] # [14 15 16]]
""" """
if _non_static_mode(): if in_dygraph_mode():
if isinstance(axis, Variable):
axis = axis.numpy()
axis = axis.item(0)
if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.final_state_concat(input, axis)
if _in_legacy_dygraph():
if isinstance(axis, Variable): if isinstance(axis, Variable):
axis = axis.numpy() axis = axis.numpy()
axis = axis.item(0) axis = axis.item(0)
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core from paddle.fluid import compiler, Program, program_guard, core
from paddle.fluid.framework import _test_eager_guard
import paddle import paddle
...@@ -49,7 +50,7 @@ class TestConcatOp(OpTest): ...@@ -49,7 +50,7 @@ class TestConcatOp(OpTest):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place)
else: else:
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.uint16: if self.dtype == np.uint16:
...@@ -58,9 +59,9 @@ class TestConcatOp(OpTest): ...@@ -58,9 +59,9 @@ class TestConcatOp(OpTest):
self.check_grad_with_place(place, ['x1'], 'Out') self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out') self.check_grad_with_place(place, ['x2'], 'Out')
else: else:
self.check_grad(['x0'], 'Out') self.check_grad(['x0'], 'Out', check_eager=True)
self.check_grad(['x1'], 'Out') self.check_grad(['x1'], 'Out', check_eager=True)
self.check_grad(['x2'], 'Out') self.check_grad(['x2'], 'Out', check_eager=True)
def init_test_data(self): def init_test_data(self):
if self.dtype == np.uint16: if self.dtype == np.uint16:
...@@ -124,6 +125,7 @@ class TestConcatOp6(TestConcatOp): ...@@ -124,6 +125,7 @@ class TestConcatOp6(TestConcatOp):
def setUp(self): def setUp(self):
self.op_type = "concat" self.op_type = "concat"
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.python_api = paddle.concat
self.init_test_data() self.init_test_data()
self.lod = [[20, 80]] self.lod = [[20, 80]]
self.out_lod = [[20, 80, 20, 80, 20, 80]] self.out_lod = [[20, 80, 20, 80, 20, 80]]
...@@ -141,12 +143,12 @@ class TestConcatOp6(TestConcatOp): ...@@ -141,12 +143,12 @@ class TestConcatOp6(TestConcatOp):
self.outputs = {'Out': (out, self.out_lod)} self.outputs = {'Out': (out, self.out_lod)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_dygraph=False) self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['x0'], 'Out', check_dygraph=False) self.check_grad(['x0'], 'Out', check_eager=True)
self.check_grad(['x1'], 'Out', check_dygraph=False) self.check_grad(['x1'], 'Out', check_eager=True)
self.check_grad(['x2'], 'Out', check_dygraph=False) self.check_grad(['x2'], 'Out', check_eager=True)
def init_test_data(self): def init_test_data(self):
self.x0 = np.random.random([100]).astype(self.dtype) self.x0 = np.random.random([100]).astype(self.dtype)
...@@ -159,6 +161,7 @@ def create_test_AxisTensor(parent): ...@@ -159,6 +161,7 @@ def create_test_AxisTensor(parent):
class TestConcatAxisTensor(parent): class TestConcatAxisTensor(parent):
def setUp(self): def setUp(self):
self.op_type = "concat" self.op_type = "concat"
self.python_api = paddle.concat
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.init_test_data() self.init_test_data()
...@@ -334,6 +337,12 @@ class TestConcatAPI(unittest.TestCase): ...@@ -334,6 +337,12 @@ class TestConcatAPI(unittest.TestCase):
self.assertEqual((out1.numpy() == np_out1).all(), True) self.assertEqual((out1.numpy() == np_out1).all(), True)
self.assertEqual((out2.numpy() == np_out2).all(), True) self.assertEqual((out2.numpy() == np_out2).all(), True)
def test_eager(self):
with _test_eager_guard():
self.test_api()
self.test_fluid_api()
self.test_imperative()
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The item in input must be Variable. # The item in input must be Variable.
...@@ -370,6 +379,7 @@ class TestConcatAPIWithLoDTensorArray(unittest.TestCase): ...@@ -370,6 +379,7 @@ class TestConcatAPIWithLoDTensorArray(unittest.TestCase):
def setUp(self): def setUp(self):
self.axis = 1 self.axis = 1
self.python = paddle.concat
self.iter_num = 3 self.iter_num = 3
self.input_shape = [2, 3] self.input_shape = [2, 3]
self.x = np.random.random(self.input_shape).astype("float32") self.x = np.random.random(self.input_shape).astype("float32")
......
...@@ -320,6 +320,7 @@ ...@@ -320,6 +320,7 @@
param : [x, axis] param : [x, axis]
kernel : kernel :
func : concat func : concat
backward : concat_grad
- api : conj - api : conj
args : (Tensor x) args : (Tensor x)
......
...@@ -179,6 +179,12 @@ ...@@ -179,6 +179,12 @@
kernel : kernel :
func : cholesky_solve_grad func : cholesky_solve_grad
- backward_api : concat_grad
forward : concat (Tensor[] x, Scalar axis) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, Scalar axis = 0)
output : Tensor[](x_grad)
invoke : concat_grad_impl(x, out_grad, axis)
- backward_api : conv2d_transpose_grad - backward_api : conv2d_transpose_grad
forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册