未验证 提交 e53d1837 编写于 作者: H hong 提交者: GitHub

Add expand equal all yaml (#41540)

* add expand, poisson

* add poison grad

* add expand equal_all poisson triangular solve yaml
上级 f48a37ef
......@@ -16,7 +16,11 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#define MAX_RANK_SUPPORTED 6
......@@ -29,70 +33,6 @@ class ExpandV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandV2");
auto x_dims = ctx->GetInputDim("X");
auto expand_shape = ctx->Attrs().Get<std::vector<int>>("shape");
if (expand_shape.size() == 0) {
expand_shape = std::vector<int>(x_dims.size(), -1);
}
PADDLE_ENFORCE_GE(
expand_shape.size(), static_cast<size_t>(x_dims.size()),
platform::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"expand_v2 op must be greater than or equal to the rank "
"(%d) of the input.",
expand_shape.size(), static_cast<size_t>(x_dims.size())));
PADDLE_ENFORCE_LE(expand_shape.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"must not be greater than %d.",
expand_shape.size(), MAX_RANK_SUPPORTED));
PADDLE_ENFORCE_GE(expand_shape.size(), 1,
platform::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"must be a positive integer.",
expand_shape.size()));
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), expand_shape.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = phi::vectorize<int>(x_dims);
auto diff = expand_shape.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (x_dims[i] == -1) {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
if (static_cast<size_t>(x_dims.size()) > i) {
out_shape[i] = x_dims[i];
} else {
out_shape[i] = -1;
}
} else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
expand_shape[i], 0,
platform::errors::InvalidArgument(
"The %uth element of 'shape' for expand_v2 op must be "
"greater than 0, but the value given is %d.",
i, expand_shape[i]));
out_shape[i] = expand_shape[i];
}
}
ctx->SetOutputDim("Out", phi::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -291,10 +231,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(expand_v2, ExpandInferShapeFunctor,
PD_INFER_META(phi::ExpandInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker,
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>);
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>,
ExpandInferShapeFunctor);
REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp,
ops::ExpandV2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2DoubleGradOpMaker<paddle::imperative::OpBase>,
......
......@@ -405,6 +405,78 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim);
}
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out) {
#define MAX_RANK_SUPPORTED 6
auto x_dims = x.dims();
auto expand_shape = shape.GetData();
if (expand_shape.size() == 0) {
expand_shape = std::vector<int64_t>(x_dims.size(), -1);
}
PADDLE_ENFORCE_GE(
expand_shape.size(),
static_cast<size_t>(x_dims.size()),
phi::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"expand_v2 op must be greater than or equal to the rank "
"(%d) of the input.",
expand_shape.size(),
static_cast<size_t>(x_dims.size())));
PADDLE_ENFORCE_LE(
expand_shape.size(),
MAX_RANK_SUPPORTED,
phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for "
"must not be greater than %d.",
expand_shape.size(),
MAX_RANK_SUPPORTED));
PADDLE_ENFORCE_GE(
expand_shape.size(),
1,
phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for "
"must be a positive integer.",
expand_shape.size()));
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), expand_shape.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = phi::vectorize<int>(x_dims);
auto diff = expand_shape.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (x_dims[i] == -1) {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
if (static_cast<size_t>(x_dims.size()) > i) {
out_shape[i] = x_dims[i];
} else {
out_shape[i] = -1;
}
} else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
expand_shape[i],
0,
phi::errors::InvalidArgument(
"The %uth element of 'shape' for expand_v2 op must be "
"greater than 0, but the value given is %d.",
i,
expand_shape[i]));
out_shape[i] = expand_shape[i];
}
}
out->set_dims(make_ddim(out_shape));
out->set_dtype(x.dtype());
if (out_shape[0] == x_dims[0]) {
out->share_lod(x);
}
}
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -85,6 +85,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w,
MetaTensor* out_v);
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -20,7 +20,9 @@
namespace phi {
template <typename T, typename Context>
void PoissonGradKernel(const Context& ctx, DenseTensor* x_grad) {
void PoissonGradKernel(const Context& ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(ctx, x_grad, static_cast<T>(0));
......
......@@ -20,6 +20,8 @@
namespace phi {
template <typename T, typename Context>
void PoissonGradKernel(const Context& ctx, DenseTensor* x_grad);
void PoissonGradKernel(const Context& ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
......@@ -18,7 +18,8 @@ namespace phi {
KernelSignature PoissonGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("poisson_grad", {}, {}, {GradVarName("X")});
return KernelSignature(
"poisson_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
......
......@@ -28,12 +28,13 @@ def create_test_not_equal_class(op_type, typename, callback):
x = np.random.random(size=(10, 7)).astype(typename)
y = np.random.random(size=(10, 7)).astype(typename)
z = callback(x, y)
self.python_api = paddle.tensor.equal_all
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type
def test_output(self):
self.check_output()
self.check_output(check_eager=True)
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal_all')
Cls.__name__ = cls_name
......@@ -46,12 +47,13 @@ def create_test_not_shape_equal_class(op_type, typename, callback):
x = np.random.random(size=(10, 7)).astype(typename)
y = np.random.random(size=(10)).astype(typename)
z = callback(x, y)
self.python_api = paddle.tensor.equal_all
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type
def test_output(self):
self.check_output()
self.check_output(check_eager=True)
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_shape_equal_all')
Cls.__name__ = cls_name
......@@ -63,12 +65,13 @@ def create_test_equal_class(op_type, typename, callback):
def setUp(self):
x = y = np.random.random(size=(10, 7)).astype(typename)
z = callback(x, y)
self.python_api = paddle.tensor.equal_all
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type
def test_output(self):
self.check_output()
self.check_output(check_eager=True)
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
Cls.__name__ = cls_name
......@@ -82,12 +85,13 @@ def create_test_dim1_class(op_type, typename, callback):
x = np.array([True, False, True]).astype(typename)
x = np.array([False, False, True]).astype(typename)
z = callback(x, y)
self.python_api = paddle.tensor.equal_all
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type
def test_output(self):
self.check_output()
self.check_output(check_eager=True)
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
Cls.__name__ = cls_name
......
......@@ -40,10 +40,10 @@ class TestExpandV2OpRank1(OpTest):
self.expand_times = [1]
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestExpandV2OpRank2_DimExpanding(TestExpandV2OpRank1):
......
......@@ -18,6 +18,7 @@ import numpy as np
from op_test import OpTest
import math
import os
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static()
paddle.seed(100)
......@@ -96,11 +97,18 @@ class TestPoissonAPI(unittest.TestCase):
self.assertTrue(np.min(y_np) >= 0)
def test_dygraph(self):
paddle.disable_static()
x = paddle.randn([10, 10], dtype='float32')
y = paddle.poisson(x)
self.assertTrue(np.min(y.numpy()) >= 0)
paddle.enable_static()
with paddle.fluid.dygraph.base.guard():
x = paddle.randn([10, 10], dtype='float32')
y = paddle.poisson(x)
self.assertTrue(np.min(y.numpy()) >= 0)
with _test_eager_guard():
x = paddle.randn([10, 10], dtype='float32')
x.stop_gradient = False
y = paddle.poisson(x)
y.backward()
self.assertTrue(np.min(y.numpy()) >= 0)
self.assertTrue(np.array_equal(np.zeros_like(x), x.gradient()))
def test_fixed_random_number(self):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
......
......@@ -47,6 +47,7 @@ class TestTriangularSolveOp(OpTest):
def setUp(self):
self.op_type = "triangular_solve"
self.python_api = paddle.tensor.linalg.triangular_solve
self.config()
self.inputs = {
......@@ -62,10 +63,10 @@ class TestTriangularSolveOp(OpTest):
self.outputs = {'Out': self.output}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
# 2D(broadcast) + 3D, test 'transpose'
......
......@@ -2834,6 +2834,10 @@ def triangular_solve(x,
print(out)
# [7, -2, -5]
"""
if in_dygraph_mode():
return _C_ops.final_state_triangular_solve(x, y, upper, transpose,
unitriangular)
if paddle.in_dynamic_mode():
return _C_ops.triangular_solve(x, y, 'upper', upper, 'transpose',
transpose, 'unitriangular',
......
......@@ -301,6 +301,9 @@ def equal_all(x, y, name=None):
result2 = paddle.equal_all(x, z)
print(result2) # result2 = [False ]
"""
if in_dygraph_mode():
return _C_ops.final_state_equal_all(x, y)
if paddle.in_dynamic_mode():
return _C_ops.equal_all(x, y)
......
......@@ -2000,6 +2000,9 @@ def expand(x, shape, name=None):
print(out)
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
return _C_ops.final_state_expand(x, shape)
if paddle.in_dynamic_mode():
return _C_ops.expand_v2(x, 'shape', shape)
......
......@@ -603,6 +603,14 @@
kernel :
func : equal
- api : equal_all
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : CompareAllInferMeta
kernel :
func : equal_all
# erf
- api : erf
args : (Tensor x)
......@@ -633,6 +641,16 @@
func : exp
backward : exp_grad
# expand
- api : expand
args : (Tensor x, IntArray shape)
output : Tensor
infer_meta :
func : ExpandInferMeta
kernel :
func : expand
backward : expand_grad
# expand_as
- api : expand_as
args : (Tensor x, Tensor y, int[] target_shape)
......@@ -1513,7 +1531,7 @@
func : pixel_shuffle
backward : pixel_shuffle_grad
# poisson // no need grad
# poisson
- api : poisson
args : (Tensor x)
output : Tensor
......@@ -1521,6 +1539,7 @@
func : UnchangedInferMeta
kernel :
func : poisson
backward : poisson_grad
- api : pool2d
args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
......@@ -2066,7 +2085,7 @@
func : TriangularSolveInferMeta
kernel :
func : triangular_solve
# backward : triangular_solve_grad
backward : triangular_solve_grad
- api : tril_triu
args : (Tensor x, int diagonal, bool lower)
......
......@@ -492,6 +492,16 @@
func : expand_as_grad
no_need_buffer : x
- backward_api : expand_grad
forward : expand (Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray shape)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : expand_grad
- backward_api : expm1_grad
forward : expm1 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......@@ -1159,6 +1169,16 @@
kernel :
func : pixel_shuffle_grad
- backward_api : poisson_grad
forward : poisson (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : poisson_grad
- backward_api : pool2d_grad
forward : pool2d(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
......@@ -1685,6 +1705,16 @@
kernel :
func : transpose_grad
- backward_api : triangular_solve_grad
forward : triangular_solve (Tensor x, Tensor y, bool upper, bool tranpose, bool unitriangular) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool tranpose, bool unitriangular)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : triangular_solve_grad
- backward_api : tril_triu_grad
forward : tril_triu(Tensor x, int diagonal, bool lower) -> Tensor(out)
args : (Tensor out_grad, int diagonal, bool lower)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册