From 63203c4abc2d1472c361fe7962f3885f7896559e Mon Sep 17 00:00:00 2001 From: Jack Zhou <136876878@qq.com> Date: Thu, 17 Sep 2020 15:47:11 +0800 Subject: [PATCH] enhance reduce op which can reduce tensor with arbitrary rank enhance reduce op which can reduce tensor with arbitrary rank --- paddle/fluid/operators/math/math_function.cc | 51 ++++ paddle/fluid/operators/math/math_function.cu | 99 +++++++- paddle/fluid/operators/math/math_function.h | 8 + paddle/fluid/operators/reduce_ops/reduce_op.h | 197 ++++++++++++---- paddle/fluid/operators/transpose_op.h | 7 +- paddle/fluid/platform/device_context.cc | 17 ++ paddle/fluid/platform/device_context.h | 7 + .../fluid/tests/unittests/test_reduce_op.py | 219 +++++++++++++++++- .../tests/unittests/test_transpose_op.py | 12 + 9 files changed, 558 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index f44b33fcf2f..b8af5a21ca5 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -22,10 +22,12 @@ limitations under the License. */ #include #endif +#include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/float16.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { @@ -63,6 +65,55 @@ DEFINE_CPU_TRANS(4); DEFINE_CPU_TRANS(5); DEFINE_CPU_TRANS(6); +template +struct TransposeNormal { + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = framework::stride(in.dims()); + auto out_stride = framework::stride(out->dims()); + const T* in_ptr = in.data(); + T* out_ptr = out->data(); + + auto transpose_helper = [&](int64_t beg, int64_t end) { + for (int64_t out_idx = beg; out_idx < end; ++out_idx) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + // calculate the input index + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride[i]; + tmp_idx -= coordinate * out_stride[i]; + in_idx += coordinate * in_stride[axis[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } + }; + double cost_per_iteration = + rank * (Eigen::TensorOpCost::DivCost() + + 2 * Eigen::TensorOpCost::MulCost() + + 2 * Eigen::TensorOpCost::AddCost()); + Eigen::TensorOpCost cost(sizeof(T), sizeof(T), cost_per_iteration); + auto* cpu_device = context.eigen_pool_device(); + cpu_device->parallelFor(out->numel(), cost, std::move(transpose_helper)); + } +}; + +// define transpose normal +#define DEFINE_CPU_TRANS_NORMAL(TYPE) \ + template struct TransposeNormal + +DEFINE_CPU_TRANS_NORMAL(platform::float16); +DEFINE_CPU_TRANS_NORMAL(platform::bfloat16); +DEFINE_CPU_TRANS_NORMAL(float); +DEFINE_CPU_TRANS_NORMAL(double); +DEFINE_CPU_TRANS_NORMAL(int); +DEFINE_CPU_TRANS_NORMAL(int64_t); +DEFINE_CPU_TRANS_NORMAL(bool); +DEFINE_CPU_TRANS_NORMAL(int16_t); +DEFINE_CPU_TRANS_NORMAL(uint8_t); +DEFINE_CPU_TRANS_NORMAL(int8_t); + struct TensorSetConstantCPU { TensorSetConstantCPU(framework::Tensor* tensor, float value) : tensor_(tensor), value_(value) {} diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 1c519d226eb..4d7c1a49286 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -11,8 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" @@ -23,6 +26,7 @@ namespace operators { namespace math { using float16 = paddle::platform::float16; +using bfloat16 = paddle::platform::bfloat16; template struct SetConstant; template struct SetConstant; @@ -31,12 +35,13 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; -#define DEFINE_GPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ template struct Transpose; DEFINE_GPU_TRANS(1); @@ -46,6 +51,88 @@ DEFINE_GPU_TRANS(4); DEFINE_GPU_TRANS(5); DEFINE_GPU_TRANS(6); +#define REINTERPRET(T, DST_PTR, SRC_PTR) \ + T* DST_PTR = reinterpret_cast(SRC_PTR) + +template +__global__ void TransposeNormalKernel(const T* in_ptr, T* out_ptr, + int64_t element, + const int64_t* in_stride_ptr, + const int64_t* out_stride_ptr, + const int64_t* axis_ptr, int rank) { + CUDA_KERNEL_LOOP(out_idx, element) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride_ptr[i]; + tmp_idx -= coordinate * out_stride_ptr[i]; + in_idx += coordinate * in_stride_ptr[axis_ptr[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } +} + +template +struct TransposeNormal { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = framework::stride(in.dims()); + auto out_stride = framework::stride(out->dims()); + auto* in_ptr = in.data(); + auto* out_ptr = out->data(); + + // copy in_stride, out_stride, axis to gpu device + const platform::CUDAPlace& cuda_place = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); + platform::CPUPlace cpu_place = platform::CPUPlace(); + size_t size = 3 * rank * sizeof(int64_t); + auto cpu_buf_holder = memory::AllocShared(cpu_place, size); + auto cuda_buf_holder = memory::AllocShared(cuda_place, size); + REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr()); + REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr()); + for (int i = 0; i < rank; ++i) { + cpu_buf[i] = in_stride[i]; + cpu_buf[rank + i] = out_stride[i]; + cpu_buf[2 * rank + i] = axis[i]; + } + memory::Copy(cuda_place, cuda_buf, cpu_place, cpu_buf, size, + context.stream()); + REINTERPRET(const int64_t, in_stride_ptr, cuda_buf); + REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank); + REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank); + + const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock(); + const int MAX_GRID_DIM = + context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; + int64_t elements = in.numel(); + int block_size = (elements >= MAX_BLOCK_DIM) + ? MAX_BLOCK_DIM + : (1 << static_cast(std::log2(elements))); + int grid_size = elements / block_size; + grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size; + TransposeNormalKernel<<>>( + in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, + rank); + } +}; + +// define transpose normal +#define DEFINE_GPU_TRANS_NORMAL(TYPE) \ + template struct TransposeNormal + +DEFINE_GPU_TRANS_NORMAL(float16); +DEFINE_GPU_TRANS_NORMAL(bfloat16); +DEFINE_GPU_TRANS_NORMAL(float); +DEFINE_GPU_TRANS_NORMAL(double); +DEFINE_GPU_TRANS_NORMAL(int); +DEFINE_GPU_TRANS_NORMAL(int64_t); +DEFINE_GPU_TRANS_NORMAL(bool); +DEFINE_GPU_TRANS_NORMAL(int16_t); +DEFINE_GPU_TRANS_NORMAL(uint8_t); +DEFINE_GPU_TRANS_NORMAL(int8_t); + struct TensorSetConstantGPU { TensorSetConstantGPU(const platform::DeviceContext& context, framework::Tensor* tensor, float value) diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index 333552a0c1a..6af0278d825 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -26,6 +26,14 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { + +template +struct TransposeNormal { + // for dims >= 7 situation + void operator()(const DeviceContext& context, const framework::Tensor& in, + framework::Tensor* out, const std::vector& axis); +}; + template struct Transpose { void operator()(const DeviceContext& context, const framework::Tensor& in, diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 67a19cb83c3..25f9453571a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -18,9 +18,10 @@ limitations under the License. */ #include #include #include - #include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/cast_op.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" namespace paddle { @@ -34,6 +35,110 @@ namespace operators { } using Tensor = framework::Tensor; +using DDim = framework::DDim; + +inline void GetShuffledDim(const DDim& src_dims, DDim* dst_dims, + const std::vector& reduced_dims, + std::vector* perm_axis) { + // check if it's a reduced dim + std::vector src_dims_check(src_dims.size(), false); + size_t src_size = src_dims.size(); + size_t reduce_size = reduced_dims.size(); + for (size_t i = 0; i < reduce_size; ++i) { + dst_dims->at(src_size - reduce_size + i) = src_dims[reduced_dims[i]]; + (*perm_axis)[src_size - reduce_size + i] = reduced_dims[i]; + src_dims_check[reduced_dims[i]] = true; + } + + size_t offset = 0; + for (size_t i = 0; i < src_dims_check.size(); ++i) { + bool is_reduced = src_dims_check[i]; + if (!is_reduced) { + (*perm_axis)[offset] = i; + dst_dims->at(offset++) = src_dims[i]; + } + } +} + +template +void GetShuffledInput(const framework::ExecutionContext& context, + const Tensor* input, Tensor* shuffled_input, + const std::vector& dims) { + DDim shuffled_dims(input->dims()); + std::vector perm_axis(input->dims().size()); + GetShuffledDim(input->dims(), &shuffled_dims, dims, &perm_axis); + + shuffled_input->Resize(shuffled_dims); + shuffled_input->mutable_data(context.GetPlace()); + + math::TransposeNormal trans; + trans(context.template device_context(), *input, + shuffled_input, perm_axis); +} + +inline void GetOriginDimFromShuffled(const DDim& src_dim, + const std::vector& dims, + std::vector* origin_dim) { + DDim shuffled_dims(src_dim); + size_t n = src_dim.size(); + std::vector perm_axis(n); + GetShuffledDim(src_dim, &shuffled_dims, dims, &perm_axis); + for (size_t i = 0; i < n; ++i) { + (*origin_dim)[perm_axis[i]] = i; + } +} + +template +void HandleLargeDim(const framework::ExecutionContext& context, + const Tensor* input, Tensor* output, + const std::vector& dims, bool keep_dim) { + // shuffle the reduced dim to the end + Tensor shuffled_input; + GetShuffledInput(context, input, &shuffled_input, dims); + + // transpose to 2D tensor whose shape is {unreduced, reduced}. + const int64_t unreduced = output->numel(); + const int64_t reduced = shuffled_input.numel() / unreduced; + shuffled_input.Resize({unreduced, reduced}); + DDim output_dim = output->dims(); + output->Resize({unreduced}); + ReduceFunctor( + context.template device_context(), shuffled_input, output, + {1}, keep_dim); + output->Resize(output_dim); +} + +template +void HandleLargeDimGrad(const framework::ExecutionContext& context, + const framework::Tensor* x, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + const std::vector& dims) { + const int64_t unreduced = out->numel(); + const int64_t reduced = x->numel() / unreduced; + DDim out_dim(out->dims()); + DDim x_dim(x->dims()); + // transpose and reshape X + Tensor shuffled_x; + GetShuffledInput(context, x, &shuffled_x, dims); + DDim shuffled_dim = shuffled_x.dims(); + shuffled_x.Resize({unreduced, reduced}); + // reshape dX {unreduced, reduced} + dx->Resize({unreduced, reduced}); + ReduceGradFunctor( + context.template device_context(), shuffled_x, *out, *dout, + dx, {1}); + // transpose dX + std::vector origin_axis(x_dim.size()); + GetOriginDimFromShuffled(x_dim, dims, &origin_axis); + Tensor dx_tmp; + framework::TensorCopy(*dx, context.GetPlace(), &dx_tmp); + dx_tmp.Resize(shuffled_dim); + dx->Resize(x_dim); + math::TransposeNormal trans; + trans(context.template device_context(), dx_tmp, dx, + origin_axis); +} template struct ReduceKernelFunctor { @@ -69,22 +174,27 @@ struct ReduceKernelFunctor { } else { int ndim = input->dims().size(); int rdim = dims.size(); - HANDLE_DIM(6, 5); - HANDLE_DIM(6, 4); - HANDLE_DIM(6, 3); - HANDLE_DIM(6, 2); - HANDLE_DIM(6, 1); - HANDLE_DIM(5, 4); - HANDLE_DIM(5, 3); - HANDLE_DIM(5, 2); - HANDLE_DIM(5, 1); - HANDLE_DIM(4, 3); - HANDLE_DIM(4, 2); - HANDLE_DIM(4, 1); - HANDLE_DIM(3, 2); - HANDLE_DIM(3, 1); - HANDLE_DIM(2, 1); - HANDLE_DIM(1, 1); + if (ndim > 6) { + HandleLargeDim(context, input, output, + dims, keep_dim); + } else { + HANDLE_DIM(6, 5); + HANDLE_DIM(6, 4); + HANDLE_DIM(6, 3); + HANDLE_DIM(6, 2); + HANDLE_DIM(6, 1); + HANDLE_DIM(5, 4); + HANDLE_DIM(5, 3); + HANDLE_DIM(5, 2); + HANDLE_DIM(5, 1); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); + } } } }; @@ -137,7 +247,6 @@ class ReduceKernel : public framework::OpKernel { } } }; - template class BoolReduceKernel : public framework::OpKernel { public: @@ -175,22 +284,27 @@ class BoolReduceKernel : public framework::OpKernel { int ndim = input->dims().size(); int rdim = dims.size(); // comments for accelerating compiling temporarily. - // HANDLE_DIM(6, 5); - // HANDLE_DIM(6, 4); - // HANDLE_DIM(6, 3); - // HANDLE_DIM(6, 2); - // HANDLE_DIM(6, 1); - // HANDLE_DIM(5, 4); - // HANDLE_DIM(5, 3); - // HANDLE_DIM(5, 2); - // HANDLE_DIM(5, 1); - HANDLE_DIM(4, 3); - HANDLE_DIM(4, 2); - HANDLE_DIM(4, 1); - HANDLE_DIM(3, 2); - HANDLE_DIM(3, 1); - HANDLE_DIM(2, 1); - HANDLE_DIM(1, 1); + if (ndim > 6) { + HandleLargeDim(context, input, output, + dims, keep_dim); + } else { + HANDLE_DIM(6, 5); + HANDLE_DIM(6, 4); + HANDLE_DIM(6, 3); + HANDLE_DIM(6, 2); + HANDLE_DIM(6, 1); + HANDLE_DIM(5, 4); + HANDLE_DIM(5, 3); + HANDLE_DIM(5, 2); + HANDLE_DIM(5, 1); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); + } } } }; @@ -279,6 +393,10 @@ class ReduceGradKernel : public framework::OpKernel { context.template device_context(), *input0, *input1, *input2, output, dims); break; + default: + HandleLargeDimGrad(context, input0, input1, + input2, output, dims); + break; } } } @@ -313,12 +431,6 @@ class ReduceOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp"); auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 6, - platform::errors::InvalidArgument( - "The input tensor X's dimensions of ReduceOp " - "should be less equal than 6. But received X's " - "dimensions = %d, X's shape = [%s].", - x_rank, x_dims)); auto dims = ctx->Attrs().Get>("dim"); PADDLE_ENFORCE_GT(dims.size(), 0, platform::errors::InvalidArgument( @@ -402,11 +514,6 @@ class ReduceGradOp : public framework::OperatorWithKernel { "Out@GRAD", "ReduceOp"); auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 6, - platform::errors::InvalidArgument( - "Tensors with rank at most 6 are supported by " - "ReduceOp. Received tensor with rank %d.", - x_rank)); auto dims = ctx->Attrs().Get>("dim"); for (size_t i = 0; i < dims.size(); ++i) { PADDLE_ENFORCE_LT(dims[i], x_rank, diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/fluid/operators/transpose_op.h index d7f5c3dd457..e4e5dfdba9f 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/fluid/operators/transpose_op.h @@ -53,10 +53,9 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx, trans6(dev_ctx, in, out, axis); break; default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Tensors with rank at most 6 are supported" - ", but received input tensor's rank is %d,", - dim)); + // for dim >= 7 situation + math::TransposeNormal trans_normal; + trans_normal(dev_ctx, in, out, axis); } } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 29982c13c8c..34305c404b4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include #include +#include //NOLINT #include #include @@ -23,6 +24,7 @@ limitations under the License. */ #endif #include "glog/logging.h" +#include "unsupported/Eigen/CXX11/ThreadPool" namespace paddle { namespace memory { @@ -131,16 +133,31 @@ DeviceContextPool::DeviceContextPool( CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); + InitPoolDevice(); } CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) { eigen_device_.reset(new Eigen::DefaultDevice()); + InitPoolDevice(); +} + +void CPUDeviceContext::InitPoolDevice() { + using EigenEnv = Eigen::StlThreadEnvironment; + using EigenThreadPool = Eigen::ThreadPoolTempl; + int num_threads = std::thread::hardware_concurrency(); + eigen_threadpool_.reset(new EigenThreadPool(num_threads)); + eigen_pool_device_.reset( + new Eigen::ThreadPoolDevice(eigen_threadpool_.get(), num_threads)); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { return eigen_device_.get(); } +Eigen::ThreadPoolDevice* CPUDeviceContext::eigen_pool_device() const { + return eigen_pool_device_.get(); +} + Place CPUDeviceContext::GetPlace() const { return place_; } #ifdef PADDLE_WITH_XPU diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 8bfdfc8a1c6..28d94627f95 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -41,6 +41,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/stream/cuda_stream.h" #endif +#define EIGEN_USE_THREADS #include "unsupported/Eigen/CXX11/Tensor" #ifdef PADDLE_WITH_XPU @@ -65,11 +66,17 @@ class CPUDeviceContext : public DeviceContext { Eigen::DefaultDevice* eigen_device() const; + Eigen::ThreadPoolDevice* eigen_pool_device() const; + Place GetPlace() const override; + inline void InitPoolDevice(); + private: CPUPlace place_; std::unique_ptr eigen_device_; + std::unique_ptr eigen_pool_device_; + std::unique_ptr eigen_threadpool_; }; template diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index b0b85f633a2..80b201d0842 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -67,6 +67,22 @@ class TestSumOp6D(OpTest): self.check_grad(['X'], 'Out') +class TestSumOp8D(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float64") + } + self.attrs = {'dim': (0, 3)} + self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") @@ -103,6 +119,40 @@ class TestMinOp(OpTest): self.check_output() +class TestMin6DOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.inputs = { + 'X': np.random.random((2, 4, 3, 5, 6, 10)).astype("float64") + } + self.attrs = {'dim': [2, 4]} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + + +class TestMin8DOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.inputs = { + 'X': np.random.random((2, 4, 3, 5, 6, 3, 2, 4)).astype("float64") + } + self.attrs = {'dim': [2, 3, 4]} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + + class TestProdOp(OpTest): def setUp(self): self.op_type = "reduce_prod" @@ -116,6 +166,42 @@ class TestProdOp(OpTest): self.check_grad(['X'], 'Out') +class TestProd6DOp(OpTest): + def setUp(self): + self.op_type = "reduce_prod" + self.inputs = { + 'X': np.random.random((5, 6, 2, 3, 4, 2)).astype("float64") + } + self.attrs = {'dim': [2, 3, 4]} + self.outputs = { + 'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestProd8DOp(OpTest): + def setUp(self): + self.op_type = "reduce_prod" + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64") + } + self.attrs = {'dim': [2, 3, 4]} + self.outputs = { + 'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + class TestAllOp(OpTest): def setUp(self): self.op_type = "reduce_all" @@ -127,12 +213,40 @@ class TestAllOp(OpTest): self.check_output() +class TestAll8DOp(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.attrs = {'reduce_all': True, 'dim': (2, 3, 4)} + self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + class TestAllOpWithDim(OpTest): def setUp(self): self.op_type = "reduce_all" self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} - self.attrs = {'dim': [1]} - self.outputs = {'Out': self.inputs['X'].all(axis=1)} + self.attrs = {'dim': (1, )} + self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + +class TestAll8DOpWithDim(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.attrs = {'dim': (1, 3, 4)} + self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): self.check_output() @@ -152,6 +266,23 @@ class TestAllOpWithKeepDim(OpTest): self.check_output() +class TestAll8DOpWithKeepDim(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.attrs = {'dim': (5, ), 'keep_dim': True} + self.outputs = { + 'Out': np.expand_dims( + self.inputs['X'].all(axis=self.attrs['dim']), axis=5) + } + + def test_check_output(self): + self.check_output() + + class TestAllOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -175,6 +306,20 @@ class TestAnyOp(OpTest): self.check_output() +class TestAny8DOp(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.attrs = {'reduce_all': True, 'dim': (3, 5, 4)} + self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + class TestAnyOpWithDim(OpTest): def setUp(self): self.op_type = "reduce_any" @@ -186,14 +331,45 @@ class TestAnyOpWithDim(OpTest): self.check_output() +class TestAny8DOpWithDim(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.attrs = {'dim': (3, 6)} + self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output() + + class TestAnyOpWithKeepDim(OpTest): def setUp(self): self.op_type = "reduce_any" self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} - self.attrs = {'dim': [1], 'keep_dim': True} + self.attrs = {'dim': (1, ), 'keep_dim': True} + self.outputs = { + 'Out': np.expand_dims( + self.inputs['X'].any(axis=self.attrs['dim']), axis=1) + } + + def test_check_output(self): + self.check_output() + + +class TestAny8DOpWithKeepDim(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.attrs = {'dim': (1, ), 'keep_dim': True} self.outputs = { 'Out': np.expand_dims( - self.inputs['X'].any(axis=1), axis=1) + self.inputs['X'].any(axis=self.attrs['dim']), axis=1) } def test_check_output(self): @@ -283,6 +459,18 @@ class Test3DReduce3(Test1DReduce): } +class Test8DReduce0(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {'dim': (4, 2, 3)} + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64") + } + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + + class TestKeepDimReduce(Test1DReduce): def setUp(self): self.op_type = "reduce_sum" @@ -294,6 +482,19 @@ class TestKeepDimReduce(Test1DReduce): } +class TestKeepDim8DReduce(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64") + } + self.attrs = {'dim': (3, 4, 5), 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), + keepdims=self.attrs['keep_dim']) + } + + class TestReduceAll(Test1DReduce): def setUp(self): self.op_type = "reduce_sum" @@ -302,6 +503,16 @@ class TestReduceAll(Test1DReduce): self.outputs = {'Out': self.inputs['X'].sum()} +class TestReduceAll(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64") + } + self.attrs = {'reduce_all': True, 'dim': (3, 4, 5)} + self.outputs = {'Out': self.inputs['X'].sum(axis=self.attrs['dim'])} + + @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index d5d1fdc5b20..56333211469 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -99,6 +99,18 @@ class TestCase7(TestTransposeOp): self.axis = (0, 1, 3, 2) +class TestCase8(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 2, 3, 2, 4, 3, 3) + self.axis = (0, 1, 3, 2, 4, 5, 6, 7) + + +class TestCase9(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 2, 3, 2, 4, 3, 3) + self.axis = (6, 1, 3, 5, 0, 2, 4, 7) + + class TestTransposeOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): -- GitLab