未验证 提交 4e509f46 编写于 作者: H hlygit66666 提交者: GitHub

add cumprod op (#35185)

* add test_cumprod_op

* Revert "add test_cumprod_op"

This reverts commit c96cf6dff5d09ae7d8cc72c1e8ae4369a153aa19.

* recommit

* add error message

* test input(x) initialize

* test use cpu

* update test code

* add test type

* add test case

* solve ci problem

* add complex case test

* add complex case test

* fix review problem

* fix conflict

* fix some docs

* change test case

* change test case

* fix review problems again

* fix docs

* fix inclusivescan bug
上级 5bdca05b
...@@ -37,6 +37,12 @@ struct complex; ...@@ -37,6 +37,12 @@ struct complex;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct IsComplex : public std::false_type {};
template <typename T>
struct IsComplex<platform::complex<T>> : public std::true_type {};
template <typename T> template <typename T>
struct DataTypeTrait {}; struct DataTypeTrait {};
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cumprod_op.h"
namespace paddle {
namespace operators {
class CumprodOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cumprod");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cumprod");
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
};
class CumprodOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of cumprod op.");
AddOutput("Out", "(Tensor), The output tensor of cumprod op.");
AddAttr<int>(
"dim",
"(int), The dim along which the input tensors will be cumproded");
AddComment(
R"DOC(Cumprod operator. Return the cumprod results of the input elements along the dim.
For example, if input X is a tensor with rank 1 and N elements, the output will also be a tensor
with rank 1 and N elements, and elements y[i] = x[0] * x[1] * x[2] *...* x[i] (0<=i<N))DOC");
}
};
template <typename T>
class CumprodGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("cumprod_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class CumprodGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CumprodGrad");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CumprodGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"framework::GradVarName(\"Out\")", "CumprodGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"framework::GradVarName(\"X\")", "CumprodGrad");
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker,
ops::CumprodGradOpMaker<paddle::framework::OpDesc>,
ops::CumprodGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp);
REGISTER_OP_CPU_KERNEL(
cumprod, ops::CumprodOpCPUKernel<float>, ops::CumprodOpCPUKernel<double>,
ops::CumprodOpCPUKernel<int>, ops::CumprodOpCPUKernel<int64_t>,
ops::CumprodOpCPUKernel<paddle::platform::complex<float>>,
ops::CumprodOpCPUKernel<paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
cumprod_grad, ops::CumprodGradOpCPUKernel<float>,
ops::CumprodGradOpCPUKernel<double>, ops::CumprodGradOpCPUKernel<int>,
ops::CumprodGradOpCPUKernel<int64_t>,
ops::CumprodGradOpCPUKernel<paddle::platform::complex<float>>,
ops::CumprodGradOpCPUKernel<paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/transform.h>
#include "paddle/fluid/operators/cumprod_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
struct MultiplyFunctor {
HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
class CumprodOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<framework::Tensor>("X");
auto *y = ctx.Output<framework::Tensor>("Out");
auto dim = ctx.Attr<int>("dim");
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim);
const auto *x_data = x->data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::InclusiveScan<T, MultiplyFunctor<T>>(
x_data, y_data, outer_dim, mid_dim, inner_dim, static_cast<T>(1),
MultiplyFunctor<T>(), /*reverse=*/false, dev_ctx);
}
};
template <typename T>
struct IsZeroFunctor {
HOSTDEVICE bool operator()(T x) const { return x == static_cast<T>(0); }
};
template <typename T>
struct CumprodGradFunctorExceptFirstZero {
HOSTDEVICE CumprodGradFunctorExceptFirstZero(
const T *x, const T *y, const T *dy_mul_y_reversed_cumsum,
const uint8_t *zero_mask, size_t mid_dim, size_t inner_dim, T *dx,
int64_t *first_zero_idx, T *x_filled_one)
: x_(x),
y_(y),
dy_mul_y_reversed_cumsum_(dy_mul_y_reversed_cumsum),
zero_mask_(zero_mask),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
dx_(dx),
first_zero_idx_(first_zero_idx),
x_filled_one_(x_filled_one) {}
HOSTDEVICE void operator()(size_t idx) const {
auto inner_idx = idx % inner_dim_;
auto outer_idx = idx / (mid_dim_ * inner_dim_);
auto mid_idx = (idx - inner_idx) / inner_dim_ % mid_dim_;
auto mask = zero_mask_[idx];
bool should_fill_one = true;
if (mask == 0) {
dx_[idx] = dy_mul_y_reversed_cumsum_[idx] / x_[idx];
if (mid_idx == mid_dim_ - 1) {
// record first zero position as -1, i.e., no zero
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = -1;
}
} else if (mid_idx > 0) { // mask > 0
if (zero_mask_[idx - inner_dim_] > 0) { // not first zero
dx_[idx] = 0;
should_fill_one = false;
} else {
// idx is the first zero position, it should be recorded
dx_[idx] = y_[idx - inner_dim_];
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = mid_idx;
}
} else { // the first zero position is index 0
dx_[idx] = 1;
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
}
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx];
}
private:
const T *x_;
const T *y_;
const T *dy_mul_y_reversed_cumsum_;
const uint8_t *zero_mask_;
size_t mid_dim_;
size_t inner_dim_;
T *dx_;
int64_t *first_zero_idx_;
T *x_filled_one_;
};
template <typename T>
struct FillFirstZeroPositionGradFunctor {
HOSTDEVICE FillFirstZeroPositionGradFunctor(const int64_t *first_zero_idx,
const T *grad_value,
size_t mid_dim, size_t inner_dim,
T *dx)
: first_zero_idx_(first_zero_idx),
grad_value_(grad_value),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
dx_(dx) {}
HOSTDEVICE void operator()(size_t idx) const {
auto outer_idx = idx / inner_dim_;
auto inner_idx = idx % inner_dim_;
auto mid_idx = first_zero_idx_[idx];
if (mid_idx >= 0) {
auto full_idx =
outer_idx * mid_dim_ * inner_dim_ + mid_idx * inner_dim_ + inner_idx;
dx_[full_idx] *= grad_value_[full_idx];
}
}
private:
const int64_t *first_zero_idx_;
const T *grad_value_;
size_t mid_dim_;
size_t inner_dim_;
T *dx_;
};
/*
Reference to
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp
input: x, y, dL/dy
output: dL/dx
dL/dx[i] = sum{0<=j<n} (dL/dy[j])*(dy[j]/dx[i]) (1)
= sum(0<=j<n} (dL/dy[j])*(d(x[0]*x[1]*...*x[j])/dx[i])
if x[i] != 0, dL/dx[i] = sum{i<=j<n} (dL/dy[j])*(y[j]/x[i]) (2)
if x[i] == 0, the formula(2) can not be applied directly.
Suppose k is the first index of zero element, the formula will be:
i > k, dL/dx[i] = 0;
i < k, dL/dx[i] = 1/x[i]*sum{i<=j<n} (dL/dy[j]*y[j])
i = k, dL/dx[i] = y[i-1]*sum{i<=j<n} (dL/dy[j])*(x[i+1]*...*x[j])
First, we will show the main resolution.
We need to judge the relationship between i (current index) and k (index
which corresponds to the first element of 0).
To mark the relationship, we now introduce zero_mask and we also need to
mark the index of the first zero element.
zero_mask = cummax(x[i] == 0); //label whether x[i]==0 until the index.
zero_index = -1; //store the first zero element's index.
e.g. x = [1, 4, 5, 0, 2, 3, 0];
zero_mask = [0, 0, 0, 1, 1, 1, 1];
zero_index = 3;
When i < k, we need to calculate the result of sum{i<=j<n}(d_y[j]*y[j]), we can
use reversed cumsum to calculate it.
R = reversed_cumsum(dy[j]*y[j]); //store the calculation result of the
sum{i<=j<n}(d_y[j]*y[j]) and x[k+1],x[k+2],...,x[j] along the index k+1 ~ j.
When i = k, we need to calculate the result of prod{i<w<j}(x[w]).
To calculate it, we introduce x_filled_one, which fill 1 before x[k+1] along
the index 0 ~ k.
e.g. x = [1, 4, 5, 0, 2, 3, 0];
x_filled_one = [1, 1, 1, 1, 2, 3, 0];
Thus, we can use cumprod(x_filled_one[j]) to calculate the result of
prod{k<=w<j}(x[w]).
Then, we will show more detailed implementation.
for (int i = 0; i < numel; i++) {
if (zero_mask[i] == 0) { //case i < k
dx[i] = R[i] / x[i];
x_filled_one[i] = 1;
} else {
if (i == 0) { //case i = k
dx[i] = 1;
zero_index = i;
x_filled_one[i] = 1;
} else {
if (zero_mask[i-1] == 0) { //case i = k
dx[i] = y[i-1];
zero_index = i;
x_filled_one[i] = 1;
} else { //case i > k
dx[i] = 0;
x_filled_one[i] = x[i];
}
}
}
}
T = reversed_cumsum(dy[j]*cumprod(x_filled_one[j]));
if (zero_index != -1) {
dx[zero_index] *= T[zero_index];
}
*/
template <typename T>
class CumprodGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<framework::Tensor>("X");
const auto *y = ctx.Input<framework::Tensor>("Out");
const auto *dy =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto dim = ctx.Attr<int>("dim");
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim);
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
size_t numel = outer_dim * mid_dim * inner_dim;
const auto *x_data = x->data<T>();
const auto *y_data = y->data<T>();
const auto *dy_data = dy->data<T>();
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto *dx_data = dx->mutable_data<T>(place);
// deal with complex
const T *x_data_deal;
const T *y_data_deal;
memory::AllocationPtr x_conj;
memory::AllocationPtr y_conj;
if (framework::IsComplex<T>::value) {
x_conj = memory::Alloc(place, numel * sizeof(T));
auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr());
y_conj = memory::Alloc(place, numel * sizeof(T));
auto *y_data_conj = reinterpret_cast<T *>(y_conj->ptr());
platform::ForRange<platform::CUDADeviceContext> for_range_x(dev_ctx,
numel);
math::ConjFunctor<T> functor_x(x_data, numel, x_data_conj);
for_range_x(functor_x);
platform::ForRange<platform::CUDADeviceContext> for_range_y(dev_ctx,
numel);
math::ConjFunctor<T> functor_y(y_data, numel, y_data_conj);
for_range_y(functor_y);
x_data_deal = x_data_conj;
y_data_deal = y_data_conj;
} else {
x_data_deal = x_data;
y_data_deal = y_data;
}
// Step 1: find cummax-ed zero mask of x
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
auto zero_mask_without_cummax =
memory::Alloc(place, numel * sizeof(uint8_t));
auto *zero_mask_without_cummax_data =
reinterpret_cast<uint8_t *>(zero_mask_without_cummax->ptr());
thrust::transform(
exec_policy, thrust::device_pointer_cast(x_data_deal),
thrust::device_pointer_cast(x_data_deal) + numel,
thrust::device_pointer_cast(zero_mask_without_cummax_data),
IsZeroFunctor<T>());
auto zero_mask = memory::Alloc(place, numel * sizeof(uint8_t));
auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr());
math::InclusiveScan<uint8_t, cub::Max>(
zero_mask_without_cummax_data, zero_mask_data, outer_dim, mid_dim,
inner_dim, static_cast<uint8_t>(0), cub::Max(), /*reverse=*/false,
dev_ctx);
zero_mask_without_cummax = nullptr;
// Step 2: calculate reversed cumsum(dy * y)
auto dy_mul_y = memory::Alloc(place, numel * sizeof(T));
auto *dy_mul_y_data = reinterpret_cast<T *>(dy_mul_y->ptr());
thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data),
thrust::device_pointer_cast(dy_data) + numel,
thrust::device_pointer_cast(y_data_deal),
thrust::device_pointer_cast(dy_mul_y_data),
MultiplyFunctor<T>());
auto dy_mul_y_reversed_cumsum = memory::Alloc(place, numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
math::InclusiveScan<T, cub::Sum>(
dy_mul_y_data, dy_mul_y_reversed_cumsum_data, outer_dim, mid_dim,
inner_dim, static_cast<T>(0), cub::Sum(), /*reverse=*/true, dev_ctx);
// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
// while the gradient value of the other positions are calculated out
// completely. This functor also:
// (1) find the first zero index, i.e., first_zero_idx_data.
// (2) fill x_filled_one, which satifies
// x_filled_one[i] = x[i], i > pos
// x_filled_one[i] = 1, i <= pos
auto first_zero_idx =
memory::Alloc(place, outer_dim * inner_dim * sizeof(int64_t));
auto *first_zero_idx_data =
reinterpret_cast<int64_t *>(first_zero_idx->ptr());
auto *x_filled_one_data = dy_mul_y_data; // reuse former allocated memory
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, numel);
CumprodGradFunctorExceptFirstZero<T> functor_except_first_zero(
x_data_deal, y_data_deal, dy_mul_y_reversed_cumsum_data, zero_mask_data,
mid_dim, inner_dim, dx_data, first_zero_idx_data, x_filled_one_data);
for_range(functor_except_first_zero);
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
math::InclusiveScan<T, MultiplyFunctor<T>>(
x_filled_one_data, x_filled_one_cumprod_data, outer_dim, mid_dim,
inner_dim, static_cast<T>(1), MultiplyFunctor<T>(), /*reverse=*/false,
dev_ctx);
// Step 5: calculate reversed cumsum(dy * x_filled_one_cumprod)
auto *dy_mul_x_filled_one_cumprod =
dy_mul_y_data; // reuse former allocated memory
thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data),
thrust::device_pointer_cast(dy_data) + numel,
thrust::device_pointer_cast(x_filled_one_cumprod_data),
thrust::device_pointer_cast(dy_mul_x_filled_one_cumprod),
MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
math::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum, outer_dim, mid_dim,
inner_dim, static_cast<T>(0), cub::Sum(),
/*reverse=*/true, dev_ctx);
// Step 6: fill zero pos gradient value
platform::ForRange<platform::CUDADeviceContext>
for_range_fill_zero_pos_grad(dev_ctx, outer_dim * inner_dim);
FillFirstZeroPositionGradFunctor<T> fill_first_zero_pos_grad_functor(
first_zero_idx_data, dy_mul_x_filled_one_cumprod_reversed_cumsum,
mid_dim, inner_dim, dx_data);
for_range_fill_zero_pos_grad(fill_first_zero_pos_grad_functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cumprod, ops::CumprodOpCUDAKernel<float>, ops::CumprodOpCUDAKernel<double>,
ops::CumprodOpCUDAKernel<int>, ops::CumprodOpCUDAKernel<int64_t>,
ops::CumprodOpCUDAKernel<paddle::platform::complex<float>>,
ops::CumprodOpCUDAKernel<paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
cumprod_grad, ops::CumprodGradOpCUDAKernel<float>,
ops::CumprodGradOpCUDAKernel<double>, ops::CumprodGradOpCUDAKernel<int>,
ops::CumprodGradOpCUDAKernel<int64_t>,
ops::CumprodGradOpCUDAKernel<paddle::platform::complex<float>>,
ops::CumprodGradOpCUDAKernel<paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static void GetCumprodDimInfo(const framework::DDim& dim, int cumprod_dim,
size_t* outer_dim, size_t* mid_dim,
size_t* inner_dim) {
PADDLE_ENFORCE_GE(
cumprod_dim, -dim.size(),
platform::errors::InvalidArgument(
"The input dim of CumprodOp should be larger than the opposite "
"rank of input x which is %d.But received dim=%d",
-dim.size(), cumprod_dim));
PADDLE_ENFORCE_LT(cumprod_dim, dim.size(),
platform::errors::InvalidArgument(
"The input dim of CumprodOp should be smaller than the "
"rank of input x which is %d.But received dim=%d",
dim.size(), cumprod_dim));
if (cumprod_dim < 0) cumprod_dim += dim.size();
*outer_dim = 1;
for (int i = 0; i < cumprod_dim; ++i) {
*outer_dim *= dim[i];
}
*mid_dim = dim[cumprod_dim];
*inner_dim = 1;
for (int i = cumprod_dim + 1; i < dim.size(); ++i) {
*inner_dim *= dim[i];
}
}
template <typename T>
class CumprodOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
int dim = context.Attr<int>("dim");
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
framework::DDim shape = x->dims();
size_t outer_dim = 1;
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
for (size_t i = 0; i < outer_dim; i++) {
for (size_t j = 0; j < mid_dim; j++) {
for (size_t k = 0; k < inner_dim; k++) {
size_t pos = i * mid_dim * inner_dim + j * inner_dim + k;
if (j == 0) {
out_data[pos] = x_data[pos];
} else {
out_data[pos] = out_data[pos - inner_dim] * x_data[pos];
}
}
}
}
}
};
template <typename T>
class CumprodGradOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const {
const Tensor* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
const Tensor* x = context.Input<Tensor>("X");
const Tensor* out = context.Input<Tensor>("Out");
int dim = context.Attr<int>("dim");
framework::DDim shape = x->dims();
Tensor* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* d_out_data = d_out->data<T>();
auto* x_data = x->data<T>();
auto* out_data = out->data<T>();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
auto place = BOOST_GET_CONST(platform::CPUPlace, context.GetPlace());
const auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t outer_dim = 1;
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
size_t numel = outer_dim * mid_dim * inner_dim;
// deal with complex
const T* x_data_deal;
const T* out_data_deal;
memory::AllocationPtr x_conj;
memory::AllocationPtr out_conj;
if (framework::IsComplex<T>::value) {
x_conj = memory::Alloc(place, numel * sizeof(T));
auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr());
out_conj = memory::Alloc(place, numel * sizeof(T));
auto* out_data_conj = reinterpret_cast<T*>(out_conj->ptr());
platform::ForRange<platform::CPUDeviceContext> for_range_x(dev_ctx,
numel);
math::ConjFunctor<T> functor_x(x_data, numel, x_data_conj);
for_range_x(functor_x);
platform::ForRange<platform::CPUDeviceContext> for_range_out(dev_ctx,
numel);
math::ConjFunctor<T> functor_out(out_data, numel, out_data_conj);
for_range_out(functor_out);
x_data_deal = x_data_conj;
out_data_deal = out_data_conj;
} else {
x_data_deal = x_data;
out_data_deal = out_data;
}
for (size_t i = 0; i < outer_dim; i++) {
for (size_t k = 0; k < inner_dim; k++) {
for (size_t j = 0; j < mid_dim; j++) {
size_t index = i * mid_dim * inner_dim + j * inner_dim + k;
d_x_data[index] = 0;
for (size_t n = 0; n < mid_dim; n++) {
size_t pos = i * mid_dim * inner_dim + n * inner_dim + k;
T elem;
if (j == 0) {
elem = d_out_data[pos];
} else {
elem = d_out_data[pos] * out_data_deal[index - inner_dim];
}
if (pos > index) {
for (size_t m = index + inner_dim; m <= pos; m += inner_dim) {
elem *= x_data_deal[m];
}
} else if (pos < index) {
elem = static_cast<T>(0);
}
d_x_data[index] += elem;
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <thrust/device_ptr.h>
#include <thrust/iterator/reverse_iterator.h>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
namespace math {
template <typename InputIterator, typename OutputIterator, typename BinaryOp>
static void CubInclusiveScan(InputIterator x_iter, OutputIterator y_iter,
size_t n, BinaryOp op,
const platform::CUDADeviceContext &dev_ctx) {
memory::AllocationPtr allocation;
void *temp_storage = nullptr;
size_t temp_storage_bytes = 0;
for (size_t i = 0; i < 2; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceScan::InclusiveScan(
temp_storage, temp_storage_bytes, x_iter, y_iter, op,
static_cast<int>(n), // Maybe overflow?
dev_ctx.stream()));
if (i == 0 && temp_storage_bytes > 0) {
allocation = memory::Alloc(dev_ctx.GetPlace(), temp_storage_bytes);
temp_storage = allocation->ptr();
}
}
}
template <typename T>
static auto MakeThrustReverseIterator(T *x) {
return thrust::reverse_iterator<thrust::device_ptr<T>>(
thrust::device_pointer_cast(x));
}
template <typename T, typename BinaryOp, bool kReverse>
struct InclusiveScanOuterOrMidDimFunctor {
HOSTDEVICE InclusiveScanOuterOrMidDimFunctor(const T *x, T *y, size_t mid_dim,
size_t inner_dim, T init,
BinaryOp op)
: x_(x),
y_(y),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
init_(init),
op_(op) {}
HOSTDEVICE void operator()(size_t idx) const {
auto outer_idx = idx / inner_dim_;
auto inner_idx = idx % inner_dim_;
if (kReverse) {
idx = outer_idx * mid_dim_ * inner_dim_ + (mid_dim_ - 1) * inner_dim_ +
inner_idx;
} else {
idx = outer_idx * mid_dim_ * inner_dim_ + inner_idx;
}
auto x_ptr = x_ + idx;
auto y_ptr = y_ + idx;
T acc_value = init_;
for (size_t i = 0; i < mid_dim_; ++i) {
acc_value = op_(acc_value, *x_ptr);
*y_ptr = acc_value;
if (kReverse) {
x_ptr -= inner_dim_;
y_ptr -= inner_dim_;
} else {
x_ptr += inner_dim_;
y_ptr += inner_dim_;
}
}
}
private:
const T *x_;
T *y_;
size_t mid_dim_;
size_t inner_dim_;
T init_;
BinaryOp op_;
};
// Reference to
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp
template <typename T, typename BinaryOp, size_t kThreadNumX, size_t kThreadNumY,
bool kReverse>
static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y,
size_t num_rows,
size_t row_size, T init,
BinaryOp op) {
using RealT = math::Real<T>;
constexpr auto kSharedBufferSize =
framework::IsComplex<T>::value ? 4 * kThreadNumX : 2 * kThreadNumX;
__shared__ RealT sbuf[kThreadNumY][kSharedBufferSize];
T *row_buf = reinterpret_cast<T *>(sbuf[threadIdx.y]);
size_t block_row = static_cast<size_t>(blockIdx.x * kThreadNumY);
size_t block_row_stride = static_cast<size_t>(gridDim.x * kThreadNumY);
for (; block_row < num_rows; block_row += block_row_stride) {
size_t row = block_row + threadIdx.y;
T block_total = init;
const T *row_x = x + row * row_size;
T *row_y = y + row * row_size;
for (size_t block_col = 0; block_col < row_size;
block_col += 2 * kThreadNumX) {
size_t col1, col2;
if (kReverse) {
col1 = row_size - 1 - block_col - threadIdx.x;
col2 = col1 - kThreadNumX;
} else {
col1 = block_col + threadIdx.x;
col2 = col1 + kThreadNumX;
}
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_x[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[kThreadNumX + threadIdx.x] = row_x[col2];
} else {
row_buf[kThreadNumX + threadIdx.x] = init;
}
if (threadIdx.x == 0) {
row_buf[0] = op(row_buf[0], block_total);
}
}
__syncthreads();
for (size_t s = kThreadNumX, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
size_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
for (size_t s = 2, d = kThreadNumX / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
size_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
if (row < num_rows) {
if (col1 < row_size) row_y[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_y[col2] = row_buf[kThreadNumX + threadIdx.x];
}
block_total = row_buf[2 * kThreadNumX - 1];
__syncthreads();
}
}
}
template <typename T, typename BinaryOp>
static void InclusiveScanInnerDim(const T *x, T *y, size_t outer_dim,
size_t inner_dim, T init, BinaryOp op,
bool reverse,
const platform::CUDADeviceContext &dev_ctx) {
constexpr size_t kThreadNumX = 16;
constexpr size_t kThreadNumY = 32;
size_t grid_dim = (outer_dim + kThreadNumY - 1) / kThreadNumY;
grid_dim = std::min<size_t>(grid_dim, dev_ctx.GetCUDAMaxGridDimSize().x);
dim3 thread_dims(kThreadNumX, kThreadNumY);
if (reverse) {
InclusiveScanInnerDimCUDAKernel<
T, BinaryOp, kThreadNumX, kThreadNumY,
/*kReverse=*/true><<<grid_dim, thread_dims, 0, dev_ctx.stream()>>>(
x, y, outer_dim, inner_dim, init, op);
} else {
InclusiveScanInnerDimCUDAKernel<
T, BinaryOp, kThreadNumX, kThreadNumY,
/*kReverse=*/false><<<grid_dim, thread_dims, 0, dev_ctx.stream()>>>(
x, y, outer_dim, inner_dim, init, op);
}
}
template <typename T, typename BinaryOp>
void InclusiveScan(const T *x, T *y, size_t outer_dim, size_t mid_dim,
size_t inner_dim, T init, BinaryOp op, bool reverse,
const platform::CUDADeviceContext &dev_ctx) {
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
if (outer_dim == 1 && inner_dim == 1) {
if (reverse) {
auto x_reverse_iter = MakeThrustReverseIterator(x + mid_dim);
auto y_reverse_iter = MakeThrustReverseIterator(y + mid_dim);
CubInclusiveScan(x_reverse_iter, y_reverse_iter, mid_dim, op, dev_ctx);
} else {
CubInclusiveScan(x, y, mid_dim, op, dev_ctx);
}
} else if (inner_dim != 1) {
platform::ForRange<platform::CUDADeviceContext> for_range(
dev_ctx, outer_dim * inner_dim);
if (reverse) {
for_range(
InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/true>(
x, y, mid_dim, inner_dim, init, op));
} else {
for_range(
InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/false>(
x, y, mid_dim, inner_dim, init, op));
}
} else {
InclusiveScanInnerDim<T, BinaryOp>(x, y, outer_dim, mid_dim, init, op,
reverse, dev_ctx);
}
}
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -163,6 +163,7 @@ from .tensor.math import cos # noqa: F401 ...@@ -163,6 +163,7 @@ from .tensor.math import cos # noqa: F401
from .tensor.math import tan # noqa: F401 from .tensor.math import tan # noqa: F401
from .tensor.math import cosh # noqa: F401 from .tensor.math import cosh # noqa: F401
from .tensor.math import cumsum # noqa: F401 from .tensor.math import cumsum # noqa: F401
from .tensor.math import cumprod # noqa: F401
from .tensor.math import exp # noqa: F401 from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401 from .tensor.math import expm1 # noqa: F401
from .tensor.math import floor # noqa: F401 from .tensor.math import floor # noqa: F401
...@@ -330,6 +331,7 @@ __all__ = [ # noqa ...@@ -330,6 +331,7 @@ __all__ = [ # noqa
'empty_like', 'empty_like',
'eye', 'eye',
'cumsum', 'cumsum',
'cumprod',
'sign', 'sign',
'is_empty', 'is_empty',
'equal', 'equal',
......
...@@ -947,6 +947,7 @@ set_tests_properties(test_mean_op PROPERTIES TIMEOUT 120) ...@@ -947,6 +947,7 @@ set_tests_properties(test_mean_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120)
set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler, Program, program_guard
np.random.seed(0)
# define cumprod grad function.
def cumprod_grad(x, y, dy, dx, shape, dim):
if dim < 0:
dim += len(shape)
mid_dim = shape[dim]
outer_dim = 1
inner_dim = 1
for i in range(0, dim):
outer_dim *= shape[i]
for i in range(dim + 1, len(shape)):
inner_dim *= shape[i]
for i in range(outer_dim):
for k in range(inner_dim):
for j in range(mid_dim):
index = i * mid_dim * inner_dim + j * inner_dim + k
for n in range(mid_dim):
pos = i * mid_dim * inner_dim + n * inner_dim + k
elem = 0
if j == 0:
elem = dy[pos]
else:
elem = dy[pos] * y[index - inner_dim]
if pos > index:
for m in range(index + inner_dim, pos + inner_dim,
inner_dim):
elem *= x[m]
elif pos < index:
elem = 0
dx[index] += elem
# test function.
class TestCumprod(OpTest):
def init_params(self):
self.shape = (2, 3, 4, 5)
self.zero_nums = [0, 10, 20, 30, int(np.prod(self.shape))]
def init_dtype(self):
self.dtype = np.float64
def setUp(self):
paddle.enable_static()
self.init_params()
self.init_dtype()
self.op_type = "cumprod"
self.inputs = {'X': None}
self.outputs = {'Out': None}
self.attrs = {'dim': None}
def prepare_inputs_outputs_attrs(self, dim, zero_num):
self.x = np.random.random(self.shape).astype(self.dtype) + 0.5
if zero_num > 0:
zero_num = min(zero_num, self.x.size)
shape = self.x.shape
self.x = self.x.flatten()
indices = random.sample(range(self.x.size), zero_num)
for i in indices:
self.x[i] = 0
self.x = np.reshape(self.x, self.shape)
self.out = np.cumprod(self.x, axis=dim)
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
self.attrs = {'dim': dim}
def init_grad_input_output(self, dim):
reshape_x = self.x.reshape(self.x.size)
self.grad_out = np.ones(self.x.size, self.dtype)
self.grad_x = np.zeros(self.x.size, self.dtype)
out_data = self.out.reshape(self.x.size)
if self.dtype == np.complex128 or self.dtype == np.complex64:
reshape_x = np.conj(reshape_x)
out_data = np.conj(out_data)
cumprod_grad(reshape_x, out_data, self.grad_out, self.grad_x,
self.shape, dim)
self.grad_x = self.grad_x.reshape(self.shape)
self.grad_out = self.grad_out.reshape(self.shape)
# test forward.
def test_check_output(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.check_output()
# test backward.
def test_check_grad(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.init_grad_input_output(dim)
if self.dtype == np.float64:
self.check_grad(['X'], 'Out')
else:
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
# test float32 case.
class TestCumprod_float32(TestCumprod):
def init_dtype(self):
self.dtype = np.float32
# test complex64 case.
class TestCumprod_complex64(TestCumprod):
def init_dtype(self):
self.dtype = np.complex64
# test complex128 case.
class TestCumprod_complex128(TestCumprod):
def init_dtype(self):
self.dtype = np.complex128
# test api.
class TestCumprodAPI(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float64'
self.shape = [2, 3, 10, 10]
def setUp(self):
paddle.enable_static()
self.init_dtype()
self.x = (np.random.rand(2, 3, 10, 10) + 0.5).astype(self.dtype)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
# test static graph api.
def test_static_api(self):
paddle.enable_static()
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.shape, dtype=self.dtype)
out = paddle.cumprod(x, -2)
exe = paddle.static.Executor(place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
out_ref = np.cumprod(self.x, -2)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
for place in self.place:
run(place)
# test dynamic graph api.
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.cumprod(x, 1)
out_ref = np.cumprod(self.x, 1)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)
paddle.enable_static()
for place in self.place:
run(place)
if __name__ == "__main__":
unittest.main()
...@@ -108,6 +108,7 @@ from .math import cos # noqa: F401 ...@@ -108,6 +108,7 @@ from .math import cos # noqa: F401
from .math import tan # noqa: F401 from .math import tan # noqa: F401
from .math import cosh # noqa: F401 from .math import cosh # noqa: F401
from .math import cumsum # noqa: F401 from .math import cumsum # noqa: F401
from .math import cumprod # noqa: F401
from .math import exp # noqa: F401 from .math import exp # noqa: F401
from .math import exp_ # noqa: F401 from .math import exp_ # noqa: F401
from .math import expm1 # noqa: F401 from .math import expm1 # noqa: F401
...@@ -236,6 +237,7 @@ tensor_method_func = [ #noqa ...@@ -236,6 +237,7 @@ tensor_method_func = [ #noqa
'cos', 'cos',
'cosh', 'cosh',
'cumsum', 'cumsum',
'cumprod',
'exp', 'exp',
'exp_', 'exp_',
'floor', 'floor',
......
...@@ -1920,6 +1920,66 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -1920,6 +1920,66 @@ def cumsum(x, axis=None, dtype=None, name=None):
_cum_sum_ = generate_layer_fn('cumsum') _cum_sum_ = generate_layer_fn('cumsum')
return _cum_sum_(**kwargs) return _cum_sum_(**kwargs)
def cumprod(x, dim=None, dtype=None, name=None):
"""
Compute the cumulative product of the input tensor x along a given dimension dim.
**Note**:
The first element of the result is the same as the first element of the input.
Args:
x (Tensor): the input tensor need to be cumproded.
dim (int): the dimension along which the input tensor will be accumulated. It need to be in the range of [-x.rank, x.rank), where x.rank means the dimensions of the input tensor x and -1 means the last dimension.
dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64, complex64, complex128. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
Returns:
Tensor, the result of cumprod operator.
Examples:
.. code-block:: python
import paddle
data = paddle.arange(12)
data = paddle.reshape(data, (3, 4))
# [[ 0 1 2 3 ]
# [ 4 5 6 7 ]
# [ 8 9 10 11]]
y = paddle.cumprod(data, dim=0)
# [[ 0 1 2 3]
# [ 0 5 12 21]
# [ 0 45 120 231]]
y = paddle.cumprod(data, dim=-1)
# [[ 0 0 0 0]
# [ 4 20 120 840]
# [ 8 72 720 7920]]
y = paddle.cumprod(data, dim=1, dtype='float64')
# [[ 0. 0. 0. 0.]
# [ 4. 20. 120. 840.]
# [ 8. 72. 720. 7920.]]
print(y.dtype)
# paddle.float64
"""
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = layers.cast(x, dtype)
if in_dygraph_mode():
return _C_ops.cumprod(x, 'dim', dim)
check_variable_and_dtype(x, "x", ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'], 'cumprod')
check_type(dim, 'dim', int, 'cumprod')
helper = LayerHelper('cumprod', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='cumprod', inputs={'X': x}, outputs={'Out': out}, attrs={'dim': dim})
return out
def isfinite(x, name=None): def isfinite(x, name=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册