未验证 提交 5fca45ea 编写于 作者: W wawltor 提交者: GitHub

[0 Tensor support] support the 0d tensor for the cumsum (#49518)

* Add the cumsum 0d tensor

* xpu and cpu judge the 0d  tensor

* change to 2022 to 2023 in new commit

* fix the reverse logic
上级 1a8be158
...@@ -33,6 +33,27 @@ class CumOp : public framework::OperatorWithKernel { ...@@ -33,6 +33,27 @@ class CumOp : public framework::OperatorWithKernel {
} }
}; };
class CumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "cumsum");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"cumsum");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -69,12 +90,13 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -69,12 +90,13 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
protected: protected:
void Apply(GradOpPtr<T> grad_op) const override { void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("cumsum"); grad_op->SetType("cumsum_grad");
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->Input("X"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
grad_op->SetAttr("reverse", grad_op->SetAttr("reverse",
!PADDLE_GET_CONST(bool, this->GetAttr("reverse"))); PADDLE_GET_CONST(bool, this->GetAttr("reverse")));
} }
}; };
...@@ -153,6 +175,7 @@ using CPU = phi::CPUContext; ...@@ -153,6 +175,7 @@ using CPU = phi::CPUContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, DECLARE_INFER_SHAPE_FUNCTOR(cumsum,
CumsumInferShapeFunctor, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumScalarAxisInferMeta)); PD_INFER_META(phi::CumScalarAxisInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp,
LogcumsumexpInferShapeFunctor, LogcumsumexpInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta)); PD_INFER_META(phi::CumInferMeta));
...@@ -169,6 +192,7 @@ REGISTER_OPERATOR(logcumsumexp, ...@@ -169,6 +192,7 @@ REGISTER_OPERATOR(logcumsumexp,
ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>, ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>,
LogcumsumexpInferShapeFunctor); LogcumsumexpInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp); REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp);
REGISTER_OPERATOR(cumsum_grad, ops::CumGradOp);
REGISTER_OP_VERSION(cumsum).AddCheckpoint( REGISTER_OP_VERSION(cumsum).AddCheckpoint(
R"ROC( R"ROC(
......
...@@ -316,9 +316,14 @@ ...@@ -316,9 +316,14 @@
- backward_op : cumsum_grad - backward_op : cumsum_grad
forward : cumsum(Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) forward : cumsum(Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out)
args : (Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse) args : (Tensor x, Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(x_grad) output : Tensor(x_grad)
invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cumsum_grad
data_type: x
- backward_op : deformable_conv_grad - backward_op : deformable_conv_grad
forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out) forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out)
......
...@@ -424,10 +424,36 @@ void CumInferMeta(const MetaTensor& x, ...@@ -424,10 +424,36 @@ void CumInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim({phi::product(x_dims)})); out->set_dims(phi::make_ddim({phi::product(x_dims)}));
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} else { } else {
if (x_dims.size() > 0) {
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
PADDLE_ENFORCE_LT(
axis,
x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
} else {
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
errors::InvalidArgument("The axis must be -1 or 0 in 0D Tensor, "
"but the value given is %d.",
axis));
}
out->set_dims(x_dims); out->set_dims(x_dims);
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
out->share_lod(x); out->share_lod(x);
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_grad_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename Context>
void CumsumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* x_grad) {
x_grad->Resize(x.dims());
CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(cumsum_grad,
CPU,
ALL_LAYOUT,
phi::CumsumGradKernel,
float,
double,
int16_t,
int,
int64_t) {}
...@@ -57,6 +57,14 @@ void ScanKernel(const Context& dev_ctx, ...@@ -57,6 +57,14 @@ void ScanKernel(const Context& dev_ctx,
bool reverse, bool reverse,
Reducer reducer, Reducer reducer,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (x.numel() == 1) {
auto raw_dims = out->dims();
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(raw_dims);
return;
}
auto out_dims = out->dims(); auto out_dims = out->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -72,8 +80,6 @@ void ScanKernel(const Context& dev_ctx, ...@@ -72,8 +80,6 @@ void ScanKernel(const Context& dev_ctx,
axis += out_dims.size(); axis += out_dims.size();
} }
dev_ctx.template Alloc<T>(out);
int pre = 1; int pre = 1;
int post = 1; int post = 1;
int mid = out_dims[axis]; int mid = out_dims[axis];
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CumsumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_grad_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CumsumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* x_grad) {
x_grad->Resize(x.dims());
CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumsum_grad,
GPU,
ALL_LAYOUT,
phi::CumsumGradKernel,
float,
double,
int16_t,
int,
int64_t) {}
#else
PD_REGISTER_KERNEL(cumsum_grad,
GPU,
ALL_LAYOUT,
phi::CumsumGradKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16) {}
#endif
...@@ -270,6 +270,16 @@ void ScanKernel(const Context& dev_ctx, ...@@ -270,6 +270,16 @@ void ScanKernel(const Context& dev_ctx,
bool reverse, bool reverse,
Op op, Op op,
DenseTensor* out) { DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
// For 0D Tensor
if (out->numel() == 1) {
auto raw_dims = out->dims();
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(raw_dims);
return;
}
auto out_dims = out->dims(); auto out_dims = out->dims();
auto size = x.numel(); auto size = x.numel();
...@@ -286,7 +296,6 @@ void ScanKernel(const Context& dev_ctx, ...@@ -286,7 +296,6 @@ void ScanKernel(const Context& dev_ctx,
axis += out_dims.size(); axis += out_dims.size();
} }
T* out_data = dev_ctx.template Alloc<T>(out);
const T* in_data = x.data<T>(); const T* in_data = x.data<T>();
// Use thrust for parallel acceleration when the input size is equal to the // Use thrust for parallel acceleration when the input size is equal to the
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CumsumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* x_grad) {
x_grad->Resize(x.dims());
CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(
cumsum_grad, XPU, ALL_LAYOUT, phi::CumsumGradKernel, float, int, int64_t) {}
...@@ -30,6 +30,15 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -30,6 +30,15 @@ void CumsumKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (x.numel() == 1) {
int r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
// prepare for call xdnn api // prepare for call xdnn api
std::vector<int> x_shape = phi::vectorize<int>(x.dims()); std::vector<int> x_shape = phi::vectorize<int>(x.dims());
int axis_as_int = axis.to<int>(); int axis_as_int = axis.to<int>();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CumsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cumsum_grad",
{"X", "Out@GRAD"},
{"axis", "flatten", "exclusive", "reverse"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(cumsum_grad, phi::CumsumOpArgumentMapping);
...@@ -16,15 +16,12 @@ import os ...@@ -16,15 +16,12 @@ import os
import tempfile import tempfile
import unittest import unittest
import gradient_checker
import numpy as np import numpy as np
from decorator_helper import prog_scope
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
...@@ -230,7 +227,7 @@ class TestSumOpExclusive1(OpTest): ...@@ -230,7 +227,7 @@ class TestSumOpExclusive1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True} self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 65)).astype("float64") a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a} self.inputs = {'X': a}
self.outputs = { self.outputs = {
'Out': np.concatenate( 'Out': np.concatenate(
...@@ -245,12 +242,15 @@ class TestSumOpExclusive1(OpTest): ...@@ -245,12 +242,15 @@ class TestSumOpExclusive1(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOpExclusive2(OpTest): class TestSumOpExclusive2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True} self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 888)).astype("float64") a = np.random.random((1, 1, 100)).astype("float64")
self.inputs = {'X': a} self.inputs = {'X': a}
self.outputs = { self.outputs = {
'Out': np.concatenate( 'Out': np.concatenate(
...@@ -265,12 +265,15 @@ class TestSumOpExclusive2(OpTest): ...@@ -265,12 +265,15 @@ class TestSumOpExclusive2(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOpExclusive3(OpTest): class TestSumOpExclusive3(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True} self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 888)).astype("float32") a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a} self.inputs = {'X': a}
self.outputs = { self.outputs = {
'Out': np.concatenate( 'Out': np.concatenate(
...@@ -285,12 +288,15 @@ class TestSumOpExclusive3(OpTest): ...@@ -285,12 +288,15 @@ class TestSumOpExclusive3(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOpExclusive4(OpTest): class TestSumOpExclusive4(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True} self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 3049)).astype("float64") a = np.random.random((1, 1, 100)).astype("float64")
self.inputs = {'X': a} self.inputs = {'X': a}
self.outputs = { self.outputs = {
'Out': np.concatenate( 'Out': np.concatenate(
...@@ -305,12 +311,15 @@ class TestSumOpExclusive4(OpTest): ...@@ -305,12 +311,15 @@ class TestSumOpExclusive4(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOpExclusive5(OpTest): class TestSumOpExclusive5(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True} self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 3096)).astype("float64") a = np.random.random((4, 5, 40)).astype("float64")
self.inputs = {'X': a} self.inputs = {'X': a}
self.outputs = { self.outputs = {
'Out': np.concatenate( 'Out': np.concatenate(
...@@ -325,12 +334,15 @@ class TestSumOpExclusive5(OpTest): ...@@ -325,12 +334,15 @@ class TestSumOpExclusive5(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOpExclusiveFP16(OpTest): class TestSumOpExclusiveFP16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"} self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"}
a = np.random.random((4, 5, 3096)).astype("float64") a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a} self.inputs = {'X': a}
self.outputs = { self.outputs = {
'Out': np.concatenate( 'Out': np.concatenate(
...@@ -345,6 +357,9 @@ class TestSumOpExclusiveFP16(OpTest): ...@@ -345,6 +357,9 @@ class TestSumOpExclusiveFP16(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOpReverseExclusive(OpTest): class TestSumOpReverseExclusive(OpTest):
def setUp(self): def setUp(self):
...@@ -366,6 +381,9 @@ class TestSumOpReverseExclusive(OpTest): ...@@ -366,6 +381,9 @@ class TestSumOpReverseExclusive(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class BadInputTest(unittest.TestCase): class BadInputTest(unittest.TestCase):
def test_error(self): def test_error(self):
...@@ -407,7 +425,6 @@ class TestTensorAxis(unittest.TestCase): ...@@ -407,7 +425,6 @@ class TestTensorAxis(unittest.TestCase):
with paddle.static.program_guard(main_prog, starup_prog): with paddle.static.program_guard(main_prog, starup_prog):
# run static # run static
x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype) x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype)
print(x)
linear = paddle.nn.Linear(np_x.shape[-1], np_x.shape[-1]) linear = paddle.nn.Linear(np_x.shape[-1], np_x.shape[-1])
linear_out = linear(x) linear_out = linear(x)
relu_out = paddle.nn.functional.relu(linear_out) relu_out = paddle.nn.functional.relu(linear_out)
...@@ -444,67 +461,5 @@ class TestTensorAxis(unittest.TestCase): ...@@ -444,67 +461,5 @@ class TestTensorAxis(unittest.TestCase):
np.testing.assert_allclose(static_out[0], infer_out) np.testing.assert_allclose(static_out[0], infer_out)
class TestCumsumDoubleGradCheck(unittest.TestCase):
def cumsum_wrapper(self, x):
return paddle.cumsum(x[0], 0)
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float64
data = layers.data('data', [3, 4], False, dtype)
data.persistable = True
out = paddle.cumsum(data, 0)
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.double_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps
)
gradient_checker.double_grad_check_for_dygraph(
self.cumsum_wrapper, [data], out, x_init=[data_arr], place=place
)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestCumsumTripleGradCheck(unittest.TestCase):
def cumsum_wrapper(self, x):
return paddle.cumsum(x[0], 0)
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float32
data = layers.data('data', [2, 3], False, dtype)
data.persistable = True
out = paddle.cumsum(data, 0)
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.triple_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps
)
gradient_checker.triple_grad_check_for_dygraph(
self.cumsum_wrapper, [data], out, x_init=[data_arr], place=place
)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -954,6 +954,34 @@ class TestSundryAPI(unittest.TestCase): ...@@ -954,6 +954,34 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
def test_cumsum(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.cumsum(x1)
out2 = paddle.cumsum(x1, axis=0)
out3 = paddle.cumsum(x1, axis=-1)
out1.retain_grads()
out2.retain_grads()
out3.retain_grads()
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(x1.grad.shape, [])
self.assertTrue(x1.grad.numpy() == 3)
self.assertEqual(out1.shape, [1])
self.assertEqual(out1.grad.shape, [1])
self.assertTrue(out1.grad.numpy() == 1)
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertTrue(out2.grad.numpy() == 1)
self.assertEqual(out3.shape, [])
self.assertEqual(out3.grad.shape, [])
self.assertTrue(out3.grad.numpy() == 1)
def test_add_n(self): def test_add_n(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
...@@ -1674,6 +1702,45 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1674,6 +1702,45 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_array_equal(out3_2, np.asarray(1)) np.testing.assert_array_equal(out3_2, np.asarray(1))
@prog_scope() @prog_scope()
def test_cumsum(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.cumsum(x1)
out2 = paddle.cumsum(x1, axis=0)
out3 = paddle.cumsum(x1, axis=-1)
paddle.static.append_backward(out1.sum())
paddle.static.append_backward(out2.sum())
paddle.static.append_backward(out3.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
out3,
x1.grad_name,
out1.grad_name,
out2.grad_name,
out3.grad_name,
],
)
self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)
self.assertEqual(res[4].shape, (1,))
self.assertEqual(res[4], 1)
self.assertEqual(res[5].shape, ())
self.assertEqual(res[5], 1)
self.assertEqual(res[6].shape, ())
self.assertEqual(res[6], 1)
self.assertEqual(out2.shape, ())
self.assertEqual(out3.shape, ())
def test_add_n(self): def test_add_n(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
......
...@@ -592,6 +592,29 @@ class TestSundryAPI(unittest.TestCase): ...@@ -592,6 +592,29 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
def test_cumsum(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.cumsum(x1)
out2 = paddle.cumsum(x1, axis=0)
out3 = paddle.cumsum(x1, axis=-1)
out1.retain_grads()
out2.retain_grads()
out3.retain_grads()
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(out1.shape, [1])
self.assertEqual(out1.grad.shape, [1])
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(out3.shape, [])
self.assertEqual(out3.grad.shape, [])
def test_add_n(self): def test_add_n(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册