From 4e509f461782f2d2fae833b5f5e57566c352ab0f Mon Sep 17 00:00:00 2001 From: hlygit66666 <32728786+hlygit66666@users.noreply.github.com> Date: Fri, 10 Sep 2021 15:41:27 +0800 Subject: [PATCH] add cumprod op (#35185) * add test_cumprod_op * Revert "add test_cumprod_op" This reverts commit c96cf6dff5d09ae7d8cc72c1e8ae4369a153aa19. * recommit * add error message * test input(x) initialize * test use cpu * update test code * add test type * add test case * solve ci problem * add complex case test * add complex case test * fix review problem * fix conflict * fix some docs * change test case * change test case * fix review problems again * fix docs * fix inclusivescan bug --- paddle/fluid/framework/data_type.h | 6 + paddle/fluid/operators/cumprod_op.cc | 102 +++++ paddle/fluid/operators/cumprod_op.cu | 369 ++++++++++++++++++ paddle/fluid/operators/cumprod_op.h | 170 ++++++++ paddle/fluid/operators/math/inclusive_scan.h | 246 ++++++++++++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_cumprod_op.py | 196 ++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 60 +++ 10 files changed, 1154 insertions(+) create mode 100644 paddle/fluid/operators/cumprod_op.cc create mode 100644 paddle/fluid/operators/cumprod_op.cu create mode 100644 paddle/fluid/operators/cumprod_op.h create mode 100644 paddle/fluid/operators/math/inclusive_scan.h create mode 100644 python/paddle/fluid/tests/unittests/test_cumprod_op.py diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index a16f35dc11..72ee126e13 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -37,6 +37,12 @@ struct complex; namespace paddle { namespace framework { +template +struct IsComplex : public std::false_type {}; + +template +struct IsComplex> : public std::true_type {}; + template struct DataTypeTrait {}; diff --git a/paddle/fluid/operators/cumprod_op.cc b/paddle/fluid/operators/cumprod_op.cc new file mode 100644 index 0000000000..bff6673429 --- /dev/null +++ b/paddle/fluid/operators/cumprod_op.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/cumprod_op.h" + +namespace paddle { +namespace operators { + +class CumprodOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cumprod"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cumprod"); + + ctx->ShareDim("X", "Out"); + ctx->ShareLoD("X", "Out"); + } +}; + +class CumprodOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of cumprod op."); + AddOutput("Out", "(Tensor), The output tensor of cumprod op."); + AddAttr( + "dim", + "(int), The dim along which the input tensors will be cumproded"); + AddComment( + R"DOC(Cumprod operator. Return the cumprod results of the input elements along the dim. + For example, if input X is a tensor with rank 1 and N elements, the output will also be a tensor + with rank 1 and N elements, and elements y[i] = x[0] * x[1] * x[2] *...* x[i] (0<=i +class CumprodGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("cumprod_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +class CumprodGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CumprodGrad"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CumprodGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "framework::GradVarName(\"Out\")", "CumprodGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "framework::GradVarName(\"X\")", "CumprodGrad"); + ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X")); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker, + ops::CumprodGradOpMaker, + ops::CumprodGradOpMaker); + +REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp); + +REGISTER_OP_CPU_KERNEL( + cumprod, ops::CumprodOpCPUKernel, ops::CumprodOpCPUKernel, + ops::CumprodOpCPUKernel, ops::CumprodOpCPUKernel, + ops::CumprodOpCPUKernel>, + ops::CumprodOpCPUKernel>); + +REGISTER_OP_CPU_KERNEL( + cumprod_grad, ops::CumprodGradOpCPUKernel, + ops::CumprodGradOpCPUKernel, ops::CumprodGradOpCPUKernel, + ops::CumprodGradOpCPUKernel, + ops::CumprodGradOpCPUKernel>, + ops::CumprodGradOpCPUKernel>); diff --git a/paddle/fluid/operators/cumprod_op.cu b/paddle/fluid/operators/cumprod_op.cu new file mode 100644 index 0000000000..82ed0bd444 --- /dev/null +++ b/paddle/fluid/operators/cumprod_op.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/cumprod_op.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/inclusive_scan.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct MultiplyFunctor { + HOSTDEVICE T operator()(T a, T b) const { return a * b; } +}; + +template +class CumprodOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Out"); + auto dim = ctx.Attr("dim"); + size_t outer_dim, mid_dim, inner_dim; + GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim); + + const auto *x_data = x->data(); + auto *y_data = y->mutable_data(ctx.GetPlace()); + const auto &dev_ctx = + ctx.template device_context(); + math::InclusiveScan>( + x_data, y_data, outer_dim, mid_dim, inner_dim, static_cast(1), + MultiplyFunctor(), /*reverse=*/false, dev_ctx); + } +}; + +template +struct IsZeroFunctor { + HOSTDEVICE bool operator()(T x) const { return x == static_cast(0); } +}; + +template +struct CumprodGradFunctorExceptFirstZero { + HOSTDEVICE CumprodGradFunctorExceptFirstZero( + const T *x, const T *y, const T *dy_mul_y_reversed_cumsum, + const uint8_t *zero_mask, size_t mid_dim, size_t inner_dim, T *dx, + int64_t *first_zero_idx, T *x_filled_one) + : x_(x), + y_(y), + dy_mul_y_reversed_cumsum_(dy_mul_y_reversed_cumsum), + zero_mask_(zero_mask), + mid_dim_(mid_dim), + inner_dim_(inner_dim), + dx_(dx), + first_zero_idx_(first_zero_idx), + x_filled_one_(x_filled_one) {} + + HOSTDEVICE void operator()(size_t idx) const { + auto inner_idx = idx % inner_dim_; + auto outer_idx = idx / (mid_dim_ * inner_dim_); + auto mid_idx = (idx - inner_idx) / inner_dim_ % mid_dim_; + auto mask = zero_mask_[idx]; + bool should_fill_one = true; + + if (mask == 0) { + dx_[idx] = dy_mul_y_reversed_cumsum_[idx] / x_[idx]; + if (mid_idx == mid_dim_ - 1) { + // record first zero position as -1, i.e., no zero + first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = -1; + } + } else if (mid_idx > 0) { // mask > 0 + if (zero_mask_[idx - inner_dim_] > 0) { // not first zero + dx_[idx] = 0; + should_fill_one = false; + } else { + // idx is the first zero position, it should be recorded + dx_[idx] = y_[idx - inner_dim_]; + first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = mid_idx; + } + } else { // the first zero position is index 0 + dx_[idx] = 1; + first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0; + } + + x_filled_one_[idx] = should_fill_one ? 1 : x_[idx]; + } + + private: + const T *x_; + const T *y_; + const T *dy_mul_y_reversed_cumsum_; + const uint8_t *zero_mask_; + size_t mid_dim_; + size_t inner_dim_; + T *dx_; + int64_t *first_zero_idx_; + T *x_filled_one_; +}; + +template +struct FillFirstZeroPositionGradFunctor { + HOSTDEVICE FillFirstZeroPositionGradFunctor(const int64_t *first_zero_idx, + const T *grad_value, + size_t mid_dim, size_t inner_dim, + T *dx) + : first_zero_idx_(first_zero_idx), + grad_value_(grad_value), + mid_dim_(mid_dim), + inner_dim_(inner_dim), + dx_(dx) {} + + HOSTDEVICE void operator()(size_t idx) const { + auto outer_idx = idx / inner_dim_; + auto inner_idx = idx % inner_dim_; + auto mid_idx = first_zero_idx_[idx]; + if (mid_idx >= 0) { + auto full_idx = + outer_idx * mid_dim_ * inner_dim_ + mid_idx * inner_dim_ + inner_idx; + dx_[full_idx] *= grad_value_[full_idx]; + } + } + + private: + const int64_t *first_zero_idx_; + const T *grad_value_; + size_t mid_dim_; + size_t inner_dim_; + T *dx_; +}; + +/* +Reference to +https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp +input: x, y, dL/dy +output: dL/dx +dL/dx[i] = sum{0<=j k, dL/dx[i] = 0; +i < k, dL/dx[i] = 1/x[i]*sum{i<=j k + dx[i] = 0; + x_filled_one[i] = x[i]; + } + } + } +} +T = reversed_cumsum(dy[j]*cumprod(x_filled_one[j])); +if (zero_index != -1) { + dx[zero_index] *= T[zero_index]; +} +*/ + +template +class CumprodGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *y = ctx.Input("Out"); + const auto *dy = + ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto dim = ctx.Attr("dim"); + + size_t outer_dim, mid_dim, inner_dim; + GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim); + if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return; + + size_t numel = outer_dim * mid_dim * inner_dim; + + const auto *x_data = x->data(); + const auto *y_data = y->data(); + const auto *dy_data = dy->data(); + + auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + const auto &dev_ctx = + ctx.template device_context(); + auto *dx_data = dx->mutable_data(place); + + // deal with complex + const T *x_data_deal; + const T *y_data_deal; + memory::AllocationPtr x_conj; + memory::AllocationPtr y_conj; + if (framework::IsComplex::value) { + x_conj = memory::Alloc(place, numel * sizeof(T)); + auto *x_data_conj = reinterpret_cast(x_conj->ptr()); + y_conj = memory::Alloc(place, numel * sizeof(T)); + auto *y_data_conj = reinterpret_cast(y_conj->ptr()); + + platform::ForRange for_range_x(dev_ctx, + numel); + math::ConjFunctor functor_x(x_data, numel, x_data_conj); + for_range_x(functor_x); + + platform::ForRange for_range_y(dev_ctx, + numel); + math::ConjFunctor functor_y(y_data, numel, y_data_conj); + for_range_y(functor_y); + x_data_deal = x_data_conj; + y_data_deal = y_data_conj; + } else { + x_data_deal = x_data; + y_data_deal = y_data; + } + +// Step 1: find cummax-ed zero mask of x +#ifdef PADDLE_WITH_CUDA + const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream()); +#else + const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); +#endif + auto zero_mask_without_cummax = + memory::Alloc(place, numel * sizeof(uint8_t)); + auto *zero_mask_without_cummax_data = + reinterpret_cast(zero_mask_without_cummax->ptr()); + thrust::transform( + exec_policy, thrust::device_pointer_cast(x_data_deal), + thrust::device_pointer_cast(x_data_deal) + numel, + thrust::device_pointer_cast(zero_mask_without_cummax_data), + IsZeroFunctor()); + + auto zero_mask = memory::Alloc(place, numel * sizeof(uint8_t)); + auto *zero_mask_data = reinterpret_cast(zero_mask->ptr()); + math::InclusiveScan( + zero_mask_without_cummax_data, zero_mask_data, outer_dim, mid_dim, + inner_dim, static_cast(0), cub::Max(), /*reverse=*/false, + dev_ctx); + zero_mask_without_cummax = nullptr; + + // Step 2: calculate reversed cumsum(dy * y) + auto dy_mul_y = memory::Alloc(place, numel * sizeof(T)); + auto *dy_mul_y_data = reinterpret_cast(dy_mul_y->ptr()); + thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data), + thrust::device_pointer_cast(dy_data) + numel, + thrust::device_pointer_cast(y_data_deal), + thrust::device_pointer_cast(dy_mul_y_data), + MultiplyFunctor()); + + auto dy_mul_y_reversed_cumsum = memory::Alloc(place, numel * sizeof(T)); + auto *dy_mul_y_reversed_cumsum_data = + reinterpret_cast(dy_mul_y_reversed_cumsum->ptr()); + math::InclusiveScan( + dy_mul_y_data, dy_mul_y_reversed_cumsum_data, outer_dim, mid_dim, + inner_dim, static_cast(0), cub::Sum(), /*reverse=*/true, dev_ctx); + + // Step 3: calculate the gradient value except the first zero position. + // The gradient value of the first zero position is filled with out[idx-1], + // while the gradient value of the other positions are calculated out + // completely. This functor also: + // (1) find the first zero index, i.e., first_zero_idx_data. + // (2) fill x_filled_one, which satifies + // x_filled_one[i] = x[i], i > pos + // x_filled_one[i] = 1, i <= pos + auto first_zero_idx = + memory::Alloc(place, outer_dim * inner_dim * sizeof(int64_t)); + auto *first_zero_idx_data = + reinterpret_cast(first_zero_idx->ptr()); + auto *x_filled_one_data = dy_mul_y_data; // reuse former allocated memory + platform::ForRange for_range(dev_ctx, numel); + CumprodGradFunctorExceptFirstZero functor_except_first_zero( + x_data_deal, y_data_deal, dy_mul_y_reversed_cumsum_data, zero_mask_data, + mid_dim, inner_dim, dx_data, first_zero_idx_data, x_filled_one_data); + for_range(functor_except_first_zero); + + // Step 4: calculate cumprod of x_filled_one + auto *x_filled_one_cumprod_data = + dy_mul_y_reversed_cumsum_data; // reuse former allocated memory + math::InclusiveScan>( + x_filled_one_data, x_filled_one_cumprod_data, outer_dim, mid_dim, + inner_dim, static_cast(1), MultiplyFunctor(), /*reverse=*/false, + dev_ctx); + + // Step 5: calculate reversed cumsum(dy * x_filled_one_cumprod) + auto *dy_mul_x_filled_one_cumprod = + dy_mul_y_data; // reuse former allocated memory + thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data), + thrust::device_pointer_cast(dy_data) + numel, + thrust::device_pointer_cast(x_filled_one_cumprod_data), + thrust::device_pointer_cast(dy_mul_x_filled_one_cumprod), + MultiplyFunctor()); + auto *dy_mul_x_filled_one_cumprod_reversed_cumsum = + dy_mul_y_reversed_cumsum_data; // reuse former allocated memory + math::InclusiveScan( + dy_mul_x_filled_one_cumprod, + dy_mul_x_filled_one_cumprod_reversed_cumsum, outer_dim, mid_dim, + inner_dim, static_cast(0), cub::Sum(), + /*reverse=*/true, dev_ctx); + + // Step 6: fill zero pos gradient value + platform::ForRange + for_range_fill_zero_pos_grad(dev_ctx, outer_dim * inner_dim); + FillFirstZeroPositionGradFunctor fill_first_zero_pos_grad_functor( + first_zero_idx_data, dy_mul_x_filled_one_cumprod_reversed_cumsum, + mid_dim, inner_dim, dx_data); + for_range_fill_zero_pos_grad(fill_first_zero_pos_grad_functor); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + cumprod, ops::CumprodOpCUDAKernel, ops::CumprodOpCUDAKernel, + ops::CumprodOpCUDAKernel, ops::CumprodOpCUDAKernel, + ops::CumprodOpCUDAKernel>, + ops::CumprodOpCUDAKernel>); + +REGISTER_OP_CUDA_KERNEL( + cumprod_grad, ops::CumprodGradOpCUDAKernel, + ops::CumprodGradOpCUDAKernel, ops::CumprodGradOpCUDAKernel, + ops::CumprodGradOpCUDAKernel, + ops::CumprodGradOpCUDAKernel>, + ops::CumprodGradOpCUDAKernel>); diff --git a/paddle/fluid/operators/cumprod_op.h b/paddle/fluid/operators/cumprod_op.h new file mode 100644 index 0000000000..a964cfb3d7 --- /dev/null +++ b/paddle/fluid/operators/cumprod_op.h @@ -0,0 +1,170 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +static void GetCumprodDimInfo(const framework::DDim& dim, int cumprod_dim, + size_t* outer_dim, size_t* mid_dim, + size_t* inner_dim) { + PADDLE_ENFORCE_GE( + cumprod_dim, -dim.size(), + platform::errors::InvalidArgument( + "The input dim of CumprodOp should be larger than the opposite " + "rank of input x which is %d.But received dim=%d", + -dim.size(), cumprod_dim)); + PADDLE_ENFORCE_LT(cumprod_dim, dim.size(), + platform::errors::InvalidArgument( + "The input dim of CumprodOp should be smaller than the " + "rank of input x which is %d.But received dim=%d", + dim.size(), cumprod_dim)); + if (cumprod_dim < 0) cumprod_dim += dim.size(); + + *outer_dim = 1; + for (int i = 0; i < cumprod_dim; ++i) { + *outer_dim *= dim[i]; + } + *mid_dim = dim[cumprod_dim]; + *inner_dim = 1; + for (int i = cumprod_dim + 1; i < dim.size(); ++i) { + *inner_dim *= dim[i]; + } +} + +template +class CumprodOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + int dim = context.Attr("dim"); + + auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + framework::DDim shape = x->dims(); + + size_t outer_dim = 1; + size_t mid_dim = 1; + size_t inner_dim = 1; + GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim); + + for (size_t i = 0; i < outer_dim; i++) { + for (size_t j = 0; j < mid_dim; j++) { + for (size_t k = 0; k < inner_dim; k++) { + size_t pos = i * mid_dim * inner_dim + j * inner_dim + k; + if (j == 0) { + out_data[pos] = x_data[pos]; + } else { + out_data[pos] = out_data[pos - inner_dim] * x_data[pos]; + } + } + } + } + } +}; + +template +class CumprodGradOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const { + const Tensor* d_out = context.Input(framework::GradVarName("Out")); + const Tensor* x = context.Input("X"); + const Tensor* out = context.Input("Out"); + + int dim = context.Attr("dim"); + framework::DDim shape = x->dims(); + Tensor* d_x = context.Output(framework::GradVarName("X")); + + auto* d_out_data = d_out->data(); + auto* x_data = x->data(); + auto* out_data = out->data(); + auto* d_x_data = d_x->mutable_data(context.GetPlace()); + + auto place = BOOST_GET_CONST(platform::CPUPlace, context.GetPlace()); + const auto& dev_ctx = + context.template device_context(); + + size_t outer_dim = 1; + size_t mid_dim = 1; + size_t inner_dim = 1; + GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim); + size_t numel = outer_dim * mid_dim * inner_dim; + + // deal with complex + const T* x_data_deal; + const T* out_data_deal; + memory::AllocationPtr x_conj; + memory::AllocationPtr out_conj; + if (framework::IsComplex::value) { + x_conj = memory::Alloc(place, numel * sizeof(T)); + auto* x_data_conj = reinterpret_cast(x_conj->ptr()); + out_conj = memory::Alloc(place, numel * sizeof(T)); + auto* out_data_conj = reinterpret_cast(out_conj->ptr()); + + platform::ForRange for_range_x(dev_ctx, + numel); + math::ConjFunctor functor_x(x_data, numel, x_data_conj); + for_range_x(functor_x); + + platform::ForRange for_range_out(dev_ctx, + numel); + math::ConjFunctor functor_out(out_data, numel, out_data_conj); + for_range_out(functor_out); + + x_data_deal = x_data_conj; + out_data_deal = out_data_conj; + } else { + x_data_deal = x_data; + out_data_deal = out_data; + } + + for (size_t i = 0; i < outer_dim; i++) { + for (size_t k = 0; k < inner_dim; k++) { + for (size_t j = 0; j < mid_dim; j++) { + size_t index = i * mid_dim * inner_dim + j * inner_dim + k; + d_x_data[index] = 0; + for (size_t n = 0; n < mid_dim; n++) { + size_t pos = i * mid_dim * inner_dim + n * inner_dim + k; + T elem; + if (j == 0) { + elem = d_out_data[pos]; + } else { + elem = d_out_data[pos] * out_data_deal[index - inner_dim]; + } + if (pos > index) { + for (size_t m = index + inner_dim; m <= pos; m += inner_dim) { + elem *= x_data_deal[m]; + } + } else if (pos < index) { + elem = static_cast(0); + } + d_x_data[index] += elem; + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/inclusive_scan.h b/paddle/fluid/operators/math/inclusive_scan.h new file mode 100644 index 0000000000..71080bf424 --- /dev/null +++ b/paddle/fluid/operators/math/inclusive_scan.h @@ -0,0 +1,246 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +namespace math { + +template +static void CubInclusiveScan(InputIterator x_iter, OutputIterator y_iter, + size_t n, BinaryOp op, + const platform::CUDADeviceContext &dev_ctx) { + memory::AllocationPtr allocation; + void *temp_storage = nullptr; + size_t temp_storage_bytes = 0; + for (size_t i = 0; i < 2; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceScan::InclusiveScan( + temp_storage, temp_storage_bytes, x_iter, y_iter, op, + static_cast(n), // Maybe overflow? + dev_ctx.stream())); + if (i == 0 && temp_storage_bytes > 0) { + allocation = memory::Alloc(dev_ctx.GetPlace(), temp_storage_bytes); + temp_storage = allocation->ptr(); + } + } +} + +template +static auto MakeThrustReverseIterator(T *x) { + return thrust::reverse_iterator>( + thrust::device_pointer_cast(x)); +} + +template +struct InclusiveScanOuterOrMidDimFunctor { + HOSTDEVICE InclusiveScanOuterOrMidDimFunctor(const T *x, T *y, size_t mid_dim, + size_t inner_dim, T init, + BinaryOp op) + : x_(x), + y_(y), + mid_dim_(mid_dim), + inner_dim_(inner_dim), + init_(init), + op_(op) {} + + HOSTDEVICE void operator()(size_t idx) const { + auto outer_idx = idx / inner_dim_; + auto inner_idx = idx % inner_dim_; + if (kReverse) { + idx = outer_idx * mid_dim_ * inner_dim_ + (mid_dim_ - 1) * inner_dim_ + + inner_idx; + } else { + idx = outer_idx * mid_dim_ * inner_dim_ + inner_idx; + } + + auto x_ptr = x_ + idx; + auto y_ptr = y_ + idx; + T acc_value = init_; + for (size_t i = 0; i < mid_dim_; ++i) { + acc_value = op_(acc_value, *x_ptr); + *y_ptr = acc_value; + if (kReverse) { + x_ptr -= inner_dim_; + y_ptr -= inner_dim_; + } else { + x_ptr += inner_dim_; + y_ptr += inner_dim_; + } + } + } + + private: + const T *x_; + T *y_; + size_t mid_dim_; + size_t inner_dim_; + T init_; + BinaryOp op_; +}; + +// Reference to +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp + +template +static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y, + size_t num_rows, + size_t row_size, T init, + BinaryOp op) { + using RealT = math::Real; + constexpr auto kSharedBufferSize = + framework::IsComplex::value ? 4 * kThreadNumX : 2 * kThreadNumX; + __shared__ RealT sbuf[kThreadNumY][kSharedBufferSize]; + T *row_buf = reinterpret_cast(sbuf[threadIdx.y]); + + size_t block_row = static_cast(blockIdx.x * kThreadNumY); + size_t block_row_stride = static_cast(gridDim.x * kThreadNumY); + for (; block_row < num_rows; block_row += block_row_stride) { + size_t row = block_row + threadIdx.y; + T block_total = init; + + const T *row_x = x + row * row_size; + T *row_y = y + row * row_size; + for (size_t block_col = 0; block_col < row_size; + block_col += 2 * kThreadNumX) { + size_t col1, col2; + if (kReverse) { + col1 = row_size - 1 - block_col - threadIdx.x; + col2 = col1 - kThreadNumX; + } else { + col1 = block_col + threadIdx.x; + col2 = col1 + kThreadNumX; + } + + if (row < num_rows) { + if (col1 < row_size) { + row_buf[threadIdx.x] = row_x[col1]; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[kThreadNumX + threadIdx.x] = row_x[col2]; + } else { + row_buf[kThreadNumX + threadIdx.x] = init; + } + + if (threadIdx.x == 0) { + row_buf[0] = op(row_buf[0], block_total); + } + } + __syncthreads(); + + for (size_t s = kThreadNumX, d = 1; s >= 1; s >>= 1, d <<= 1) { + if (row < num_rows && threadIdx.x < s) { + size_t offset = (2 * threadIdx.x + 1) * d - 1; + row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]); + } + __syncthreads(); + } + + for (size_t s = 2, d = kThreadNumX / 2; d >= 1; s <<= 1, d >>= 1) { + if (row < num_rows && threadIdx.x < s - 1) { + size_t offset = 2 * (threadIdx.x + 1) * d - 1; + row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]); + } + __syncthreads(); + } + + if (row < num_rows) { + if (col1 < row_size) row_y[col1] = row_buf[threadIdx.x]; + if (col2 < row_size) row_y[col2] = row_buf[kThreadNumX + threadIdx.x]; + } + block_total = row_buf[2 * kThreadNumX - 1]; + __syncthreads(); + } + } +} + +template +static void InclusiveScanInnerDim(const T *x, T *y, size_t outer_dim, + size_t inner_dim, T init, BinaryOp op, + bool reverse, + const platform::CUDADeviceContext &dev_ctx) { + constexpr size_t kThreadNumX = 16; + constexpr size_t kThreadNumY = 32; + + size_t grid_dim = (outer_dim + kThreadNumY - 1) / kThreadNumY; + grid_dim = std::min(grid_dim, dev_ctx.GetCUDAMaxGridDimSize().x); + dim3 thread_dims(kThreadNumX, kThreadNumY); + if (reverse) { + InclusiveScanInnerDimCUDAKernel< + T, BinaryOp, kThreadNumX, kThreadNumY, + /*kReverse=*/true><<>>( + x, y, outer_dim, inner_dim, init, op); + } else { + InclusiveScanInnerDimCUDAKernel< + T, BinaryOp, kThreadNumX, kThreadNumY, + /*kReverse=*/false><<>>( + x, y, outer_dim, inner_dim, init, op); + } +} + +template +void InclusiveScan(const T *x, T *y, size_t outer_dim, size_t mid_dim, + size_t inner_dim, T init, BinaryOp op, bool reverse, + const platform::CUDADeviceContext &dev_ctx) { + if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return; + + if (outer_dim == 1 && inner_dim == 1) { + if (reverse) { + auto x_reverse_iter = MakeThrustReverseIterator(x + mid_dim); + auto y_reverse_iter = MakeThrustReverseIterator(y + mid_dim); + CubInclusiveScan(x_reverse_iter, y_reverse_iter, mid_dim, op, dev_ctx); + } else { + CubInclusiveScan(x, y, mid_dim, op, dev_ctx); + } + } else if (inner_dim != 1) { + platform::ForRange for_range( + dev_ctx, outer_dim * inner_dim); + if (reverse) { + for_range( + InclusiveScanOuterOrMidDimFunctor( + x, y, mid_dim, inner_dim, init, op)); + } else { + for_range( + InclusiveScanOuterOrMidDimFunctor( + x, y, mid_dim, inner_dim, init, op)); + } + } else { + InclusiveScanInnerDim(x, y, outer_dim, mid_dim, init, op, + reverse, dev_ctx); + } +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 24a7a666fb..9d60a5b381 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -163,6 +163,7 @@ from .tensor.math import cos # noqa: F401 from .tensor.math import tan # noqa: F401 from .tensor.math import cosh # noqa: F401 from .tensor.math import cumsum # noqa: F401 +from .tensor.math import cumprod # noqa: F401 from .tensor.math import exp # noqa: F401 from .tensor.math import expm1 # noqa: F401 from .tensor.math import floor # noqa: F401 @@ -330,6 +331,7 @@ __all__ = [ # noqa 'empty_like', 'eye', 'cumsum', + 'cumprod', 'sign', 'is_empty', 'equal', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index bb250e32c0..c2878efcad 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -947,6 +947,7 @@ set_tests_properties(test_mean_op PROPERTIES TIMEOUT 120) set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) +set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_cumprod_op.py b/python/paddle/fluid/tests/unittests/test_cumprod_op.py new file mode 100644 index 0000000000..31e7ee287f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cumprod_op.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest +import random +import paddle + +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import compiler, Program, program_guard + +np.random.seed(0) + + +# define cumprod grad function. +def cumprod_grad(x, y, dy, dx, shape, dim): + if dim < 0: + dim += len(shape) + mid_dim = shape[dim] + outer_dim = 1 + inner_dim = 1 + for i in range(0, dim): + outer_dim *= shape[i] + for i in range(dim + 1, len(shape)): + inner_dim *= shape[i] + for i in range(outer_dim): + for k in range(inner_dim): + for j in range(mid_dim): + index = i * mid_dim * inner_dim + j * inner_dim + k + for n in range(mid_dim): + pos = i * mid_dim * inner_dim + n * inner_dim + k + elem = 0 + if j == 0: + elem = dy[pos] + else: + elem = dy[pos] * y[index - inner_dim] + if pos > index: + for m in range(index + inner_dim, pos + inner_dim, + inner_dim): + elem *= x[m] + elif pos < index: + elem = 0 + dx[index] += elem + + +# test function. +class TestCumprod(OpTest): + def init_params(self): + self.shape = (2, 3, 4, 5) + self.zero_nums = [0, 10, 20, 30, int(np.prod(self.shape))] + + def init_dtype(self): + self.dtype = np.float64 + + def setUp(self): + paddle.enable_static() + self.init_params() + self.init_dtype() + self.op_type = "cumprod" + self.inputs = {'X': None} + self.outputs = {'Out': None} + self.attrs = {'dim': None} + + def prepare_inputs_outputs_attrs(self, dim, zero_num): + self.x = np.random.random(self.shape).astype(self.dtype) + 0.5 + if zero_num > 0: + zero_num = min(zero_num, self.x.size) + shape = self.x.shape + self.x = self.x.flatten() + indices = random.sample(range(self.x.size), zero_num) + for i in indices: + self.x[i] = 0 + self.x = np.reshape(self.x, self.shape) + self.out = np.cumprod(self.x, axis=dim) + self.inputs = {'X': self.x} + self.outputs = {'Out': self.out} + self.attrs = {'dim': dim} + + def init_grad_input_output(self, dim): + reshape_x = self.x.reshape(self.x.size) + self.grad_out = np.ones(self.x.size, self.dtype) + self.grad_x = np.zeros(self.x.size, self.dtype) + out_data = self.out.reshape(self.x.size) + if self.dtype == np.complex128 or self.dtype == np.complex64: + reshape_x = np.conj(reshape_x) + out_data = np.conj(out_data) + cumprod_grad(reshape_x, out_data, self.grad_out, self.grad_x, + self.shape, dim) + self.grad_x = self.grad_x.reshape(self.shape) + self.grad_out = self.grad_out.reshape(self.shape) + + # test forward. + def test_check_output(self): + for dim in range(-len(self.shape), len(self.shape)): + for zero_num in self.zero_nums: + self.prepare_inputs_outputs_attrs(dim, zero_num) + self.check_output() + + # test backward. + def test_check_grad(self): + for dim in range(-len(self.shape), len(self.shape)): + for zero_num in self.zero_nums: + self.prepare_inputs_outputs_attrs(dim, zero_num) + self.init_grad_input_output(dim) + if self.dtype == np.float64: + self.check_grad(['X'], 'Out') + else: + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +# test float32 case. +class TestCumprod_float32(TestCumprod): + def init_dtype(self): + self.dtype = np.float32 + + +# test complex64 case. +class TestCumprod_complex64(TestCumprod): + def init_dtype(self): + self.dtype = np.complex64 + + +# test complex128 case. +class TestCumprod_complex128(TestCumprod): + def init_dtype(self): + self.dtype = np.complex128 + + +# test api. +class TestCumprodAPI(unittest.TestCase): + def init_dtype(self): + self.dtype = 'float64' + self.shape = [2, 3, 10, 10] + + def setUp(self): + paddle.enable_static() + self.init_dtype() + self.x = (np.random.rand(2, 3, 10, 10) + 0.5).astype(self.dtype) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + # test static graph api. + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape, dtype=self.dtype) + out = paddle.cumprod(x, -2) + exe = paddle.static.Executor(place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.cumprod(self.x, -2) + + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + for place in self.place: + run(place) + + # test dynamic graph api. + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.cumprod(x, 1) + out_ref = np.cumprod(self.x, 1) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index a67b015f8f..b9e0c75a60 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -108,6 +108,7 @@ from .math import cos # noqa: F401 from .math import tan # noqa: F401 from .math import cosh # noqa: F401 from .math import cumsum # noqa: F401 +from .math import cumprod # noqa: F401 from .math import exp # noqa: F401 from .math import exp_ # noqa: F401 from .math import expm1 # noqa: F401 @@ -236,6 +237,7 @@ tensor_method_func = [ #noqa 'cos', 'cosh', 'cumsum', + 'cumprod', 'exp', 'exp_', 'floor', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 29f3425cb7..298ee031a9 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1920,6 +1920,66 @@ def cumsum(x, axis=None, dtype=None, name=None): _cum_sum_ = generate_layer_fn('cumsum') return _cum_sum_(**kwargs) +def cumprod(x, dim=None, dtype=None, name=None): + """ + Compute the cumulative product of the input tensor x along a given dimension dim. + + **Note**: + The first element of the result is the same as the first element of the input. + + Args: + x (Tensor): the input tensor need to be cumproded. + dim (int): the dimension along which the input tensor will be accumulated. It need to be in the range of [-x.rank, x.rank), where x.rank means the dimensions of the input tensor x and -1 means the last dimension. + dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64, complex64, complex128. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + + Returns: + Tensor, the result of cumprod operator. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.arange(12) + data = paddle.reshape(data, (3, 4)) + # [[ 0 1 2 3 ] + # [ 4 5 6 7 ] + # [ 8 9 10 11]] + + y = paddle.cumprod(data, dim=0) + # [[ 0 1 2 3] + # [ 0 5 12 21] + # [ 0 45 120 231]] + + y = paddle.cumprod(data, dim=-1) + # [[ 0 0 0 0] + # [ 4 20 120 840] + # [ 8 72 720 7920]] + + y = paddle.cumprod(data, dim=1, dtype='float64') + # [[ 0. 0. 0. 0.] + # [ 4. 20. 120. 840.] + # [ 8. 72. 720. 7920.]] + + print(y.dtype) + # paddle.float64 + + """ + + if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): + x = layers.cast(x, dtype) + + if in_dygraph_mode(): + return _C_ops.cumprod(x, 'dim', dim) + + check_variable_and_dtype(x, "x", ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'], 'cumprod') + check_type(dim, 'dim', int, 'cumprod') + + helper = LayerHelper('cumprod', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='cumprod', inputs={'X': x}, outputs={'Out': out}, attrs={'dim': dim}) + return out + def isfinite(x, name=None): """ -- GitLab