From 64223620e92b1aac5b84af1f2bafad68d0384116 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 15 Mar 2022 14:47:14 +0800 Subject: [PATCH] [phi] Transfer lgamma, kldiv_loss, isclose, cumprod kernels into phi and pass the tests of these four kernels (#39770) * tranfer and pass the lgamma unittest * merge and pass the test * transfer kldiv_loss and kldiv_loss_grad; pass the unitest * trafer the isclose and cumprod kernel * change PT_REGISTER -> PD_REGISTER * fix by code review * fix by code review * fix * remove enforce include dependence from scalar * fix * fix by code review * fix by code review --- paddle/fluid/operators/cumprod_op.cc | 16 +- paddle/fluid/operators/cumprod_op.cu | 369 ------------------ paddle/fluid/operators/cumprod_op.h | 170 -------- paddle/fluid/operators/isclose_op.cc | 43 -- paddle/fluid/operators/isclose_op.cu | 85 ---- paddle/fluid/operators/isclose_op.h | 93 ----- paddle/fluid/operators/kldiv_loss_op.cc | 8 - paddle/fluid/operators/kldiv_loss_op.cu | 22 -- paddle/fluid/operators/kldiv_loss_op.h | 119 ------ paddle/fluid/operators/kldiv_loss_op_npu.cc | 3 +- paddle/fluid/operators/lgamma_op.cc | 30 +- paddle/fluid/operators/lgamma_op.cu | 59 --- paddle/fluid/operators/lgamma_op.h | 100 ----- paddle/fluid/operators/math/inclusive_scan.h | 18 +- paddle/phi/api/lib/utils/CMakeLists.txt | 2 +- paddle/phi/common/CMakeLists.txt | 1 + paddle/phi/common/scalar.cc | 35 ++ paddle/phi/common/scalar.h | 11 +- paddle/phi/kernels/cpu/cumprod_grad_kernel.cc | 113 ++++++ paddle/phi/kernels/cpu/cumprod_kernel.cc | 65 +++ paddle/phi/kernels/cpu/isclose_kernel.cc | 21 + .../phi/kernels/cpu/kldiv_loss_grad_kernel.cc | 22 ++ paddle/phi/kernels/cpu/kldiv_loss_kernel.cc | 23 ++ paddle/phi/kernels/cpu/lgamma_grad_kernel.cc | 20 + paddle/phi/kernels/cpu/lgamma_kernel.cc | 49 +++ paddle/phi/kernels/cumprod_grad_kernel.h | 28 ++ paddle/phi/kernels/cumprod_kernel.h | 26 ++ paddle/phi/kernels/funcs/cumprod.h | 52 +++ .../phi/kernels/funcs/elementwise_functor.h | 5 + paddle/phi/kernels/gpu/cumprod_grad_kernel.cu | 320 +++++++++++++++ paddle/phi/kernels/gpu/cumprod_kernel.cu | 60 +++ paddle/phi/kernels/gpu/isclose_kernel.cu | 22 ++ .../phi/kernels/gpu/kldiv_loss_grad_kernel.cu | 22 ++ paddle/phi/kernels/gpu/kldiv_loss_kernel.cu | 21 + paddle/phi/kernels/gpu/lgamma_grad_kernel.cu | 21 + paddle/phi/kernels/gpu/lgamma_kernel.cu | 41 ++ paddle/phi/kernels/impl/isclose_kernel_impl.h | 176 +++++++++ .../impl/kldiv_loss_grad_kernel_impl.h | 70 ++++ .../phi/kernels/impl/kldiv_loss_kernel_impl.h | 69 ++++ .../kernels/impl/lgamma_grad_kernel_impl.h | 47 +++ paddle/phi/kernels/isclose_kernel.h | 30 ++ paddle/phi/kernels/kldiv_loss_grad_kernel.h | 29 ++ paddle/phi/kernels/kldiv_loss_kernel.h | 29 ++ paddle/phi/kernels/lgamma_grad_kernel.h | 27 ++ paddle/phi/kernels/lgamma_kernel.h | 26 ++ paddle/phi/ops/compat/cumprod_sig.cc | 29 ++ paddle/phi/ops/compat/isclose_sig.cc | 50 +++ paddle/phi/ops/compat/kldiv_loss_sig.cc | 30 ++ paddle/phi/ops/compat/lgamma_sig.cc | 25 ++ 49 files changed, 1632 insertions(+), 1120 deletions(-) delete mode 100644 paddle/fluid/operators/cumprod_op.cu delete mode 100644 paddle/fluid/operators/cumprod_op.h delete mode 100644 paddle/fluid/operators/isclose_op.cu delete mode 100644 paddle/fluid/operators/isclose_op.h delete mode 100644 paddle/fluid/operators/kldiv_loss_op.cu delete mode 100644 paddle/fluid/operators/kldiv_loss_op.h delete mode 100644 paddle/fluid/operators/lgamma_op.cu delete mode 100644 paddle/fluid/operators/lgamma_op.h create mode 100644 paddle/phi/common/scalar.cc create mode 100644 paddle/phi/kernels/cpu/cumprod_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/cumprod_kernel.cc create mode 100644 paddle/phi/kernels/cpu/isclose_kernel.cc create mode 100644 paddle/phi/kernels/cpu/kldiv_loss_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/kldiv_loss_kernel.cc create mode 100644 paddle/phi/kernels/cpu/lgamma_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/lgamma_kernel.cc create mode 100644 paddle/phi/kernels/cumprod_grad_kernel.h create mode 100644 paddle/phi/kernels/cumprod_kernel.h create mode 100644 paddle/phi/kernels/funcs/cumprod.h create mode 100644 paddle/phi/kernels/gpu/cumprod_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/cumprod_kernel.cu create mode 100644 paddle/phi/kernels/gpu/isclose_kernel.cu create mode 100644 paddle/phi/kernels/gpu/kldiv_loss_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/kldiv_loss_kernel.cu create mode 100644 paddle/phi/kernels/gpu/lgamma_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/lgamma_kernel.cu create mode 100644 paddle/phi/kernels/impl/isclose_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/isclose_kernel.h create mode 100644 paddle/phi/kernels/kldiv_loss_grad_kernel.h create mode 100644 paddle/phi/kernels/kldiv_loss_kernel.h create mode 100644 paddle/phi/kernels/lgamma_grad_kernel.h create mode 100644 paddle/phi/kernels/lgamma_kernel.h create mode 100644 paddle/phi/ops/compat/cumprod_sig.cc create mode 100644 paddle/phi/ops/compat/isclose_sig.cc create mode 100644 paddle/phi/ops/compat/kldiv_loss_sig.cc create mode 100644 paddle/phi/ops/compat/lgamma_sig.cc diff --git a/paddle/fluid/operators/cumprod_op.cc b/paddle/fluid/operators/cumprod_op.cc index bff6673429d..90910bbbb20 100644 --- a/paddle/fluid/operators/cumprod_op.cc +++ b/paddle/fluid/operators/cumprod_op.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/cumprod_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { @@ -87,16 +88,3 @@ REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker, ops::CumprodGradOpMaker); REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp); - -REGISTER_OP_CPU_KERNEL( - cumprod, ops::CumprodOpCPUKernel, ops::CumprodOpCPUKernel, - ops::CumprodOpCPUKernel, ops::CumprodOpCPUKernel, - ops::CumprodOpCPUKernel>, - ops::CumprodOpCPUKernel>); - -REGISTER_OP_CPU_KERNEL( - cumprod_grad, ops::CumprodGradOpCPUKernel, - ops::CumprodGradOpCPUKernel, ops::CumprodGradOpCPUKernel, - ops::CumprodGradOpCPUKernel, - ops::CumprodGradOpCPUKernel>, - ops::CumprodGradOpCPUKernel>); diff --git a/paddle/fluid/operators/cumprod_op.cu b/paddle/fluid/operators/cumprod_op.cu deleted file mode 100644 index f792d683291..00000000000 --- a/paddle/fluid/operators/cumprod_op.cu +++ /dev/null @@ -1,369 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "paddle/fluid/operators/cumprod_op.h" -#include "paddle/fluid/operators/math/inclusive_scan.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { - -template -struct MultiplyFunctor { - HOSTDEVICE T operator()(T a, T b) const { return a * b; } -}; - -template -class CumprodOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - auto dim = ctx.Attr("dim"); - size_t outer_dim, mid_dim, inner_dim; - GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim); - - const auto *x_data = x->data(); - auto *y_data = y->mutable_data(ctx.GetPlace()); - const auto &dev_ctx = - ctx.template device_context(); - math::InclusiveScan>( - x_data, y_data, outer_dim, mid_dim, inner_dim, static_cast(1), - MultiplyFunctor(), /*reverse=*/false, dev_ctx); - } -}; - -template -struct IsZeroFunctor { - HOSTDEVICE bool operator()(T x) const { return x == static_cast(0); } -}; - -template -struct CumprodGradFunctorExceptFirstZero { - HOSTDEVICE CumprodGradFunctorExceptFirstZero( - const T *x, const T *y, const T *dy_mul_y_reversed_cumsum, - const uint8_t *zero_mask, size_t mid_dim, size_t inner_dim, T *dx, - int64_t *first_zero_idx, T *x_filled_one) - : x_(x), - y_(y), - dy_mul_y_reversed_cumsum_(dy_mul_y_reversed_cumsum), - zero_mask_(zero_mask), - mid_dim_(mid_dim), - inner_dim_(inner_dim), - dx_(dx), - first_zero_idx_(first_zero_idx), - x_filled_one_(x_filled_one) {} - - HOSTDEVICE void operator()(size_t idx) const { - auto inner_idx = idx % inner_dim_; - auto outer_idx = idx / (mid_dim_ * inner_dim_); - auto mid_idx = (idx - inner_idx) / inner_dim_ % mid_dim_; - auto mask = zero_mask_[idx]; - bool should_fill_one = true; - - if (mask == 0) { - dx_[idx] = dy_mul_y_reversed_cumsum_[idx] / x_[idx]; - if (mid_idx == mid_dim_ - 1) { - // record first zero position as -1, i.e., no zero - first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = -1; - } - } else if (mid_idx > 0) { // mask > 0 - if (zero_mask_[idx - inner_dim_] > 0) { // not first zero - dx_[idx] = 0; - should_fill_one = false; - } else { - // idx is the first zero position, it should be recorded - dx_[idx] = y_[idx - inner_dim_]; - first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = mid_idx; - } - } else { // the first zero position is index 0 - dx_[idx] = 1; - first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0; - } - - x_filled_one_[idx] = should_fill_one ? 1 : x_[idx]; - } - - private: - const T *x_; - const T *y_; - const T *dy_mul_y_reversed_cumsum_; - const uint8_t *zero_mask_; - size_t mid_dim_; - size_t inner_dim_; - T *dx_; - int64_t *first_zero_idx_; - T *x_filled_one_; -}; - -template -struct FillFirstZeroPositionGradFunctor { - HOSTDEVICE FillFirstZeroPositionGradFunctor(const int64_t *first_zero_idx, - const T *grad_value, - size_t mid_dim, size_t inner_dim, - T *dx) - : first_zero_idx_(first_zero_idx), - grad_value_(grad_value), - mid_dim_(mid_dim), - inner_dim_(inner_dim), - dx_(dx) {} - - HOSTDEVICE void operator()(size_t idx) const { - auto outer_idx = idx / inner_dim_; - auto inner_idx = idx % inner_dim_; - auto mid_idx = first_zero_idx_[idx]; - if (mid_idx >= 0) { - auto full_idx = - outer_idx * mid_dim_ * inner_dim_ + mid_idx * inner_dim_ + inner_idx; - dx_[full_idx] *= grad_value_[full_idx]; - } - } - - private: - const int64_t *first_zero_idx_; - const T *grad_value_; - size_t mid_dim_; - size_t inner_dim_; - T *dx_; -}; - -/* -Reference to -https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp -input: x, y, dL/dy -output: dL/dx -dL/dx[i] = sum{0<=j k, dL/dx[i] = 0; -i < k, dL/dx[i] = 1/x[i]*sum{i<=j k - dx[i] = 0; - x_filled_one[i] = x[i]; - } - } - } -} -T = reversed_cumsum(dy[j]*cumprod(x_filled_one[j])); -if (zero_index != -1) { - dx[zero_index] *= T[zero_index]; -} -*/ - -template -class CumprodGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto *x = ctx.Input("X"); - const auto *y = ctx.Input("Out"); - const auto *dy = - ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto dim = ctx.Attr("dim"); - - size_t outer_dim, mid_dim, inner_dim; - GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim); - if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return; - - size_t numel = outer_dim * mid_dim * inner_dim; - - const auto *x_data = x->data(); - const auto *y_data = y->data(); - const auto *dy_data = dy->data(); - - auto place = ctx.GetPlace(); - const auto &dev_ctx = - ctx.template device_context(); - auto *dx_data = dx->mutable_data(place); - - // deal with complex - const T *x_data_deal; - const T *y_data_deal; - memory::AllocationPtr x_conj; - memory::AllocationPtr y_conj; - if (framework::IsComplex::value) { - x_conj = memory::Alloc(place, numel * sizeof(T)); - auto *x_data_conj = reinterpret_cast(x_conj->ptr()); - y_conj = memory::Alloc(place, numel * sizeof(T)); - auto *y_data_conj = reinterpret_cast(y_conj->ptr()); - - platform::ForRange for_range_x(dev_ctx, - numel); - phi::funcs::ConjFunctor functor_x(x_data, numel, x_data_conj); - for_range_x(functor_x); - - platform::ForRange for_range_y(dev_ctx, - numel); - phi::funcs::ConjFunctor functor_y(y_data, numel, y_data_conj); - for_range_y(functor_y); - x_data_deal = x_data_conj; - y_data_deal = y_data_conj; - } else { - x_data_deal = x_data; - y_data_deal = y_data; - } - -// Step 1: find cummax-ed zero mask of x -#ifdef PADDLE_WITH_CUDA - const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream()); -#else - const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); -#endif - auto zero_mask_without_cummax = - memory::Alloc(place, numel * sizeof(uint8_t)); - auto *zero_mask_without_cummax_data = - reinterpret_cast(zero_mask_without_cummax->ptr()); - thrust::transform( - exec_policy, thrust::device_pointer_cast(x_data_deal), - thrust::device_pointer_cast(x_data_deal) + numel, - thrust::device_pointer_cast(zero_mask_without_cummax_data), - IsZeroFunctor()); - - auto zero_mask = memory::Alloc(place, numel * sizeof(uint8_t)); - auto *zero_mask_data = reinterpret_cast(zero_mask->ptr()); - math::InclusiveScan( - zero_mask_without_cummax_data, zero_mask_data, outer_dim, mid_dim, - inner_dim, static_cast(0), cub::Max(), /*reverse=*/false, - dev_ctx); - zero_mask_without_cummax = nullptr; - - // Step 2: calculate reversed cumsum(dy * y) - auto dy_mul_y = memory::Alloc(place, numel * sizeof(T)); - auto *dy_mul_y_data = reinterpret_cast(dy_mul_y->ptr()); - thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data), - thrust::device_pointer_cast(dy_data) + numel, - thrust::device_pointer_cast(y_data_deal), - thrust::device_pointer_cast(dy_mul_y_data), - MultiplyFunctor()); - - auto dy_mul_y_reversed_cumsum = memory::Alloc(place, numel * sizeof(T)); - auto *dy_mul_y_reversed_cumsum_data = - reinterpret_cast(dy_mul_y_reversed_cumsum->ptr()); - math::InclusiveScan( - dy_mul_y_data, dy_mul_y_reversed_cumsum_data, outer_dim, mid_dim, - inner_dim, static_cast(0), cub::Sum(), /*reverse=*/true, dev_ctx); - - // Step 3: calculate the gradient value except the first zero position. - // The gradient value of the first zero position is filled with out[idx-1], - // while the gradient value of the other positions are calculated out - // completely. This functor also: - // (1) find the first zero index, i.e., first_zero_idx_data. - // (2) fill x_filled_one, which satifies - // x_filled_one[i] = x[i], i > pos - // x_filled_one[i] = 1, i <= pos - auto first_zero_idx = - memory::Alloc(place, outer_dim * inner_dim * sizeof(int64_t)); - auto *first_zero_idx_data = - reinterpret_cast(first_zero_idx->ptr()); - auto *x_filled_one_data = dy_mul_y_data; // reuse former allocated memory - platform::ForRange for_range(dev_ctx, numel); - CumprodGradFunctorExceptFirstZero functor_except_first_zero( - x_data_deal, y_data_deal, dy_mul_y_reversed_cumsum_data, zero_mask_data, - mid_dim, inner_dim, dx_data, first_zero_idx_data, x_filled_one_data); - for_range(functor_except_first_zero); - - // Step 4: calculate cumprod of x_filled_one - auto *x_filled_one_cumprod_data = - dy_mul_y_reversed_cumsum_data; // reuse former allocated memory - math::InclusiveScan>( - x_filled_one_data, x_filled_one_cumprod_data, outer_dim, mid_dim, - inner_dim, static_cast(1), MultiplyFunctor(), /*reverse=*/false, - dev_ctx); - - // Step 5: calculate reversed cumsum(dy * x_filled_one_cumprod) - auto *dy_mul_x_filled_one_cumprod = - dy_mul_y_data; // reuse former allocated memory - thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data), - thrust::device_pointer_cast(dy_data) + numel, - thrust::device_pointer_cast(x_filled_one_cumprod_data), - thrust::device_pointer_cast(dy_mul_x_filled_one_cumprod), - MultiplyFunctor()); - auto *dy_mul_x_filled_one_cumprod_reversed_cumsum = - dy_mul_y_reversed_cumsum_data; // reuse former allocated memory - math::InclusiveScan( - dy_mul_x_filled_one_cumprod, - dy_mul_x_filled_one_cumprod_reversed_cumsum, outer_dim, mid_dim, - inner_dim, static_cast(0), cub::Sum(), - /*reverse=*/true, dev_ctx); - - // Step 6: fill zero pos gradient value - platform::ForRange - for_range_fill_zero_pos_grad(dev_ctx, outer_dim * inner_dim); - FillFirstZeroPositionGradFunctor fill_first_zero_pos_grad_functor( - first_zero_idx_data, dy_mul_x_filled_one_cumprod_reversed_cumsum, - mid_dim, inner_dim, dx_data); - for_range_fill_zero_pos_grad(fill_first_zero_pos_grad_functor); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - cumprod, ops::CumprodOpCUDAKernel, ops::CumprodOpCUDAKernel, - ops::CumprodOpCUDAKernel, ops::CumprodOpCUDAKernel, - ops::CumprodOpCUDAKernel>, - ops::CumprodOpCUDAKernel>); - -REGISTER_OP_CUDA_KERNEL( - cumprod_grad, ops::CumprodGradOpCUDAKernel, - ops::CumprodGradOpCUDAKernel, ops::CumprodGradOpCUDAKernel, - ops::CumprodGradOpCUDAKernel, - ops::CumprodGradOpCUDAKernel>, - ops::CumprodGradOpCUDAKernel>); diff --git a/paddle/fluid/operators/cumprod_op.h b/paddle/fluid/operators/cumprod_op.h deleted file mode 100644 index 74ed2008ae9..00000000000 --- a/paddle/fluid/operators/cumprod_op.h +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -static void GetCumprodDimInfo(const framework::DDim& dim, int cumprod_dim, - size_t* outer_dim, size_t* mid_dim, - size_t* inner_dim) { - PADDLE_ENFORCE_GE( - cumprod_dim, -dim.size(), - platform::errors::InvalidArgument( - "The input dim of CumprodOp should be larger than the opposite " - "rank of input x which is %d.But received dim=%d", - -dim.size(), cumprod_dim)); - PADDLE_ENFORCE_LT(cumprod_dim, dim.size(), - platform::errors::InvalidArgument( - "The input dim of CumprodOp should be smaller than the " - "rank of input x which is %d.But received dim=%d", - dim.size(), cumprod_dim)); - if (cumprod_dim < 0) cumprod_dim += dim.size(); - - *outer_dim = 1; - for (int i = 0; i < cumprod_dim; ++i) { - *outer_dim *= dim[i]; - } - *mid_dim = dim[cumprod_dim]; - *inner_dim = 1; - for (int i = cumprod_dim + 1; i < dim.size(); ++i) { - *inner_dim *= dim[i]; - } -} - -template -class CumprodOpCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* out = context.Output("Out"); - int dim = context.Attr("dim"); - - auto* x_data = x->data(); - auto* out_data = out->mutable_data(context.GetPlace()); - framework::DDim shape = x->dims(); - - size_t outer_dim = 1; - size_t mid_dim = 1; - size_t inner_dim = 1; - GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim); - - for (size_t i = 0; i < outer_dim; i++) { - for (size_t j = 0; j < mid_dim; j++) { - for (size_t k = 0; k < inner_dim; k++) { - size_t pos = i * mid_dim * inner_dim + j * inner_dim + k; - if (j == 0) { - out_data[pos] = x_data[pos]; - } else { - out_data[pos] = out_data[pos - inner_dim] * x_data[pos]; - } - } - } - } - } -}; - -template -class CumprodGradOpCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const { - const Tensor* d_out = context.Input(framework::GradVarName("Out")); - const Tensor* x = context.Input("X"); - const Tensor* out = context.Input("Out"); - - int dim = context.Attr("dim"); - framework::DDim shape = x->dims(); - Tensor* d_x = context.Output(framework::GradVarName("X")); - - auto* d_out_data = d_out->data(); - auto* x_data = x->data(); - auto* out_data = out->data(); - auto* d_x_data = d_x->mutable_data(context.GetPlace()); - - auto place = context.GetPlace(); - const auto& dev_ctx = - context.template device_context(); - - size_t outer_dim = 1; - size_t mid_dim = 1; - size_t inner_dim = 1; - GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim); - size_t numel = outer_dim * mid_dim * inner_dim; - - // deal with complex - const T* x_data_deal; - const T* out_data_deal; - memory::AllocationPtr x_conj; - memory::AllocationPtr out_conj; - if (framework::IsComplex::value) { - x_conj = memory::Alloc(place, numel * sizeof(T)); - auto* x_data_conj = reinterpret_cast(x_conj->ptr()); - out_conj = memory::Alloc(place, numel * sizeof(T)); - auto* out_data_conj = reinterpret_cast(out_conj->ptr()); - - platform::ForRange for_range_x(dev_ctx, - numel); - phi::funcs::ConjFunctor functor_x(x_data, numel, x_data_conj); - for_range_x(functor_x); - - platform::ForRange for_range_out(dev_ctx, - numel); - phi::funcs::ConjFunctor functor_out(out_data, numel, out_data_conj); - for_range_out(functor_out); - - x_data_deal = x_data_conj; - out_data_deal = out_data_conj; - } else { - x_data_deal = x_data; - out_data_deal = out_data; - } - - for (size_t i = 0; i < outer_dim; i++) { - for (size_t k = 0; k < inner_dim; k++) { - for (size_t j = 0; j < mid_dim; j++) { - size_t index = i * mid_dim * inner_dim + j * inner_dim + k; - d_x_data[index] = 0; - for (size_t n = 0; n < mid_dim; n++) { - size_t pos = i * mid_dim * inner_dim + n * inner_dim + k; - T elem; - if (j == 0) { - elem = d_out_data[pos]; - } else { - elem = d_out_data[pos] * out_data_deal[index - inner_dim]; - } - if (pos > index) { - for (size_t m = index + inner_dim; m <= pos; m += inner_dim) { - elem *= x_data_deal[m]; - } - } else if (pos < index) { - elem = static_cast(0); - } - d_x_data[index] += elem; - } - } - } - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/isclose_op.cc b/paddle/fluid/operators/isclose_op.cc index 0ae7a9fa02f..8668de4d3a6 100644 --- a/paddle/fluid/operators/isclose_op.cc +++ b/paddle/fluid/operators/isclose_op.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/isclose_op.h" #include #include #include "paddle/fluid/framework/op_registry.h" @@ -23,45 +22,6 @@ namespace paddle { namespace operators { -template -struct GetTensorValue { - T operator()(const platform::CPUDeviceContext& dev_ctx, - const framework::Tensor& tensor) const { - return *(tensor.data()); - } -}; - -template -struct IscloseFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& other, - const double rtol, const double atol, bool equal_nan, - framework::Tensor* output) { - auto* in_a = in.data(); - auto* in_b = other.data(); - auto* out_data = output->mutable_data(ctx.GetPlace()); - auto num = in.numel(); - // *out_data = true; - for (int i = 0; i < num; i++) { - out_data[i] = true; - } - for (int i = 0; i < num; i++) { - const T a = in_a[i], b = in_b[i]; - bool val; - if (std::isnan(a) || std::isnan(b)) { - val = equal_nan && std::isnan(a) == std::isnan(b); - } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); - val = a == b || left <= right || diff <= 1e-15; - } - // *out_data &= val; - out_data[i] = val; - } - } -}; - class IscloseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -154,12 +114,9 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference { } // namespace paddle namespace ops = paddle::operators; -using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR( isclose, ops::IscloseOp, ops::IscloseOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::IscloseOpVarTypeInference); -REGISTER_OP_CPU_KERNEL(isclose, ops::IscloseKernel, - ops::IscloseKernel); diff --git a/paddle/fluid/operators/isclose_op.cu b/paddle/fluid/operators/isclose_op.cu deleted file mode 100644 index 09710ba0c69..00000000000 --- a/paddle/fluid/operators/isclose_op.cu +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/isclose_op.h" - -namespace paddle { -namespace operators { - -template -struct GetTensorValue { - T operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor& tensor) const { - const T* data = tensor.data(); - T value; - const auto gpu_place = dev_ctx.GetPlace(); - memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T), - dev_ctx.stream()); - return value; - } -}; - -template -__global__ void IscloseCUDAKernel(const T* in_data, const T* other_data, - const double rtol, const double atol, - bool equal_nan, int num, bool* out_data) { - unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; - bool val; - for (int i = idx; i < num; i += blockDim.x * gridDim.x) { - const T a = in_data[i], b = other_data[i]; - if (isnan(a) || isnan(b)) { - val = equal_nan && isnan(a) == isnan(b); - } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); - val = a == b || left <= right || diff <= 1e-15; - } - out_data[i] = val; - // if (!val) *out_data = false; - } -} - -template -struct IscloseFunctor { - void operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor& in, const framework::Tensor& other, - const double rtol, const double atol, bool equal_nan, - framework::Tensor* output) { - int num = in.numel(); - const T* in_data = in.data(); - const T* other_data = other.data(); - bool* out_data = output->mutable_data(dev_ctx.GetPlace()); - int block = 1024; - int grid = (block - 1 + num) / block; - grid = (grid > block) ? block : grid; -#ifdef PADDLE_WITH_HIP - hipMemset(out_data, true, num * sizeof(bool)); -#else - cudaMemset(out_data, true, num * sizeof(bool)); -#endif - IscloseCUDAKernel<<>>( - in_data, other_data, rtol, atol, equal_nan, num, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(isclose, ops::IscloseKernel, - ops::IscloseKernel); diff --git a/paddle/fluid/operators/isclose_op.h b/paddle/fluid/operators/isclose_op.h deleted file mode 100644 index cde5d2afbf0..00000000000 --- a/paddle/fluid/operators/isclose_op.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -struct GetTensorValue { - T operator()(const platform::DeviceContext& ctx, - const framework::Tensor& tensor) const; -}; - -template -struct IscloseFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& other, const float rtol, - const float atol, bool equal_nan, framework::Tensor* output); -}; - -template -class IscloseKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - // get attrs - bool equal_nan = ctx.Attr("equal_nan"); - // get input/output - const auto* input = ctx.Input("Input"); - const auto* other = ctx.Input("Other"); - auto* out = ctx.Output("Out"); - - double rtol_v = std::stod(ctx.Attr("rtol")); - double atol_v = std::stod(ctx.Attr("atol")); - - auto& dev_ctx = ctx.template device_context(); - GetTensorValue get_tensor_value; - if (ctx.HasInput("Rtol")) { - const auto* rtol = ctx.Input("Rtol"); - PADDLE_ENFORCE_EQ( - rtol->numel(), 1, - platform::errors::InvalidArgument( - "Input(Rtol) size must be 1, but get %d.", rtol->numel())); - PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(rtol->dtype()), - framework::proto::VarType::FP64, - platform::errors::InvalidArgument( - "Input(Rtol) type must be double, but get %s.", - framework::DataTypeToString( - framework::TransToProtoVarType(rtol->dtype())))); - rtol_v = get_tensor_value(dev_ctx, *rtol); - } - if (ctx.HasInput("Atol")) { - const auto* atol = ctx.Input("Atol"); - PADDLE_ENFORCE_EQ( - atol->numel(), 1, - platform::errors::InvalidArgument( - "Input(Atol) size must be 1, but get %d", atol->numel())); - PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(atol->dtype()), - framework::proto::VarType::FP64, - platform::errors::InvalidArgument( - "Input(Atol) type must be double, but get %s", - framework::DataTypeToString( - framework::TransToProtoVarType(atol->dtype())))); - atol_v = get_tensor_value(dev_ctx, *atol); - } - - IscloseFunctor()(dev_ctx, *input, *other, rtol_v, atol_v, - equal_nan, out); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index a78d8ec1014..dcd98054b05 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -9,7 +9,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/kldiv_loss_op.h" #include #include #include "paddle/fluid/framework/op_registry.h" @@ -177,10 +176,3 @@ REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, ops::KLDivLossOpGradMaker); REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad, ops::KLDivLossGradNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL( - kldiv_loss, ops::KLDivLossKernel, - ops::KLDivLossKernel); -REGISTER_OP_CPU_KERNEL( - kldiv_loss_grad, - ops::KLDivLossGradKernel, - ops::KLDivLossGradKernel); diff --git a/paddle/fluid/operators/kldiv_loss_op.cu b/paddle/fluid/operators/kldiv_loss_op.cu deleted file mode 100644 index 5226cb8c08e..00000000000 --- a/paddle/fluid/operators/kldiv_loss_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/kldiv_loss_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - kldiv_loss, - ops::KLDivLossKernel, - ops::KLDivLossKernel); -REGISTER_OP_CUDA_KERNEL( - kldiv_loss_grad, - ops::KLDivLossGradKernel, - ops::KLDivLossGradKernel); diff --git a/paddle/fluid/operators/kldiv_loss_op.h b/paddle/fluid/operators/kldiv_loss_op.h deleted file mode 100644 index 5a6ef06f5eb..00000000000 --- a/paddle/fluid/operators/kldiv_loss_op.h +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/hostdevice.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using Array1 = Eigen::DSizes; - -template -struct KLDivLossForward { - HOSTDEVICE KLDivLossForward() {} - - HOSTDEVICE T operator()(const T& target, const T& input) const { - if (target <= 0) { - return 0; - } else { - return target * (std::log(target) - input); - } - } -}; - -template -struct KLDivLossBackward { - HOSTDEVICE KLDivLossBackward() {} - - HOSTDEVICE T operator()(const T& target, const T& grad) const { - if (target <= 0) { - return 0; - } else { - return static_cast(-1.) * grad; - } - } -}; - -template -class KLDivLossKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context().eigen_device(); - auto* input = ctx.Input("X"); - auto* target = ctx.Input("Target"); - auto* loss = ctx.Output("Loss"); - auto reduction = ctx.Attr("reduction"); - - const int n = input->dims()[0]; - - loss->mutable_data(ctx.GetPlace()); - auto input_t = framework::EigenVector::Flatten(*input); - auto target_t = framework::EigenVector::Flatten(*target); - auto loss_t = framework::EigenVector::Flatten(*loss); - auto output = target_t.binaryExpr(input_t, KLDivLossForward()); - if ("none" == reduction) { - loss_t.device(place) = output; - } else if ("batchmean" == reduction) { - auto output_sum = output.sum(); - if (n > 0) { - loss_t.device(place) = output_sum / output_sum.constant(n); - } else { - loss_t.device(place) = output_sum; - } - } else if ("mean" == reduction) { - loss_t.device(place) = output.mean(); - } else if ("sum" == reduction) { - loss_t.device(place) = output.sum(); - } - } -}; - -template -class KLDivLossGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context().eigen_device(); - auto* target = ctx.Input("Target"); - auto reduction = ctx.Attr("reduction"); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); - - const int n = input_grad->dims()[0]; - const int numel = input_grad->numel(); - const int expand = numel / loss_grad->numel(); - - input_grad->mutable_data(ctx.GetPlace()); - - auto target_t = framework::EigenVector::Flatten(*target); - - auto input_grad_t = framework::EigenVector::Flatten(*input_grad); - auto loss_grad_t = framework::EigenVector::Flatten(*loss_grad); - - auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); - auto grad_t = target_t * loss_grad_expand; - input_grad_t.device(place) = - target_t.binaryExpr(grad_t, KLDivLossBackward()); - - if ("mean" == reduction) { - input_grad_t.device(place) = input_grad_t / static_cast(numel); - } else if ("batchmean" == reduction) { - input_grad_t.device(place) = input_grad_t / static_cast(n); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/kldiv_loss_op_npu.cc b/paddle/fluid/operators/kldiv_loss_op_npu.cc index 322ae5df4cb..eac181489aa 100644 --- a/paddle/fluid/operators/kldiv_loss_op_npu.cc +++ b/paddle/fluid/operators/kldiv_loss_op_npu.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the Licnse. */ -#include "paddle/fluid/operators/kldiv_loss_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/lgamma_op.cc b/paddle/fluid/operators/lgamma_op.cc index 148fb05afcf..72c6b41efa9 100644 --- a/paddle/fluid/operators/lgamma_op.cc +++ b/paddle/fluid/operators/lgamma_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/lgamma_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -35,16 +38,6 @@ $$out = log\Gamma(x)$$ class LgammaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Lgamma"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Lgamma"); - - auto in_dims = ctx->GetInputDim("X"); - - ctx->SetOutputDim("Out", in_dims); - ctx->ShareLoD("X", "Out"); - } }; template @@ -83,17 +76,12 @@ class LgammaGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(lgamma, LgammaInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(lgamma, ops::LgammaOp, ops::LgammaOpMaker, ops::LgammaGradMaker, - ops::LgammaGradMaker); + ops::LgammaGradMaker, + LgammaInferShapeFunctor); REGISTER_OPERATOR(lgamma_grad, ops::LgammaGradOp); - -REGISTER_OP_CPU_KERNEL( - lgamma, ops::LgammaKernel, - ops::LgammaKernel) - -REGISTER_OP_CPU_KERNEL( - lgamma_grad, - ops::LgammaGradKernel, - ops::LgammaGradKernel); diff --git a/paddle/fluid/operators/lgamma_op.cu b/paddle/fluid/operators/lgamma_op.cu deleted file mode 100644 index b9f273727b0..00000000000 --- a/paddle/fluid/operators/lgamma_op.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/lgamma_op.h" - -namespace paddle { -namespace operators { - -template -struct CudaLgammaFunctor { - __device__ __forceinline__ T operator()(const T x) const { - return Eigen::numext::lgamma(x); - } -}; - -template -class LgammaKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - auto& dev_ctx = context.device_context(); - std::vector ins = {x}; - std::vector outs = {out}; - auto functor = CudaLgammaFunctor(); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - lgamma, ops::LgammaKernel, - ops::LgammaKernel); - -REGISTER_OP_CUDA_KERNEL( - lgamma_grad, - ops::LgammaGradKernel, - ops::LgammaGradKernel); diff --git a/paddle/fluid/operators/lgamma_op.h b/paddle/fluid/operators/lgamma_op.h deleted file mode 100644 index 674054e7457..00000000000 --- a/paddle/fluid/operators/lgamma_op.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/for_range.h" - -namespace paddle { -namespace operators { - -template -struct LgammaFunctor { - LgammaFunctor(const T* input, T* output, int64_t numel) - : input_(input), output_(output), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - output_[idx] = Eigen::numext::lgamma(input_[idx]); - } - - private: - const T* input_; - T* output_; - int64_t numel_; -}; - -template -struct LgammaGradFunctor { - LgammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel) - : dout_(dout), x_(x), output_(output), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]); - } - - private: - const T* dout_; - const T* x_; - T* output_; - int64_t numel_; -}; - -using Tensor = framework::Tensor; - -template -class LgammaKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* out = context.Output("Out"); - - auto numel = x->numel(); - auto* x_data = x->data(); - auto* out_data = out->mutable_data(context.GetPlace(), - size_t(x->numel() * sizeof(T))); - - auto& dev_ctx = context.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - LgammaFunctor functor(x_data, out_data, numel); - for_range(functor); - } -}; - -template -class LgammaGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - const framework::Tensor* x = ctx.Input("X"); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); - - auto numel = d_out->numel(); - auto* dout_data = d_out->data(); - auto* x_data = x->data(); - auto* dx_data = d_x->mutable_data( - ctx.GetPlace(), static_cast(numel * sizeof(T))); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - LgammaGradFunctor functor(dout_data, x_data, dx_data, numel); - for_range(functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/inclusive_scan.h b/paddle/fluid/operators/math/inclusive_scan.h index 9994ccc10cb..b77e2345036 100644 --- a/paddle/fluid/operators/math/inclusive_scan.h +++ b/paddle/fluid/operators/math/inclusive_scan.h @@ -34,10 +34,10 @@ namespace paddle { namespace operators { namespace math { -template +template static void CubInclusiveScan(InputIterator x_iter, OutputIterator y_iter, - size_t n, BinaryOp op, - const platform::CUDADeviceContext &dev_ctx) { + size_t n, BinaryOp op, const Context &dev_ctx) { memory::AllocationPtr allocation; void *temp_storage = nullptr; size_t temp_storage_bytes = 0; @@ -185,11 +185,10 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y, } } -template +template static void InclusiveScanInnerDim(const T *x, T *y, size_t outer_dim, size_t inner_dim, T init, BinaryOp op, - bool reverse, - const platform::CUDADeviceContext &dev_ctx) { + bool reverse, const Context &dev_ctx) { constexpr size_t kThreadNumX = 16; constexpr size_t kThreadNumY = 32; @@ -209,10 +208,10 @@ static void InclusiveScanInnerDim(const T *x, T *y, size_t outer_dim, } } -template +template void InclusiveScan(const T *x, T *y, size_t outer_dim, size_t mid_dim, size_t inner_dim, T init, BinaryOp op, bool reverse, - const platform::CUDADeviceContext &dev_ctx) { + const Context &dev_ctx) { if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return; if (outer_dim == 1 && inner_dim == 1) { @@ -224,8 +223,7 @@ void InclusiveScan(const T *x, T *y, size_t outer_dim, size_t mid_dim, CubInclusiveScan(x, y, mid_dim, op, dev_ctx); } } else if (inner_dim != 1) { - platform::ForRange for_range( - dev_ctx, outer_dim * inner_dim); + platform::ForRange for_range(dev_ctx, outer_dim * inner_dim); if (reverse) { for_range( InclusiveScanOuterOrMidDimFunctor( diff --git a/paddle/phi/api/lib/utils/CMakeLists.txt b/paddle/phi/api/lib/utils/CMakeLists.txt index 6d056b54b70..271a58222f0 100644 --- a/paddle/phi/api/lib/utils/CMakeLists.txt +++ b/paddle/phi/api/lib/utils/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS -tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits) +tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits scalar) diff --git a/paddle/phi/common/CMakeLists.txt b/paddle/phi/common/CMakeLists.txt index 85a1424ee34..0947870dcd3 100644 --- a/paddle/phi/common/CMakeLists.txt +++ b/paddle/phi/common/CMakeLists.txt @@ -1 +1,2 @@ cc_library(phi_place SRCS place.cc) +cc_library(scalar SRCS scalar.cc) diff --git a/paddle/phi/common/scalar.cc b/paddle/phi/common/scalar.cc new file mode 100644 index 00000000000..5cd55c1e88b --- /dev/null +++ b/paddle/phi/common/scalar.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/common/scalar.h" + +#include "paddle/phi/core/enforce.h" + +namespace paddle { +namespace experimental { + +// NOTE(xiongkun): why we put definition here? +// test_custom_op can't include enforce.h, because enforce.h includes gflags. +// so we decouple the include dependence of enforce.h by link. +void ThrowTensorConvertError(int num) { + PADDLE_ENFORCE_EQ(num, + 1, + phi::errors::InvalidArgument( + "The Scalar only supports Tensor with 1 element, but " + "now Tensor has `%d` elements", + num)); +} + +} // namespace experimental +} // namespace paddle diff --git a/paddle/phi/common/scalar.h b/paddle/phi/common/scalar.h index 72cef89d300..5134f4eb726 100644 --- a/paddle/phi/common/scalar.h +++ b/paddle/phi/common/scalar.h @@ -19,9 +19,12 @@ limitations under the License. */ #include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/include/tensor.h" + namespace paddle { namespace experimental { +void ThrowTensorConvertError(int); + template class ScalarBase { public: @@ -104,11 +107,7 @@ class ScalarBase { // The Tensor must have one dim ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT is_from_tensor_ = true; - PD_CHECK( - tensor.numel() == 1, - "The Scalar only supports Tensor with 1 element, but now Tensor has `", - tensor.numel(), - "` element."); + ThrowTensorConvertError(tensor.numel()); switch (dtype_) { case DataType::FLOAT32: data_.f32 = tensor.template data()[0]; @@ -156,6 +155,8 @@ class ScalarBase { CopyScalar(other, this); } + // NOTE(xiongkun): some op need to judge the dtype of the Scalar, we expose a + // interface. bool FromTensor() const { return is_from_tensor_; } void SetFromTensor(bool from_tensor) { is_from_tensor_ = from_tensor; } diff --git a/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc b/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc new file mode 100644 index 00000000000..a25f9650fc5 --- /dev/null +++ b/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cumprod_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/cumprod.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +// NOTE(@xiongkun): use of IsComplex<> +#include "paddle/fluid/framework/data_type.h" + +namespace phi { +template +void CumprodGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& d_out, + int dim, + DenseTensor* d_x) { + DDim shape = x.dims(); + + auto* d_out_data = d_out.data(); + auto* x_data = x.data(); + auto* out_data = out.data(); + auto* d_x_data = dev_ctx.template Alloc(d_x); + + size_t outer_dim = 1; + size_t mid_dim = 1; + size_t inner_dim = 1; + GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim); + size_t numel = outer_dim * mid_dim * inner_dim; + + // deal with complex + const T* x_data_deal; + const T* out_data_deal; + Allocator::AllocationPtr x_conj; + Allocator::AllocationPtr out_conj; + if (paddle::framework::IsComplex::value) { + x_conj = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(T)); + auto* x_data_conj = reinterpret_cast(x_conj->ptr()); + out_conj = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(T)); + auto* out_data_conj = reinterpret_cast(out_conj->ptr()); + + phi::funcs::ForRange for_range_x(dev_ctx, numel); + phi::funcs::ConjFunctor functor_x(x_data, numel, x_data_conj); + for_range_x(functor_x); + + phi::funcs::ForRange for_range_out(dev_ctx, numel); + phi::funcs::ConjFunctor functor_out(out_data, numel, out_data_conj); + for_range_out(functor_out); + + x_data_deal = x_data_conj; + out_data_deal = out_data_conj; + } else { + x_data_deal = x_data; + out_data_deal = out_data; + } + + for (size_t i = 0; i < outer_dim; i++) { + for (size_t k = 0; k < inner_dim; k++) { + for (size_t j = 0; j < mid_dim; j++) { + size_t index = i * mid_dim * inner_dim + j * inner_dim + k; + d_x_data[index] = 0; + for (size_t n = 0; n < mid_dim; n++) { + size_t pos = i * mid_dim * inner_dim + n * inner_dim + k; + T elem; + if (j == 0) { + elem = d_out_data[pos]; + } else { + elem = d_out_data[pos] * out_data_deal[index - inner_dim]; + } + if (pos > index) { + for (size_t m = index + inner_dim; m <= pos; m += inner_dim) { + elem *= x_data_deal[m]; + } + } else if (pos < index) { + elem = static_cast(0); + } + d_x_data[index] += elem; + } + } + } + } +} +} // namespace phi +PD_REGISTER_KERNEL(cumprod_grad, + CPU, + ALL_LAYOUT, + phi::CumprodGradKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/cumprod_kernel.cc b/paddle/phi/kernels/cpu/cumprod_kernel.cc new file mode 100644 index 00000000000..aea338027f5 --- /dev/null +++ b/paddle/phi/kernels/cpu/cumprod_kernel.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cumprod_kernel.h" + +#include +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/cumprod.h" + +namespace phi { +template +void CumprodKernel(const Context& dev_ctx, + const DenseTensor& input, + int dim, + DenseTensor* out) { + const DenseTensor* x = &input; + auto* x_data = x->data(); + auto* out_data = dev_ctx.template Alloc(out); + DDim shape = x->dims(); + + size_t outer_dim = 1; + size_t mid_dim = 1; + size_t inner_dim = 1; + GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim); + + for (size_t i = 0; i < outer_dim; i++) { + for (size_t j = 0; j < mid_dim; j++) { + for (size_t k = 0; k < inner_dim; k++) { + size_t pos = i * mid_dim * inner_dim + j * inner_dim + k; + if (j == 0) { + out_data[pos] = x_data[pos]; + } else { + out_data[pos] = out_data[pos - inner_dim] * x_data[pos]; + } + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cumprod, + CPU, + ALL_LAYOUT, + phi::CumprodKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/isclose_kernel.cc b/paddle/phi/kernels/cpu/isclose_kernel.cc new file mode 100644 index 00000000000..633c6ba093e --- /dev/null +++ b/paddle/phi/kernels/cpu/isclose_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/isclose_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/isclose_kernel_impl.h" + +PD_REGISTER_KERNEL( + isclose, CPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/kldiv_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/kldiv_loss_grad_kernel.cc new file mode 100644 index 00000000000..f9399d38d71 --- /dev/null +++ b/paddle/phi/kernels/cpu/kldiv_loss_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kldiv_loss_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + kldiv_loss_grad, CPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/kldiv_loss_kernel.cc b/paddle/phi/kernels/cpu/kldiv_loss_kernel.cc new file mode 100644 index 00000000000..c462b8ec32c --- /dev/null +++ b/paddle/phi/kernels/cpu/kldiv_loss_kernel.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kldiv_loss_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h" + +namespace phi {} // namespace phi + +PD_REGISTER_KERNEL( + kldiv_loss, CPU, ALL_LAYOUT, phi::KLDivLossKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/lgamma_grad_kernel.cc b/paddle/phi/kernels/cpu/lgamma_grad_kernel.cc new file mode 100644 index 00000000000..116fa3f8d3f --- /dev/null +++ b/paddle/phi/kernels/cpu/lgamma_grad_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lgamma_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h" +PD_REGISTER_KERNEL( + lgamma_grad, CPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/lgamma_kernel.cc b/paddle/phi/kernels/cpu/lgamma_kernel.cc new file mode 100644 index 00000000000..d0226894089 --- /dev/null +++ b/paddle/phi/kernels/cpu/lgamma_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lgamma_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { +template +struct LgammaFunctor { + LgammaFunctor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = Eigen::numext::lgamma(input_[idx]); + } + + private: + const T* input_; + T* output_; + int64_t numel_; +}; + +template +void LgammaKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + phi::funcs::ForRange for_range(dev_ctx, numel); + LgammaFunctor functor(x_data, out_data, numel); + for_range(functor); +} +} // namespace phi + +PD_REGISTER_KERNEL(lgamma, CPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {} diff --git a/paddle/phi/kernels/cumprod_grad_kernel.h b/paddle/phi/kernels/cumprod_grad_kernel.h new file mode 100644 index 00000000000..b3cb17b28e0 --- /dev/null +++ b/paddle/phi/kernels/cumprod_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CumprodGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + int dim, + DenseTensor* dx); +} // phi diff --git a/paddle/phi/kernels/cumprod_kernel.h b/paddle/phi/kernels/cumprod_kernel.h new file mode 100644 index 00000000000..96d76cb0f43 --- /dev/null +++ b/paddle/phi/kernels/cumprod_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CumprodKernel(const Context& dev_ctx, + const DenseTensor& x, + int dim, + DenseTensor* out); +} // phi diff --git a/paddle/phi/kernels/funcs/cumprod.h b/paddle/phi/kernels/funcs/cumprod.h new file mode 100644 index 00000000000..ac40523c1c4 --- /dev/null +++ b/paddle/phi/kernels/funcs/cumprod.h @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +static void GetCumprodDimInfo(const DDim& dim, + int cumprod_dim, + size_t* outer_dim, + size_t* mid_dim, + size_t* inner_dim) { + PADDLE_ENFORCE_GE( + cumprod_dim, + -dim.size(), + phi::errors::InvalidArgument( + "The input dim of CumprodOp should be larger than the opposite " + "rank of input x which is %d.But received dim=%d", + -dim.size(), + cumprod_dim)); + PADDLE_ENFORCE_LT(cumprod_dim, + dim.size(), + phi::errors::InvalidArgument( + "The input dim of CumprodOp should be smaller than the " + "rank of input x which is %d.But received dim=%d", + dim.size(), + cumprod_dim)); + if (cumprod_dim < 0) cumprod_dim += dim.size(); + + *outer_dim = 1; + for (int i = 0; i < cumprod_dim; ++i) { + *outer_dim *= dim[i]; + } + *mid_dim = dim[cumprod_dim]; + *inner_dim = 1; + for (int i = cumprod_dim + 1; i < dim.size(); ++i) { + *inner_dim *= dim[i]; + } +} +} // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index f9e66836a62..ac262fe2d57 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -67,6 +67,11 @@ struct InverseMultiplyFunctor { } }; +template +struct IsZeroFunctor { + HOSTDEVICE bool operator()(T x) const { return x == static_cast(0); } +}; + // Divide #define DIV_ERROR_INFO \ "InvalidArgumentError: Integer division by zero encountered in " \ diff --git a/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu b/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu new file mode 100644 index 00000000000..6e871246292 --- /dev/null +++ b/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu @@ -0,0 +1,320 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cumprod_grad_kernel.h" + +#include +#include "paddle/fluid/operators/math/inclusive_scan.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/cumprod.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +// NOTE(@xiongkun): use of IsComplex<> +#include "paddle/fluid/framework/data_type.h" + +namespace phi { + +template +struct CumprodGradFunctorExceptFirstZero { + HOSTDEVICE CumprodGradFunctorExceptFirstZero( + const T *x, + const T *y, + const T *dy_mul_y_reversed_cumsum, + const uint8_t *zero_mask, + size_t mid_dim, + size_t inner_dim, + T *dx, + int64_t *first_zero_idx, + T *x_filled_one) + : x_(x), + y_(y), + dy_mul_y_reversed_cumsum_(dy_mul_y_reversed_cumsum), + zero_mask_(zero_mask), + mid_dim_(mid_dim), + inner_dim_(inner_dim), + dx_(dx), + first_zero_idx_(first_zero_idx), + x_filled_one_(x_filled_one) {} + + HOSTDEVICE void operator()(size_t idx) const { + auto inner_idx = idx % inner_dim_; + auto outer_idx = idx / (mid_dim_ * inner_dim_); + auto mid_idx = (idx - inner_idx) / inner_dim_ % mid_dim_; + auto mask = zero_mask_[idx]; + bool should_fill_one = true; + + if (mask == 0) { + dx_[idx] = dy_mul_y_reversed_cumsum_[idx] / x_[idx]; + if (mid_idx == mid_dim_ - 1) { + // record first zero position as -1, i.e., no zero + first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = -1; + } + } else if (mid_idx > 0) { // mask > 0 + if (zero_mask_[idx - inner_dim_] > 0) { // not first zero + dx_[idx] = 0; + should_fill_one = false; + } else { + // idx is the first zero position, it should be recorded + dx_[idx] = y_[idx - inner_dim_]; + first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = mid_idx; + } + } else { // the first zero position is index 0 + dx_[idx] = 1; + first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0; + } + + x_filled_one_[idx] = should_fill_one ? 1 : x_[idx]; + } + + private: + const T *x_; + const T *y_; + const T *dy_mul_y_reversed_cumsum_; + const uint8_t *zero_mask_; + size_t mid_dim_; + size_t inner_dim_; + T *dx_; + int64_t *first_zero_idx_; + T *x_filled_one_; +}; + +template +struct FillFirstZeroPositionGradFunctor { + HOSTDEVICE FillFirstZeroPositionGradFunctor(const int64_t *first_zero_idx, + const T *grad_value, + size_t mid_dim, + size_t inner_dim, + T *dx) + : first_zero_idx_(first_zero_idx), + grad_value_(grad_value), + mid_dim_(mid_dim), + inner_dim_(inner_dim), + dx_(dx) {} + + HOSTDEVICE void operator()(size_t idx) const { + auto outer_idx = idx / inner_dim_; + auto inner_idx = idx % inner_dim_; + auto mid_idx = first_zero_idx_[idx]; + if (mid_idx >= 0) { + auto full_idx = + outer_idx * mid_dim_ * inner_dim_ + mid_idx * inner_dim_ + inner_idx; + dx_[full_idx] *= grad_value_[full_idx]; + } + } + + private: + const int64_t *first_zero_idx_; + const T *grad_value_; + size_t mid_dim_; + size_t inner_dim_; + T *dx_; +}; + +template +void CumprodGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &out, + const DenseTensor &dout, + int dim, + DenseTensor *dx) { + const auto *y = &out; + const auto *dy = &dout; + + size_t outer_dim, mid_dim, inner_dim; + GetCumprodDimInfo(x.dims(), dim, &outer_dim, &mid_dim, &inner_dim); + if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return; + + size_t numel = outer_dim * mid_dim * inner_dim; + + const auto *x_data = x.data(); + const auto *y_data = y->data(); + const auto *dy_data = dy->data(); + + auto place = dev_ctx.GetPlace(); + auto *dx_data = dev_ctx.template Alloc(dx); + + // deal with complex + const T *x_data_deal; + const T *y_data_deal; + Allocator::AllocationPtr x_conj; + Allocator::AllocationPtr y_conj; + if (paddle::framework::IsComplex::value) { + x_conj = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(T)); + auto *x_data_conj = reinterpret_cast(x_conj->ptr()); + y_conj = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(T)); + auto *y_data_conj = reinterpret_cast(y_conj->ptr()); + + phi::funcs::ForRange for_range_x(dev_ctx, numel); + phi::funcs::ConjFunctor functor_x(x_data, numel, x_data_conj); + for_range_x(functor_x); + + phi::funcs::ForRange for_range_y(dev_ctx, numel); + phi::funcs::ConjFunctor functor_y(y_data, numel, y_data_conj); + for_range_y(functor_y); + x_data_deal = x_data_conj; + y_data_deal = y_data_conj; + } else { + x_data_deal = x_data; + y_data_deal = y_data; + } + +// Step 1: find cummax-ed zero mask of x +#ifdef PADDLE_WITH_CUDA + const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream()); +#else + const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); +#endif + auto zero_mask_without_cummax = + const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(uint8_t)); + auto *zero_mask_without_cummax_data = + reinterpret_cast(zero_mask_without_cummax->ptr()); + thrust::transform(exec_policy, + thrust::device_pointer_cast(x_data_deal), + thrust::device_pointer_cast(x_data_deal) + numel, + thrust::device_pointer_cast(zero_mask_without_cummax_data), + funcs::IsZeroFunctor()); + + auto zero_mask = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(uint8_t)); + auto *zero_mask_data = reinterpret_cast(zero_mask->ptr()); + paddle::operators::math::InclusiveScan( + zero_mask_without_cummax_data, + zero_mask_data, + outer_dim, + mid_dim, + inner_dim, + static_cast(0), + cub::Max(), + /*reverse=*/false, + dev_ctx); + zero_mask_without_cummax = nullptr; + + // Step 2: calculate reversed cumsum(dy * y) + auto dy_mul_y = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(T)); + auto *dy_mul_y_data = reinterpret_cast(dy_mul_y->ptr()); + thrust::transform(exec_policy, + thrust::device_pointer_cast(dy_data), + thrust::device_pointer_cast(dy_data) + numel, + thrust::device_pointer_cast(y_data_deal), + thrust::device_pointer_cast(dy_mul_y_data), + funcs::MultiplyFunctor()); + + auto dy_mul_y_reversed_cumsum = + const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(T)); + auto *dy_mul_y_reversed_cumsum_data = + reinterpret_cast(dy_mul_y_reversed_cumsum->ptr()); + paddle::operators::math::InclusiveScan( + dy_mul_y_data, + dy_mul_y_reversed_cumsum_data, + outer_dim, + mid_dim, + inner_dim, + static_cast(0), + cub::Sum(), + /*reverse=*/true, + dev_ctx); + + // Step 3: calculate the gradient value except the first zero position. + // The gradient value of the first zero position is filled with out[idx-1], + // while the gradient value of the other positions are calculated out + // completely. This functor also: + // (1) find the first zero index, i.e., first_zero_idx_data. + // (2) fill x_filled_one, which satifies + // x_filled_one[i] = x[i], i > pos + // x_filled_one[i] = 1, i <= pos + auto first_zero_idx = const_cast(dev_ctx.GetAllocator()) + .Allocate(numel * sizeof(int64_t)); + auto *first_zero_idx_data = + reinterpret_cast(first_zero_idx->ptr()); + auto *x_filled_one_data = dy_mul_y_data; // reuse former allocated memory + phi::funcs::ForRange for_range(dev_ctx, numel); + CumprodGradFunctorExceptFirstZero functor_except_first_zero( + x_data_deal, + y_data_deal, + dy_mul_y_reversed_cumsum_data, + zero_mask_data, + mid_dim, + inner_dim, + dx_data, + first_zero_idx_data, + x_filled_one_data); + for_range(functor_except_first_zero); + + // Step 4: calculate cumprod of x_filled_one + auto *x_filled_one_cumprod_data = + dy_mul_y_reversed_cumsum_data; // reuse former allocated memory + paddle::operators::math::InclusiveScan>( + x_filled_one_data, + x_filled_one_cumprod_data, + outer_dim, + mid_dim, + inner_dim, + static_cast(1), + funcs::MultiplyFunctor(), + /*reverse=*/false, + dev_ctx); + + // Step 5: calculate reversed cumsum(dy * x_filled_one_cumprod) + auto *dy_mul_x_filled_one_cumprod = + dy_mul_y_data; // reuse former allocated memory + thrust::transform(exec_policy, + thrust::device_pointer_cast(dy_data), + thrust::device_pointer_cast(dy_data) + numel, + thrust::device_pointer_cast(x_filled_one_cumprod_data), + thrust::device_pointer_cast(dy_mul_x_filled_one_cumprod), + funcs::MultiplyFunctor()); + auto *dy_mul_x_filled_one_cumprod_reversed_cumsum = + dy_mul_y_reversed_cumsum_data; // reuse former allocated memory + paddle::operators::math::InclusiveScan( + dy_mul_x_filled_one_cumprod, + dy_mul_x_filled_one_cumprod_reversed_cumsum, + outer_dim, + mid_dim, + inner_dim, + static_cast(0), + cub::Sum(), + /*reverse=*/true, + dev_ctx); + + // Step 6: fill zero pos gradient value + phi::funcs::ForRange for_range_fill_zero_pos_grad( + dev_ctx, outer_dim * inner_dim); + FillFirstZeroPositionGradFunctor fill_first_zero_pos_grad_functor( + first_zero_idx_data, + dy_mul_x_filled_one_cumprod_reversed_cumsum, + mid_dim, + inner_dim, + dx_data); + for_range_fill_zero_pos_grad(fill_first_zero_pos_grad_functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cumprod_grad, + GPU, + ALL_LAYOUT, + phi::CumprodGradKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/cumprod_kernel.cu b/paddle/phi/kernels/gpu/cumprod_kernel.cu new file mode 100644 index 00000000000..1bbf8972a24 --- /dev/null +++ b/paddle/phi/kernels/gpu/cumprod_kernel.cu @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cumprod_kernel.h" + +#include "paddle/fluid/operators/math/inclusive_scan.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/cumprod.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace phi { + +template +void CumprodKernel(const Context &dev_ctx, + const DenseTensor &input, + int dim, + DenseTensor *out) { + const auto *x = &input; + auto *y = out; + size_t outer_dim, mid_dim, inner_dim; + GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim); + + const auto *x_data = x->data(); + auto *y_data = dev_ctx.template Alloc(y); + paddle::operators::math::InclusiveScan(x_data, + y_data, + outer_dim, + mid_dim, + inner_dim, + static_cast(1), + funcs::MultiplyFunctor(), + /*reverse=*/false, + dev_ctx); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cumprod, + GPU, + ALL_LAYOUT, + phi::CumprodKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/isclose_kernel.cu b/paddle/phi/kernels/gpu/isclose_kernel.cu new file mode 100644 index 00000000000..34774ec715c --- /dev/null +++ b/paddle/phi/kernels/gpu/isclose_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/isclose_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/isclose_kernel_impl.h" + +PD_REGISTER_KERNEL( + isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/kldiv_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/kldiv_loss_grad_kernel.cu new file mode 100644 index 00000000000..8ca53f021f0 --- /dev/null +++ b/paddle/phi/kernels/gpu/kldiv_loss_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kldiv_loss_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h" +PD_REGISTER_KERNEL( + kldiv_loss_grad, GPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/kldiv_loss_kernel.cu b/paddle/phi/kernels/gpu/kldiv_loss_kernel.cu new file mode 100644 index 00000000000..9388ac7071c --- /dev/null +++ b/paddle/phi/kernels/gpu/kldiv_loss_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kldiv_loss_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h" +PD_REGISTER_KERNEL( + kldiv_loss, GPU, ALL_LAYOUT, phi::KLDivLossKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu b/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu new file mode 100644 index 00000000000..3e4cd21a658 --- /dev/null +++ b/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lgamma_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h" +PD_REGISTER_KERNEL( + lgamma_grad, GPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lgamma_kernel.cu b/paddle/phi/kernels/gpu/lgamma_kernel.cu new file mode 100644 index 00000000000..e94d67f4ce3 --- /dev/null +++ b/paddle/phi/kernels/gpu/lgamma_kernel.cu @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lgamma_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { +template +struct CudaLgammaFunctor { + __device__ __forceinline__ T operator()(const T x) const { + return Eigen::numext::lgamma(x); + } +}; +template +void LgammaKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + // XKTODO( add gpu kernel implementation. ) + dev_ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + auto functor = CudaLgammaFunctor(); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} +} // namespace phi + +PD_REGISTER_KERNEL(lgamma, GPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/isclose_kernel_impl.h b/paddle/phi/kernels/impl/isclose_kernel_impl.h new file mode 100644 index 00000000000..25247ceaff6 --- /dev/null +++ b/paddle/phi/kernels/impl/isclose_kernel_impl.h @@ -0,0 +1,176 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" + +// TODO(xiongkun): remove the header when decouple the memcpy function in phi. +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { +using Tensor = DenseTensor; +template +struct GetTensorValue { + T operator()(const DeviceContext& ctx, const DenseTensor& tensor) const; +}; + +template +struct IscloseFunctor { + void operator()(const DeviceContext& ctx, + const DenseTensor& in, + const DenseTensor& other, + const float rtol, + const float atol, + bool equal_nan, + DenseTensor* output); +}; + +template +struct GetTensorValue { + T operator()(const phi::CPUContext& dev_ctx, + const DenseTensor& tensor) const { + return *(tensor.data()); + } +}; + +template +struct GetTensorValue { + T operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& tensor) const { + const T* data = tensor.data(); + T value; + const auto gpu_place = dev_ctx.GetPlace(); + paddle::memory::Copy( + phi::CPUPlace(), &value, gpu_place, data, sizeof(T), dev_ctx.stream()); + return value; + } +}; + +template +struct IscloseFunctor { + void operator()(const phi::CPUContext& ctx, + const DenseTensor& in, + const DenseTensor& other, + const double rtol, + const double atol, + bool equal_nan, + DenseTensor* output) { + auto* in_a = in.data(); + auto* in_b = other.data(); + auto* out_data = ctx.template Alloc(output); + auto num = in.numel(); + // *out_data = true; + for (int i = 0; i < num; i++) { + out_data[i] = true; + } + for (int i = 0; i < num; i++) { + const T a = in_a[i], b = in_b[i]; + bool val; + if (std::isnan(a) || std::isnan(b)) { + val = equal_nan && std::isnan(a) == std::isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + // *out_data &= val; + out_data[i] = val; + } + } +}; + +#if defined(__NVCC__) || defined(__HIPCC__) +template +__global__ void IscloseCUDAKernel(const T* in_data, + const T* other_data, + const double rtol, + const double atol, + bool equal_nan, + int num, + bool* out_data) { + unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; + bool val; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + const T a = in_data[i], b = other_data[i]; + if (isnan(a) || isnan(b)) { + val = equal_nan && isnan(a) == isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + out_data[i] = val; + // if (!val) *out_data = false; + } +} + +template +struct IscloseFunctor { + void operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& in, + const DenseTensor& other, + const double rtol, + const double atol, + bool equal_nan, + DenseTensor* output) { + int num = in.numel(); + const T* in_data = in.data(); + const T* other_data = other.data(); + bool* out_data = dev_ctx.template Alloc(output); + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; +#ifdef PADDLE_WITH_HIP + hipMemset(out_data, true, num * sizeof(bool)); +#else + cudaMemset(out_data, true, num * sizeof(bool)); +#endif + IscloseCUDAKernel<<>>( + in_data, other_data, rtol, atol, equal_nan, num, out_data); + } +}; +#endif + +template +void IscloseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rtol, + const Scalar& atol, + bool equal_nan, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + atol.dtype(), + DataType::FLOAT64, + phi::errors::InvalidArgument("Input(Atol) type must be double")); + + PADDLE_ENFORCE_EQ( + rtol.dtype(), + DataType::FLOAT64, + phi::errors::InvalidArgument("Input(Rtol) type must be double")); + + IscloseFunctor()( + dev_ctx, x, y, rtol.to(), atol.to(), equal_nan, out); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h b/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h new file mode 100644 index 00000000000..1ae90960ef4 --- /dev/null +++ b/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { +using Array1 = Eigen::DSizes; +template +struct KLDivLossBackward { + HOSTDEVICE KLDivLossBackward() {} + + HOSTDEVICE T operator()(const T& target, const T& grad) const { + if (target <= 0) { + return 0; + } else { + return static_cast(-1.) * grad; + } + } +}; + +template +void KLDivLossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const DenseTensor& d_out, + const std::string& reduction, + DenseTensor* d_x) { + auto& place = *dev_ctx.eigen_device(); + auto* target = &label; + auto* input_grad = d_x; + auto* loss_grad = &d_out; + + const int n = input_grad->dims()[0]; + const int numel = input_grad->numel(); + const int expand = numel / loss_grad->numel(); + + dev_ctx.template Alloc(input_grad); + + auto target_t = phi::EigenVector::Flatten(*target); + + auto input_grad_t = phi::EigenVector::Flatten(*input_grad); + auto loss_grad_t = phi::EigenVector::Flatten(*loss_grad); + + auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); + auto grad_t = target_t * loss_grad_expand; + input_grad_t.device(place) = + target_t.binaryExpr(grad_t, KLDivLossBackward()); + + if ("mean" == reduction) { + input_grad_t.device(place) = input_grad_t / static_cast(numel); + } else if ("batchmean" == reduction) { + input_grad_t.device(place) = input_grad_t / static_cast(n); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h b/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h new file mode 100644 index 00000000000..ecd23bbfc1c --- /dev/null +++ b/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { +using Array1 = Eigen::DSizes; +template +struct KLDivLossForward { + HOSTDEVICE KLDivLossForward() {} + + HOSTDEVICE T operator()(const T& target, const T& input) const { + if (target <= 0) { + return 0; + } else { + return target * (std::log(target) - input); + } + } +}; +template +void KLDivLossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const std::string& reduction, + DenseTensor* out) { + auto& place = *(dev_ctx.eigen_device()); + auto* input = &x; + auto* target = &label; + auto* loss = out; + + const int n = input->dims()[0]; + dev_ctx.template Alloc(loss); + + auto input_t = phi::EigenVector::Flatten(*input); + auto target_t = phi::EigenVector::Flatten(*target); + auto loss_t = phi::EigenVector::Flatten(*loss); + auto output = target_t.binaryExpr(input_t, KLDivLossForward()); + if ("none" == reduction) { + loss_t.device(place) = output; + } else if ("batchmean" == reduction) { + auto output_sum = output.sum(); + if (n > 0) { + loss_t.device(place) = output_sum / output_sum.constant(n); + } else { + loss_t.device(place) = output_sum; + } + } else if ("mean" == reduction) { + loss_t.device(place) = output.mean(); + } else if ("sum" == reduction) { + loss_t.device(place) = output.sum(); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h b/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h new file mode 100644 index 00000000000..a1b33f5a331 --- /dev/null +++ b/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/funcs/for_range.h" +namespace phi { +template +struct LgammaGradFunctor { + LgammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; +template +void LgammaGradKernel(const Context& dev_ctx, + const DenseTensor& d_out, + const DenseTensor& x, + DenseTensor* d_x) { + auto numel = d_out.numel(); + auto* dout_data = d_out.data(); + auto* x_data = x.data(); + auto* dx_data = + dev_ctx.template Alloc(d_x, static_cast(numel * sizeof(T))); + phi::funcs::ForRange for_range(dev_ctx, numel); + LgammaGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/isclose_kernel.h b/paddle/phi/kernels/isclose_kernel.h new file mode 100644 index 00000000000..8c468da0550 --- /dev/null +++ b/paddle/phi/kernels/isclose_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void IscloseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rtol, + const Scalar& atol, + bool equal_nan, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/kldiv_loss_grad_kernel.h b/paddle/phi/kernels/kldiv_loss_grad_kernel.h new file mode 100644 index 00000000000..8f53898fa68 --- /dev/null +++ b/paddle/phi/kernels/kldiv_loss_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +// XKTODO (change name) +void KLDivLossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const DenseTensor& d_out, + const std::string& reduction, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/kldiv_loss_kernel.h b/paddle/phi/kernels/kldiv_loss_kernel.h new file mode 100644 index 00000000000..103780ab747 --- /dev/null +++ b/paddle/phi/kernels/kldiv_loss_kernel.h @@ -0,0 +1,29 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void KLDivLossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const std::string& reduction, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/lgamma_grad_kernel.h b/paddle/phi/kernels/lgamma_grad_kernel.h new file mode 100644 index 00000000000..94173cc29c7 --- /dev/null +++ b/paddle/phi/kernels/lgamma_grad_kernel.h @@ -0,0 +1,27 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LgammaGradKernel(const Context& dev_ctx, + const DenseTensor& d_out, + const DenseTensor& x, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/lgamma_kernel.h b/paddle/phi/kernels/lgamma_kernel.h new file mode 100644 index 00000000000..f61b3a1ce85 --- /dev/null +++ b/paddle/phi/kernels/lgamma_kernel.h @@ -0,0 +1,26 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LgammaKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/ops/compat/cumprod_sig.cc b/paddle/phi/ops/compat/cumprod_sig.cc new file mode 100644 index 00000000000..59b4eabfa47 --- /dev/null +++ b/paddle/phi/ops/compat/cumprod_sig.cc @@ -0,0 +1,29 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CumprodGradGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("cumprod_grad", + {"X", "Out", GradVarName("Out")}, + {"dim"}, + {GradVarName("X")}); +} + +} // namespace phi +PD_REGISTER_ARG_MAPPING_FN(cumprod_grad, phi::CumprodGradGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/isclose_sig.cc b/paddle/phi/ops/compat/isclose_sig.cc new file mode 100644 index 00000000000..08632e99095 --- /dev/null +++ b/paddle/phi/ops/compat/isclose_sig.cc @@ -0,0 +1,50 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature IscloseOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Rtol")) { + if (ctx.HasInput("Atol")) { + return KernelSignature("isclose", + {"Input", "Other"}, + {"Rtol", "Atol", "equal_nan"}, + {"Out"}); + + } else { + return KernelSignature("isclose", + {"Input", "Other"}, + {"Rtol", "atol", "equal_nan"}, + {"Out"}); + } + } else { + if (ctx.HasInput("Atol")) { + return KernelSignature("isclose", + {"Input", "Other"}, + {"rtol", "Atol", "equal_nan"}, + {"Out"}); + } else { + return KernelSignature("isclose", + {"Input", "Other"}, + {"rtol", "atol", "equal_nan"}, + {"Out"}); + } + } +} + +} // namespace phi +PD_REGISTER_ARG_MAPPING_FN(isclose, phi::IscloseOpArgumentMapping); diff --git a/paddle/phi/ops/compat/kldiv_loss_sig.cc b/paddle/phi/ops/compat/kldiv_loss_sig.cc new file mode 100644 index 00000000000..22d2f074e9f --- /dev/null +++ b/paddle/phi/ops/compat/kldiv_loss_sig.cc @@ -0,0 +1,30 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature KLDivLossGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("kldiv_loss_grad", + {"X", "Target", GradVarName("Loss")}, + {"reduction"}, + {GradVarName("X")}); +} + +} // namespace phi +PD_REGISTER_ARG_MAPPING_FN(kldiv_loss_grad, + phi::KLDivLossGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/lgamma_sig.cc b/paddle/phi/ops/compat/lgamma_sig.cc new file mode 100644 index 00000000000..968ad4923ba --- /dev/null +++ b/paddle/phi/ops/compat/lgamma_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "lgamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")}); +} + +} // namespace phi +PD_REGISTER_ARG_MAPPING_FN(lgamma_grad, phi::LgammaGradOpArgumentMapping); -- GitLab