未验证 提交 84b63a26 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add add_n(sum) infermeta and yaml (#41362)

* add add_n infermeta

* forward run success

* add add_n grad yaml
上级 0f165f0b
......@@ -31,6 +31,52 @@ limitations under the License. */
namespace paddle {
namespace experimental {
// TODO(chenweihang): the original sum grad op can support higher-level
// differentiation,
// but if we use this impl, it will not support. We need to be able to reuse
// the autograd API here, which is not yet implemented
// TODO(chenweihang): we should support call generated api in custom api impl
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "add_n_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "add_n_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {});
size_t out_number = x.size();
std::vector<Tensor> x_grad;
auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::Scalar&,
float,
bool,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
for (auto* dense_x_grad_t : dense_x_grad) {
phi::MetaTensor meta_out(dense_x_grad_t);
phi::UnchangedInferMeta(MakeMetaTensor(*dense_out_grad), &meta_out);
(*kernel_fn)(
*dev_ctx, *dense_out_grad, phi::Scalar(1.0), 0.0, true, dense_x_grad_t);
}
return x_grad;
}
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set =
......
......@@ -22,6 +22,9 @@ limitations under the License. */
namespace paddle {
namespace experimental {
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad);
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking);
std::vector<Tensor> split_impl(const Tensor& x,
......
......@@ -279,6 +279,78 @@ void AdamwInferMeta(const MetaTensor& param,
master_param_outs);
}
void AddNInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out,
MetaConfig config) {
auto N = x.size();
PADDLE_ENFORCE_GT(
N,
0,
phi::errors::InvalidArgument(
"The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d.",
N));
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}
phi::DDim in_dim({0});
for (size_t i = 0; i < x.size(); ++i) {
auto x_dim = x[i]->dims();
if (phi::product(x_dim) == 0) {
continue;
}
if (phi::product(in_dim) == 0) {
in_dim = x_dim;
} else {
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(in_dim,
x_dim,
phi::errors::InvalidArgument(
"The input tensor X of SumOp must"
" have same shape. But received X[0]'s shape = "
"[%s], X[%d]'s shape = [%s].",
in_dim,
i,
x_dim));
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(),
x_dim.size(),
phi::errors::InvalidArgument(
"The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s "
"shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(),
in_dim,
i,
x_dim.size(),
i,
x_dim));
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
in_dim[j],
x_dim[j],
phi::errors::InvalidArgument(
"The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim,
i,
x_dim));
}
}
}
}
out->set_dims(in_dim);
out->share_lod(*x[0]);
}
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
......
......@@ -117,6 +117,10 @@ void AdamwInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
void AddNInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
......
......@@ -25,6 +25,7 @@ from paddle.fluid.op import Operator
from paddle.fluid.tests.unittests.op_test import (
OpTest, convert_float_to_uint16, convert_uint16_to_float)
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard
class TestSumOp(OpTest):
......@@ -347,6 +348,27 @@ class API_Test_Add_n(unittest.TestCase):
self.assertEqual((sum_value.numpy() == expected_result).all(), True)
def test_dygraph_final_state_api(self):
with fluid.dygraph.guard():
with _test_eager_guard():
input0 = paddle.ones(shape=[2, 3], dtype='float32')
input1 = paddle.ones(shape=[2, 3], dtype='float32')
input0.stop_gradient = False
input1.stop_gradient = False
expected_result = np.empty((2, 3))
expected_result.fill(2)
sum_value = paddle.add_n([input0, input1])
self.assertEqual((sum_value.numpy() == expected_result).all(),
True)
expected_grad_result = np.empty((2, 3))
expected_grad_result.fill(1)
sum_value.backward()
self.assertEqual(
(input0.grad.numpy() == expected_grad_result).all(), True)
self.assertEqual(
(input1.grad.numpy() == expected_grad_result).all(), True)
class TestRaiseSumError(unittest.TestCase):
def test_errors(self):
......
......@@ -1068,7 +1068,11 @@ def add_n(inputs, name=None):
# [[8., 10., 12.],
# [14., 16., 18.]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
if isinstance(inputs, Variable):
inputs = [inputs]
return _C_ops.final_state_add_n(inputs)
if _in_legacy_dygraph():
if isinstance(inputs, Variable):
inputs = [inputs]
return _C_ops.sum(inputs, 'use_mkldnn', False)
......
......@@ -63,6 +63,15 @@
backward : add_grad
# no_need_buffer : x, y
- api : add_n
args : (Tensor[] x)
output : Tensor
infer_meta :
func : AddNInferMeta
kernel :
func : add_n
backward : add_n_grad
- api : addmm
args : (Tensor input, Tensor x, Tensor y, float alpha, float beta)
output : Tensor
......
......@@ -41,6 +41,13 @@
func : add_grad
no_need_buffer : x, y
- backward_api : add_n_grad
forward : add_n (Tensor[] x) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad)
output : Tensor[](x_grad)
invoke : add_n_grad_impl(x, out_grad)
no_need_buffer : x
- backward_api : addmm_grad
forward : scatter (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out)
args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha, float beta)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册