未验证 提交 64223620 编写于 作者: X xiongkun 提交者: GitHub

[phi] Transfer lgamma, kldiv_loss, isclose, cumprod kernels into phi and pass...

[phi] Transfer lgamma, kldiv_loss, isclose, cumprod kernels into phi and pass the tests of these four kernels (#39770)

* tranfer and pass the lgamma unittest

* merge and pass the test

* transfer kldiv_loss and kldiv_loss_grad; pass the unitest

* trafer the isclose and cumprod kernel

* change PT_REGISTER -> PD_REGISTER

* fix by code review

* fix by code review

* fix

* remove enforce include dependence from scalar

* fix

* fix by code review

* fix by code review
上级 7039f61e
......@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cumprod_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
......@@ -87,16 +88,3 @@ REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker,
ops::CumprodGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp);
REGISTER_OP_CPU_KERNEL(
cumprod, ops::CumprodOpCPUKernel<float>, ops::CumprodOpCPUKernel<double>,
ops::CumprodOpCPUKernel<int>, ops::CumprodOpCPUKernel<int64_t>,
ops::CumprodOpCPUKernel<paddle::platform::complex<float>>,
ops::CumprodOpCPUKernel<paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
cumprod_grad, ops::CumprodGradOpCPUKernel<float>,
ops::CumprodGradOpCPUKernel<double>, ops::CumprodGradOpCPUKernel<int>,
ops::CumprodGradOpCPUKernel<int64_t>,
ops::CumprodGradOpCPUKernel<paddle::platform::complex<float>>,
ops::CumprodGradOpCPUKernel<paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/transform.h>
#include "paddle/fluid/operators/cumprod_op.h"
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
template <typename T>
struct MultiplyFunctor {
HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
class CumprodOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<framework::Tensor>("X");
auto *y = ctx.Output<framework::Tensor>("Out");
auto dim = ctx.Attr<int>("dim");
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim);
const auto *x_data = x->data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::InclusiveScan<T, MultiplyFunctor<T>>(
x_data, y_data, outer_dim, mid_dim, inner_dim, static_cast<T>(1),
MultiplyFunctor<T>(), /*reverse=*/false, dev_ctx);
}
};
template <typename T>
struct IsZeroFunctor {
HOSTDEVICE bool operator()(T x) const { return x == static_cast<T>(0); }
};
template <typename T>
struct CumprodGradFunctorExceptFirstZero {
HOSTDEVICE CumprodGradFunctorExceptFirstZero(
const T *x, const T *y, const T *dy_mul_y_reversed_cumsum,
const uint8_t *zero_mask, size_t mid_dim, size_t inner_dim, T *dx,
int64_t *first_zero_idx, T *x_filled_one)
: x_(x),
y_(y),
dy_mul_y_reversed_cumsum_(dy_mul_y_reversed_cumsum),
zero_mask_(zero_mask),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
dx_(dx),
first_zero_idx_(first_zero_idx),
x_filled_one_(x_filled_one) {}
HOSTDEVICE void operator()(size_t idx) const {
auto inner_idx = idx % inner_dim_;
auto outer_idx = idx / (mid_dim_ * inner_dim_);
auto mid_idx = (idx - inner_idx) / inner_dim_ % mid_dim_;
auto mask = zero_mask_[idx];
bool should_fill_one = true;
if (mask == 0) {
dx_[idx] = dy_mul_y_reversed_cumsum_[idx] / x_[idx];
if (mid_idx == mid_dim_ - 1) {
// record first zero position as -1, i.e., no zero
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = -1;
}
} else if (mid_idx > 0) { // mask > 0
if (zero_mask_[idx - inner_dim_] > 0) { // not first zero
dx_[idx] = 0;
should_fill_one = false;
} else {
// idx is the first zero position, it should be recorded
dx_[idx] = y_[idx - inner_dim_];
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = mid_idx;
}
} else { // the first zero position is index 0
dx_[idx] = 1;
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
}
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx];
}
private:
const T *x_;
const T *y_;
const T *dy_mul_y_reversed_cumsum_;
const uint8_t *zero_mask_;
size_t mid_dim_;
size_t inner_dim_;
T *dx_;
int64_t *first_zero_idx_;
T *x_filled_one_;
};
template <typename T>
struct FillFirstZeroPositionGradFunctor {
HOSTDEVICE FillFirstZeroPositionGradFunctor(const int64_t *first_zero_idx,
const T *grad_value,
size_t mid_dim, size_t inner_dim,
T *dx)
: first_zero_idx_(first_zero_idx),
grad_value_(grad_value),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
dx_(dx) {}
HOSTDEVICE void operator()(size_t idx) const {
auto outer_idx = idx / inner_dim_;
auto inner_idx = idx % inner_dim_;
auto mid_idx = first_zero_idx_[idx];
if (mid_idx >= 0) {
auto full_idx =
outer_idx * mid_dim_ * inner_dim_ + mid_idx * inner_dim_ + inner_idx;
dx_[full_idx] *= grad_value_[full_idx];
}
}
private:
const int64_t *first_zero_idx_;
const T *grad_value_;
size_t mid_dim_;
size_t inner_dim_;
T *dx_;
};
/*
Reference to
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp
input: x, y, dL/dy
output: dL/dx
dL/dx[i] = sum{0<=j<n} (dL/dy[j])*(dy[j]/dx[i]) (1)
= sum(0<=j<n} (dL/dy[j])*(d(x[0]*x[1]*...*x[j])/dx[i])
if x[i] != 0, dL/dx[i] = sum{i<=j<n} (dL/dy[j])*(y[j]/x[i]) (2)
if x[i] == 0, the formula(2) can not be applied directly.
Suppose k is the first index of zero element, the formula will be:
i > k, dL/dx[i] = 0;
i < k, dL/dx[i] = 1/x[i]*sum{i<=j<n} (dL/dy[j]*y[j])
i = k, dL/dx[i] = y[i-1]*sum{i<=j<n} (dL/dy[j])*(x[i+1]*...*x[j])
First, we will show the main resolution.
We need to judge the relationship between i (current index) and k (index
which corresponds to the first element of 0).
To mark the relationship, we now introduce zero_mask and we also need to
mark the index of the first zero element.
zero_mask = cummax(x[i] == 0); //label whether x[i]==0 until the index.
zero_index = -1; //store the first zero element's index.
e.g. x = [1, 4, 5, 0, 2, 3, 0];
zero_mask = [0, 0, 0, 1, 1, 1, 1];
zero_index = 3;
When i < k, we need to calculate the result of sum{i<=j<n}(d_y[j]*y[j]), we can
use reversed cumsum to calculate it.
R = reversed_cumsum(dy[j]*y[j]); //store the calculation result of the
sum{i<=j<n}(d_y[j]*y[j]) and x[k+1],x[k+2],...,x[j] along the index k+1 ~ j.
When i = k, we need to calculate the result of prod{i<w<j}(x[w]).
To calculate it, we introduce x_filled_one, which fill 1 before x[k+1] along
the index 0 ~ k.
e.g. x = [1, 4, 5, 0, 2, 3, 0];
x_filled_one = [1, 1, 1, 1, 2, 3, 0];
Thus, we can use cumprod(x_filled_one[j]) to calculate the result of
prod{k<=w<j}(x[w]).
Then, we will show more detailed implementation.
for (int i = 0; i < numel; i++) {
if (zero_mask[i] == 0) { //case i < k
dx[i] = R[i] / x[i];
x_filled_one[i] = 1;
} else {
if (i == 0) { //case i = k
dx[i] = 1;
zero_index = i;
x_filled_one[i] = 1;
} else {
if (zero_mask[i-1] == 0) { //case i = k
dx[i] = y[i-1];
zero_index = i;
x_filled_one[i] = 1;
} else { //case i > k
dx[i] = 0;
x_filled_one[i] = x[i];
}
}
}
}
T = reversed_cumsum(dy[j]*cumprod(x_filled_one[j]));
if (zero_index != -1) {
dx[zero_index] *= T[zero_index];
}
*/
template <typename T>
class CumprodGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<framework::Tensor>("X");
const auto *y = ctx.Input<framework::Tensor>("Out");
const auto *dy =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto dim = ctx.Attr<int>("dim");
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim);
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
size_t numel = outer_dim * mid_dim * inner_dim;
const auto *x_data = x->data<T>();
const auto *y_data = y->data<T>();
const auto *dy_data = dy->data<T>();
auto place = ctx.GetPlace();
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto *dx_data = dx->mutable_data<T>(place);
// deal with complex
const T *x_data_deal;
const T *y_data_deal;
memory::AllocationPtr x_conj;
memory::AllocationPtr y_conj;
if (framework::IsComplex<T>::value) {
x_conj = memory::Alloc(place, numel * sizeof(T));
auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr());
y_conj = memory::Alloc(place, numel * sizeof(T));
auto *y_data_conj = reinterpret_cast<T *>(y_conj->ptr());
platform::ForRange<platform::CUDADeviceContext> for_range_x(dev_ctx,
numel);
phi::funcs::ConjFunctor<T> functor_x(x_data, numel, x_data_conj);
for_range_x(functor_x);
platform::ForRange<platform::CUDADeviceContext> for_range_y(dev_ctx,
numel);
phi::funcs::ConjFunctor<T> functor_y(y_data, numel, y_data_conj);
for_range_y(functor_y);
x_data_deal = x_data_conj;
y_data_deal = y_data_conj;
} else {
x_data_deal = x_data;
y_data_deal = y_data;
}
// Step 1: find cummax-ed zero mask of x
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
auto zero_mask_without_cummax =
memory::Alloc(place, numel * sizeof(uint8_t));
auto *zero_mask_without_cummax_data =
reinterpret_cast<uint8_t *>(zero_mask_without_cummax->ptr());
thrust::transform(
exec_policy, thrust::device_pointer_cast(x_data_deal),
thrust::device_pointer_cast(x_data_deal) + numel,
thrust::device_pointer_cast(zero_mask_without_cummax_data),
IsZeroFunctor<T>());
auto zero_mask = memory::Alloc(place, numel * sizeof(uint8_t));
auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr());
math::InclusiveScan<uint8_t, cub::Max>(
zero_mask_without_cummax_data, zero_mask_data, outer_dim, mid_dim,
inner_dim, static_cast<uint8_t>(0), cub::Max(), /*reverse=*/false,
dev_ctx);
zero_mask_without_cummax = nullptr;
// Step 2: calculate reversed cumsum(dy * y)
auto dy_mul_y = memory::Alloc(place, numel * sizeof(T));
auto *dy_mul_y_data = reinterpret_cast<T *>(dy_mul_y->ptr());
thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data),
thrust::device_pointer_cast(dy_data) + numel,
thrust::device_pointer_cast(y_data_deal),
thrust::device_pointer_cast(dy_mul_y_data),
MultiplyFunctor<T>());
auto dy_mul_y_reversed_cumsum = memory::Alloc(place, numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
math::InclusiveScan<T, cub::Sum>(
dy_mul_y_data, dy_mul_y_reversed_cumsum_data, outer_dim, mid_dim,
inner_dim, static_cast<T>(0), cub::Sum(), /*reverse=*/true, dev_ctx);
// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
// while the gradient value of the other positions are calculated out
// completely. This functor also:
// (1) find the first zero index, i.e., first_zero_idx_data.
// (2) fill x_filled_one, which satifies
// x_filled_one[i] = x[i], i > pos
// x_filled_one[i] = 1, i <= pos
auto first_zero_idx =
memory::Alloc(place, outer_dim * inner_dim * sizeof(int64_t));
auto *first_zero_idx_data =
reinterpret_cast<int64_t *>(first_zero_idx->ptr());
auto *x_filled_one_data = dy_mul_y_data; // reuse former allocated memory
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, numel);
CumprodGradFunctorExceptFirstZero<T> functor_except_first_zero(
x_data_deal, y_data_deal, dy_mul_y_reversed_cumsum_data, zero_mask_data,
mid_dim, inner_dim, dx_data, first_zero_idx_data, x_filled_one_data);
for_range(functor_except_first_zero);
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
math::InclusiveScan<T, MultiplyFunctor<T>>(
x_filled_one_data, x_filled_one_cumprod_data, outer_dim, mid_dim,
inner_dim, static_cast<T>(1), MultiplyFunctor<T>(), /*reverse=*/false,
dev_ctx);
// Step 5: calculate reversed cumsum(dy * x_filled_one_cumprod)
auto *dy_mul_x_filled_one_cumprod =
dy_mul_y_data; // reuse former allocated memory
thrust::transform(exec_policy, thrust::device_pointer_cast(dy_data),
thrust::device_pointer_cast(dy_data) + numel,
thrust::device_pointer_cast(x_filled_one_cumprod_data),
thrust::device_pointer_cast(dy_mul_x_filled_one_cumprod),
MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
math::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum, outer_dim, mid_dim,
inner_dim, static_cast<T>(0), cub::Sum(),
/*reverse=*/true, dev_ctx);
// Step 6: fill zero pos gradient value
platform::ForRange<platform::CUDADeviceContext>
for_range_fill_zero_pos_grad(dev_ctx, outer_dim * inner_dim);
FillFirstZeroPositionGradFunctor<T> fill_first_zero_pos_grad_functor(
first_zero_idx_data, dy_mul_x_filled_one_cumprod_reversed_cumsum,
mid_dim, inner_dim, dx_data);
for_range_fill_zero_pos_grad(fill_first_zero_pos_grad_functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cumprod, ops::CumprodOpCUDAKernel<float>, ops::CumprodOpCUDAKernel<double>,
ops::CumprodOpCUDAKernel<int>, ops::CumprodOpCUDAKernel<int64_t>,
ops::CumprodOpCUDAKernel<paddle::platform::complex<float>>,
ops::CumprodOpCUDAKernel<paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
cumprod_grad, ops::CumprodGradOpCUDAKernel<float>,
ops::CumprodGradOpCUDAKernel<double>, ops::CumprodGradOpCUDAKernel<int>,
ops::CumprodGradOpCUDAKernel<int64_t>,
ops::CumprodGradOpCUDAKernel<paddle::platform::complex<float>>,
ops::CumprodGradOpCUDAKernel<paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static void GetCumprodDimInfo(const framework::DDim& dim, int cumprod_dim,
size_t* outer_dim, size_t* mid_dim,
size_t* inner_dim) {
PADDLE_ENFORCE_GE(
cumprod_dim, -dim.size(),
platform::errors::InvalidArgument(
"The input dim of CumprodOp should be larger than the opposite "
"rank of input x which is %d.But received dim=%d",
-dim.size(), cumprod_dim));
PADDLE_ENFORCE_LT(cumprod_dim, dim.size(),
platform::errors::InvalidArgument(
"The input dim of CumprodOp should be smaller than the "
"rank of input x which is %d.But received dim=%d",
dim.size(), cumprod_dim));
if (cumprod_dim < 0) cumprod_dim += dim.size();
*outer_dim = 1;
for (int i = 0; i < cumprod_dim; ++i) {
*outer_dim *= dim[i];
}
*mid_dim = dim[cumprod_dim];
*inner_dim = 1;
for (int i = cumprod_dim + 1; i < dim.size(); ++i) {
*inner_dim *= dim[i];
}
}
template <typename T>
class CumprodOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
int dim = context.Attr<int>("dim");
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
framework::DDim shape = x->dims();
size_t outer_dim = 1;
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
for (size_t i = 0; i < outer_dim; i++) {
for (size_t j = 0; j < mid_dim; j++) {
for (size_t k = 0; k < inner_dim; k++) {
size_t pos = i * mid_dim * inner_dim + j * inner_dim + k;
if (j == 0) {
out_data[pos] = x_data[pos];
} else {
out_data[pos] = out_data[pos - inner_dim] * x_data[pos];
}
}
}
}
}
};
template <typename T>
class CumprodGradOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const {
const Tensor* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
const Tensor* x = context.Input<Tensor>("X");
const Tensor* out = context.Input<Tensor>("Out");
int dim = context.Attr<int>("dim");
framework::DDim shape = x->dims();
Tensor* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* d_out_data = d_out->data<T>();
auto* x_data = x->data<T>();
auto* out_data = out->data<T>();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t outer_dim = 1;
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
size_t numel = outer_dim * mid_dim * inner_dim;
// deal with complex
const T* x_data_deal;
const T* out_data_deal;
memory::AllocationPtr x_conj;
memory::AllocationPtr out_conj;
if (framework::IsComplex<T>::value) {
x_conj = memory::Alloc(place, numel * sizeof(T));
auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr());
out_conj = memory::Alloc(place, numel * sizeof(T));
auto* out_data_conj = reinterpret_cast<T*>(out_conj->ptr());
platform::ForRange<platform::CPUDeviceContext> for_range_x(dev_ctx,
numel);
phi::funcs::ConjFunctor<T> functor_x(x_data, numel, x_data_conj);
for_range_x(functor_x);
platform::ForRange<platform::CPUDeviceContext> for_range_out(dev_ctx,
numel);
phi::funcs::ConjFunctor<T> functor_out(out_data, numel, out_data_conj);
for_range_out(functor_out);
x_data_deal = x_data_conj;
out_data_deal = out_data_conj;
} else {
x_data_deal = x_data;
out_data_deal = out_data;
}
for (size_t i = 0; i < outer_dim; i++) {
for (size_t k = 0; k < inner_dim; k++) {
for (size_t j = 0; j < mid_dim; j++) {
size_t index = i * mid_dim * inner_dim + j * inner_dim + k;
d_x_data[index] = 0;
for (size_t n = 0; n < mid_dim; n++) {
size_t pos = i * mid_dim * inner_dim + n * inner_dim + k;
T elem;
if (j == 0) {
elem = d_out_data[pos];
} else {
elem = d_out_data[pos] * out_data_deal[index - inner_dim];
}
if (pos > index) {
for (size_t m = index + inner_dim; m <= pos; m += inner_dim) {
elem *= x_data_deal[m];
}
} else if (pos < index) {
elem = static_cast<T>(0);
}
d_x_data[index] += elem;
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/isclose_op.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
......@@ -23,45 +22,6 @@
namespace paddle {
namespace operators {
template <typename T>
struct GetTensorValue<platform::CPUDeviceContext, T> {
T operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
return *(tensor.data<T>());
}
};
template <typename T>
struct IscloseFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
auto* in_a = in.data<T>();
auto* in_b = other.data<T>();
auto* out_data = output->mutable_data<bool>(ctx.GetPlace());
auto num = in.numel();
// *out_data = true;
for (int i = 0; i < num; i++) {
out_data[i] = true;
}
for (int i = 0; i < num; i++) {
const T a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
// *out_data &= val;
out_data[i] = val;
}
}
};
class IscloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -154,12 +114,9 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference {
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
isclose, ops::IscloseOp, ops::IscloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::IscloseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(isclose, ops::IscloseKernel<CPU, float>,
ops::IscloseKernel<CPU, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct GetTensorValue {
T operator()(const platform::DeviceContext& ctx,
const framework::Tensor& tensor) const;
};
template <typename DeviceContext, typename T>
struct IscloseFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& other, const float rtol,
const float atol, bool equal_nan, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
class IscloseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// get attrs
bool equal_nan = ctx.Attr<bool>("equal_nan");
// get input/output
const auto* input = ctx.Input<Tensor>("Input");
const auto* other = ctx.Input<Tensor>("Other");
auto* out = ctx.Output<Tensor>("Out");
double rtol_v = std::stod(ctx.Attr<std::string>("rtol"));
double atol_v = std::stod(ctx.Attr<std::string>("atol"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
GetTensorValue<DeviceContext, double> get_tensor_value;
if (ctx.HasInput("Rtol")) {
const auto* rtol = ctx.Input<Tensor>("Rtol");
PADDLE_ENFORCE_EQ(
rtol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Rtol) size must be 1, but get %d.", rtol->numel()));
PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(rtol->dtype()),
framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Rtol) type must be double, but get %s.",
framework::DataTypeToString(
framework::TransToProtoVarType(rtol->dtype()))));
rtol_v = get_tensor_value(dev_ctx, *rtol);
}
if (ctx.HasInput("Atol")) {
const auto* atol = ctx.Input<Tensor>("Atol");
PADDLE_ENFORCE_EQ(
atol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Atol) size must be 1, but get %d", atol->numel()));
PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(atol->dtype()),
framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Atol) type must be double, but get %s",
framework::DataTypeToString(
framework::TransToProtoVarType(atol->dtype()))));
atol_v = get_tensor_value(dev_ctx, *atol);
}
IscloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v,
equal_nan, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -9,7 +9,6 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kldiv_loss_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
......@@ -177,10 +176,3 @@ REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
kldiv_loss_grad,
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossForward {
HOSTDEVICE KLDivLossForward() {}
HOSTDEVICE T operator()(const T& target, const T& input) const {
if (target <= 0) {
return 0;
} else {
return target * (std::log(target) - input);
}
}
};
template <typename T>
struct KLDivLossBackward {
HOSTDEVICE KLDivLossBackward() {}
HOSTDEVICE T operator()(const T& target, const T& grad) const {
if (target <= 0) {
return 0;
} else {
return static_cast<T>(-1.) * grad;
}
}
};
template <typename DeviceContext, typename T>
class KLDivLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* input = ctx.Input<Tensor>("X");
auto* target = ctx.Input<Tensor>("Target");
auto* loss = ctx.Output<Tensor>("Loss");
auto reduction = ctx.Attr<std::string>("reduction");
const int n = input->dims()[0];
loss->mutable_data<T>(ctx.GetPlace());
auto input_t = framework::EigenVector<T>::Flatten(*input);
auto target_t = framework::EigenVector<T>::Flatten(*target);
auto loss_t = framework::EigenVector<T>::Flatten(*loss);
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
if ("none" == reduction) {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
auto output_sum = output.sum();
if (n > 0) {
loss_t.device(place) = output_sum / output_sum.constant(n);
} else {
loss_t.device(place) = output_sum;
}
} else if ("mean" == reduction) {
loss_t.device(place) = output.mean();
} else if ("sum" == reduction) {
loss_t.device(place) = output.sum();
}
}
};
template <typename DeviceContext, typename T>
class KLDivLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* target = ctx.Input<Tensor>("Target");
auto reduction = ctx.Attr<std::string>("reduction");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
const int n = input_grad->dims()[0];
const int numel = input_grad->numel();
const int expand = numel / loss_grad->numel();
input_grad->mutable_data<T>(ctx.GetPlace());
auto target_t = framework::EigenVector<T>::Flatten(*target);
auto input_grad_t = framework::EigenVector<T>::Flatten(*input_grad);
auto loss_grad_t = framework::EigenVector<T>::Flatten(*loss_grad);
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
auto grad_t = target_t * loss_grad_expand;
input_grad_t.device(place) =
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
} else if ("batchmean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(n);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the Licnse. */
#include "paddle/fluid/operators/kldiv_loss_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/lgamma_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -35,16 +38,6 @@ $$out = log\Gamma(x)$$
class LgammaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Lgamma");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Lgamma");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", "Out");
}
};
template <typename T>
......@@ -83,17 +76,12 @@ class LgammaGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(lgamma, LgammaInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(lgamma, ops::LgammaOp, ops::LgammaOpMaker,
ops::LgammaGradMaker<paddle::framework::OpDesc>,
ops::LgammaGradMaker<paddle::imperative::OpBase>);
ops::LgammaGradMaker<paddle::imperative::OpBase>,
LgammaInferShapeFunctor);
REGISTER_OPERATOR(lgamma_grad, ops::LgammaGradOp);
REGISTER_OP_CPU_KERNEL(
lgamma, ops::LgammaKernel<paddle::platform::CPUDeviceContext, float>,
ops::LgammaKernel<paddle::platform::CPUDeviceContext, double>)
REGISTER_OP_CPU_KERNEL(
lgamma_grad,
ops::LgammaGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LgammaGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
struct LgammaFunctor {
LgammaFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = Eigen::numext::lgamma(input_[idx]);
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T>
struct LgammaGradFunctor {
LgammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]);
}
private:
const T* dout_;
const T* x_;
T* output_;
int64_t numel_;
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class LgammaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace(),
size_t(x->numel() * sizeof(T)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
LgammaFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class LgammaGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<T>();
auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
LgammaGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -34,10 +34,10 @@ namespace paddle {
namespace operators {
namespace math {
template <typename InputIterator, typename OutputIterator, typename BinaryOp>
template <typename InputIterator, typename OutputIterator, typename BinaryOp,
typename Context>
static void CubInclusiveScan(InputIterator x_iter, OutputIterator y_iter,
size_t n, BinaryOp op,
const platform::CUDADeviceContext &dev_ctx) {
size_t n, BinaryOp op, const Context &dev_ctx) {
memory::AllocationPtr allocation;
void *temp_storage = nullptr;
size_t temp_storage_bytes = 0;
......@@ -185,11 +185,10 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y,
}
}
template <typename T, typename BinaryOp>
template <typename T, typename BinaryOp, typename Context>
static void InclusiveScanInnerDim(const T *x, T *y, size_t outer_dim,
size_t inner_dim, T init, BinaryOp op,
bool reverse,
const platform::CUDADeviceContext &dev_ctx) {
bool reverse, const Context &dev_ctx) {
constexpr size_t kThreadNumX = 16;
constexpr size_t kThreadNumY = 32;
......@@ -209,10 +208,10 @@ static void InclusiveScanInnerDim(const T *x, T *y, size_t outer_dim,
}
}
template <typename T, typename BinaryOp>
template <typename T, typename BinaryOp, typename Context>
void InclusiveScan(const T *x, T *y, size_t outer_dim, size_t mid_dim,
size_t inner_dim, T init, BinaryOp op, bool reverse,
const platform::CUDADeviceContext &dev_ctx) {
const Context &dev_ctx) {
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
if (outer_dim == 1 && inner_dim == 1) {
......@@ -224,8 +223,7 @@ void InclusiveScan(const T *x, T *y, size_t outer_dim, size_t mid_dim,
CubInclusiveScan(x, y, mid_dim, op, dev_ctx);
}
} else if (inner_dim != 1) {
platform::ForRange<platform::CUDADeviceContext> for_range(
dev_ctx, outer_dim * inner_dim);
platform::ForRange<Context> for_range(dev_ctx, outer_dim * inner_dim);
if (reverse) {
for_range(
InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/true>(
......
cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits)
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits scalar)
cc_library(phi_place SRCS place.cc)
cc_library(scalar SRCS scalar.cc)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kldiv_loss_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
kldiv_loss,
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
kldiv_loss_grad,
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, double>);
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace experimental {
// NOTE(xiongkun): why we put definition here?
// test_custom_op can't include enforce.h, because enforce.h includes gflags.
// so we decouple the include dependence of enforce.h by link.
void ThrowTensorConvertError(int num) {
PADDLE_ENFORCE_EQ(num,
1,
phi::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, but "
"now Tensor has `%d` elements",
num));
}
} // namespace experimental
} // namespace paddle
......@@ -19,9 +19,12 @@ limitations under the License. */
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace experimental {
void ThrowTensorConvertError(int);
template <typename T>
class ScalarBase {
public:
......@@ -104,11 +107,7 @@ class ScalarBase {
// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_from_tensor_ = true;
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
tensor.numel(),
"` element.");
ThrowTensorConvertError(tensor.numel());
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
......@@ -156,6 +155,8 @@ class ScalarBase {
CopyScalar(other, this);
}
// NOTE(xiongkun): some op need to judge the dtype of the Scalar, we expose a
// interface.
bool FromTensor() const { return is_from_tensor_; }
void SetFromTensor(bool from_tensor) { is_from_tensor_ = from_tensor; }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumprod_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T, typename Context>
void CumprodGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& d_out,
int dim,
DenseTensor* d_x) {
DDim shape = x.dims();
auto* d_out_data = d_out.data<T>();
auto* x_data = x.data<T>();
auto* out_data = out.data<T>();
auto* d_x_data = dev_ctx.template Alloc<T>(d_x);
size_t outer_dim = 1;
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
size_t numel = outer_dim * mid_dim * inner_dim;
// deal with complex
const T* x_data_deal;
const T* out_data_deal;
Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr out_conj;
if (paddle::framework::IsComplex<T>::value) {
x_conj = const_cast<Allocator&>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr());
out_conj = const_cast<Allocator&>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto* out_data_conj = reinterpret_cast<T*>(out_conj->ptr());
phi::funcs::ForRange<Context> for_range_x(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_x(x_data, numel, x_data_conj);
for_range_x(functor_x);
phi::funcs::ForRange<Context> for_range_out(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_out(out_data, numel, out_data_conj);
for_range_out(functor_out);
x_data_deal = x_data_conj;
out_data_deal = out_data_conj;
} else {
x_data_deal = x_data;
out_data_deal = out_data;
}
for (size_t i = 0; i < outer_dim; i++) {
for (size_t k = 0; k < inner_dim; k++) {
for (size_t j = 0; j < mid_dim; j++) {
size_t index = i * mid_dim * inner_dim + j * inner_dim + k;
d_x_data[index] = 0;
for (size_t n = 0; n < mid_dim; n++) {
size_t pos = i * mid_dim * inner_dim + n * inner_dim + k;
T elem;
if (j == 0) {
elem = d_out_data[pos];
} else {
elem = d_out_data[pos] * out_data_deal[index - inner_dim];
}
if (pos > index) {
for (size_t m = index + inner_dim; m <= pos; m += inner_dim) {
elem *= x_data_deal[m];
}
} else if (pos < index) {
elem = static_cast<T>(0);
}
d_x_data[index] += elem;
}
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(cumprod_grad,
CPU,
ALL_LAYOUT,
phi::CumprodGradKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumprod_kernel.h"
#include <cstdint>
#include <type_traits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
namespace phi {
template <typename T, typename Context>
void CumprodKernel(const Context& dev_ctx,
const DenseTensor& input,
int dim,
DenseTensor* out) {
const DenseTensor* x = &input;
auto* x_data = x->data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
DDim shape = x->dims();
size_t outer_dim = 1;
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
for (size_t i = 0; i < outer_dim; i++) {
for (size_t j = 0; j < mid_dim; j++) {
for (size_t k = 0; k < inner_dim; k++) {
size_t pos = i * mid_dim * inner_dim + j * inner_dim + k;
if (j == 0) {
out_data[pos] = x_data[pos];
} else {
out_data[pos] = out_data[pos - inner_dim] * x_data[pos];
}
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(cumprod,
CPU,
ALL_LAYOUT,
phi::CumprodKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/isclose_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"
PD_REGISTER_KERNEL(
isclose, CPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kldiv_loss_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
kldiv_loss_grad, CPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kldiv_loss_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h"
namespace phi {} // namespace phi
PD_REGISTER_KERNEL(
kldiv_loss, CPU, ALL_LAYOUT, phi::KLDivLossKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/lgamma_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
lgamma_grad, CPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/lgamma_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T>
struct LgammaFunctor {
LgammaFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = Eigen::numext::lgamma(input_[idx]);
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T, typename Context>
void LgammaKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
LgammaFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(lgamma, CPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CumprodGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
int dim,
DenseTensor* dx);
} // phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CumprodKernel(const Context& dev_ctx,
const DenseTensor& x,
int dim,
DenseTensor* out);
} // phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
static void GetCumprodDimInfo(const DDim& dim,
int cumprod_dim,
size_t* outer_dim,
size_t* mid_dim,
size_t* inner_dim) {
PADDLE_ENFORCE_GE(
cumprod_dim,
-dim.size(),
phi::errors::InvalidArgument(
"The input dim of CumprodOp should be larger than the opposite "
"rank of input x which is %d.But received dim=%d",
-dim.size(),
cumprod_dim));
PADDLE_ENFORCE_LT(cumprod_dim,
dim.size(),
phi::errors::InvalidArgument(
"The input dim of CumprodOp should be smaller than the "
"rank of input x which is %d.But received dim=%d",
dim.size(),
cumprod_dim));
if (cumprod_dim < 0) cumprod_dim += dim.size();
*outer_dim = 1;
for (int i = 0; i < cumprod_dim; ++i) {
*outer_dim *= dim[i];
}
*mid_dim = dim[cumprod_dim];
*inner_dim = 1;
for (int i = cumprod_dim + 1; i < dim.size(); ++i) {
*inner_dim *= dim[i];
}
}
} // namespace phi
......@@ -67,6 +67,11 @@ struct InverseMultiplyFunctor<bool> {
}
};
template <typename T>
struct IsZeroFunctor {
HOSTDEVICE bool operator()(T x) const { return x == static_cast<T>(0); }
};
// Divide
#define DIV_ERROR_INFO \
"InvalidArgumentError: Integer division by zero encountered in " \
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumprod_grad_kernel.h"
#include <thrust/transform.h>
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T>
struct CumprodGradFunctorExceptFirstZero {
HOSTDEVICE CumprodGradFunctorExceptFirstZero(
const T *x,
const T *y,
const T *dy_mul_y_reversed_cumsum,
const uint8_t *zero_mask,
size_t mid_dim,
size_t inner_dim,
T *dx,
int64_t *first_zero_idx,
T *x_filled_one)
: x_(x),
y_(y),
dy_mul_y_reversed_cumsum_(dy_mul_y_reversed_cumsum),
zero_mask_(zero_mask),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
dx_(dx),
first_zero_idx_(first_zero_idx),
x_filled_one_(x_filled_one) {}
HOSTDEVICE void operator()(size_t idx) const {
auto inner_idx = idx % inner_dim_;
auto outer_idx = idx / (mid_dim_ * inner_dim_);
auto mid_idx = (idx - inner_idx) / inner_dim_ % mid_dim_;
auto mask = zero_mask_[idx];
bool should_fill_one = true;
if (mask == 0) {
dx_[idx] = dy_mul_y_reversed_cumsum_[idx] / x_[idx];
if (mid_idx == mid_dim_ - 1) {
// record first zero position as -1, i.e., no zero
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = -1;
}
} else if (mid_idx > 0) { // mask > 0
if (zero_mask_[idx - inner_dim_] > 0) { // not first zero
dx_[idx] = 0;
should_fill_one = false;
} else {
// idx is the first zero position, it should be recorded
dx_[idx] = y_[idx - inner_dim_];
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = mid_idx;
}
} else { // the first zero position is index 0
dx_[idx] = 1;
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
}
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx];
}
private:
const T *x_;
const T *y_;
const T *dy_mul_y_reversed_cumsum_;
const uint8_t *zero_mask_;
size_t mid_dim_;
size_t inner_dim_;
T *dx_;
int64_t *first_zero_idx_;
T *x_filled_one_;
};
template <typename T>
struct FillFirstZeroPositionGradFunctor {
HOSTDEVICE FillFirstZeroPositionGradFunctor(const int64_t *first_zero_idx,
const T *grad_value,
size_t mid_dim,
size_t inner_dim,
T *dx)
: first_zero_idx_(first_zero_idx),
grad_value_(grad_value),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
dx_(dx) {}
HOSTDEVICE void operator()(size_t idx) const {
auto outer_idx = idx / inner_dim_;
auto inner_idx = idx % inner_dim_;
auto mid_idx = first_zero_idx_[idx];
if (mid_idx >= 0) {
auto full_idx =
outer_idx * mid_dim_ * inner_dim_ + mid_idx * inner_dim_ + inner_idx;
dx_[full_idx] *= grad_value_[full_idx];
}
}
private:
const int64_t *first_zero_idx_;
const T *grad_value_;
size_t mid_dim_;
size_t inner_dim_;
T *dx_;
};
template <typename T, typename Context>
void CumprodGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &out,
const DenseTensor &dout,
int dim,
DenseTensor *dx) {
const auto *y = &out;
const auto *dy = &dout;
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x.dims(), dim, &outer_dim, &mid_dim, &inner_dim);
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
size_t numel = outer_dim * mid_dim * inner_dim;
const auto *x_data = x.data<T>();
const auto *y_data = y->data<T>();
const auto *dy_data = dy->data<T>();
auto place = dev_ctx.GetPlace();
auto *dx_data = dev_ctx.template Alloc<T>(dx);
// deal with complex
const T *x_data_deal;
const T *y_data_deal;
Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr y_conj;
if (paddle::framework::IsComplex<T>::value) {
x_conj = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr());
y_conj = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto *y_data_conj = reinterpret_cast<T *>(y_conj->ptr());
phi::funcs::ForRange<Context> for_range_x(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_x(x_data, numel, x_data_conj);
for_range_x(functor_x);
phi::funcs::ForRange<Context> for_range_y(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_y(y_data, numel, y_data_conj);
for_range_y(functor_y);
x_data_deal = x_data_conj;
y_data_deal = y_data_conj;
} else {
x_data_deal = x_data;
y_data_deal = y_data;
}
// Step 1: find cummax-ed zero mask of x
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
auto zero_mask_without_cummax =
const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(uint8_t));
auto *zero_mask_without_cummax_data =
reinterpret_cast<uint8_t *>(zero_mask_without_cummax->ptr());
thrust::transform(exec_policy,
thrust::device_pointer_cast(x_data_deal),
thrust::device_pointer_cast(x_data_deal) + numel,
thrust::device_pointer_cast(zero_mask_without_cummax_data),
funcs::IsZeroFunctor<T>());
auto zero_mask = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(uint8_t));
auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr());
paddle::operators::math::InclusiveScan<uint8_t, cub::Max>(
zero_mask_without_cummax_data,
zero_mask_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<uint8_t>(0),
cub::Max(),
/*reverse=*/false,
dev_ctx);
zero_mask_without_cummax = nullptr;
// Step 2: calculate reversed cumsum(dy * y)
auto dy_mul_y = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto *dy_mul_y_data = reinterpret_cast<T *>(dy_mul_y->ptr());
thrust::transform(exec_policy,
thrust::device_pointer_cast(dy_data),
thrust::device_pointer_cast(dy_data) + numel,
thrust::device_pointer_cast(y_data_deal),
thrust::device_pointer_cast(dy_mul_y_data),
funcs::MultiplyFunctor<T>());
auto dy_mul_y_reversed_cumsum =
const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
paddle::operators::math::InclusiveScan<T, cub::Sum>(
dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
// while the gradient value of the other positions are calculated out
// completely. This functor also:
// (1) find the first zero index, i.e., first_zero_idx_data.
// (2) fill x_filled_one, which satifies
// x_filled_one[i] = x[i], i > pos
// x_filled_one[i] = 1, i <= pos
auto first_zero_idx = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(int64_t));
auto *first_zero_idx_data =
reinterpret_cast<int64_t *>(first_zero_idx->ptr());
auto *x_filled_one_data = dy_mul_y_data; // reuse former allocated memory
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
CumprodGradFunctorExceptFirstZero<T> functor_except_first_zero(
x_data_deal,
y_data_deal,
dy_mul_y_reversed_cumsum_data,
zero_mask_data,
mid_dim,
inner_dim,
dx_data,
first_zero_idx_data,
x_filled_one_data);
for_range(functor_except_first_zero);
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
x_filled_one_data,
x_filled_one_cumprod_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
// Step 5: calculate reversed cumsum(dy * x_filled_one_cumprod)
auto *dy_mul_x_filled_one_cumprod =
dy_mul_y_data; // reuse former allocated memory
thrust::transform(exec_policy,
thrust::device_pointer_cast(dy_data),
thrust::device_pointer_cast(dy_data) + numel,
thrust::device_pointer_cast(x_filled_one_cumprod_data),
thrust::device_pointer_cast(dy_mul_x_filled_one_cumprod),
funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
// Step 6: fill zero pos gradient value
phi::funcs::ForRange<Context> for_range_fill_zero_pos_grad(
dev_ctx, outer_dim * inner_dim);
FillFirstZeroPositionGradFunctor<T> fill_first_zero_pos_grad_functor(
first_zero_idx_data,
dy_mul_x_filled_one_cumprod_reversed_cumsum,
mid_dim,
inner_dim,
dx_data);
for_range_fill_zero_pos_grad(fill_first_zero_pos_grad_functor);
}
} // namespace phi
PD_REGISTER_KERNEL(cumprod_grad,
GPU,
ALL_LAYOUT,
phi::CumprodGradKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumprod_kernel.h"
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace phi {
template <typename T, typename Context>
void CumprodKernel(const Context &dev_ctx,
const DenseTensor &input,
int dim,
DenseTensor *out) {
const auto *x = &input;
auto *y = out;
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim);
const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
paddle::operators::math::InclusiveScan(x_data,
y_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
}
} // namespace phi
PD_REGISTER_KERNEL(cumprod,
GPU,
ALL_LAYOUT,
phi::CumprodKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/isclose_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"
PD_REGISTER_KERNEL(
isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kldiv_loss_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
kldiv_loss_grad, GPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kldiv_loss_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h"
PD_REGISTER_KERNEL(
kldiv_loss, GPU, ALL_LAYOUT, phi::KLDivLossKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/lgamma_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
lgamma_grad, GPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,48 +12,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/lgamma_op.h"
#include "paddle/phi/kernels/lgamma_kernel.h"
namespace paddle {
namespace operators {
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi {
template <typename T>
struct CudaLgammaFunctor {
__device__ __forceinline__ T operator()(const T x) const {
return Eigen::numext::lgamma(x);
}
};
template <typename T>
class LgammaKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = CudaLgammaFunctor<T>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lgamma, ops::LgammaKernel<paddle::platform::CUDADeviceContext, float>,
ops::LgammaKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
lgamma_grad,
ops::LgammaGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LgammaGradKernel<paddle::platform::CUDADeviceContext, double>);
template <typename T, typename Context>
void LgammaKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
// XKTODO( add gpu kernel implementation. )
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaLgammaFunctor<T>();
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(lgamma, GPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,30 +12,102 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/isclose_op.h"
#pragma once
#include <cmath>
#include <string>
namespace paddle {
namespace operators {
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
// TODO(xiongkun): remove the header when decouple the memcpy function in phi.
#include "paddle/fluid/memory/memcpy.h"
namespace phi {
using Tensor = DenseTensor;
template <typename DeviceContext, typename T>
struct GetTensorValue {
T operator()(const DeviceContext& ctx, const DenseTensor& tensor) const;
};
template <typename DeviceContext, typename T>
struct IscloseFunctor {
void operator()(const DeviceContext& ctx,
const DenseTensor& in,
const DenseTensor& other,
const float rtol,
const float atol,
bool equal_nan,
DenseTensor* output);
};
template <typename T>
struct GetTensorValue<phi::CPUContext, T> {
T operator()(const phi::CPUContext& dev_ctx,
const DenseTensor& tensor) const {
return *(tensor.data<T>());
}
};
template <typename T>
struct GetTensorValue<platform::CUDADeviceContext, T> {
T operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
struct GetTensorValue<phi::GPUContext, T> {
T operator()(const phi::GPUContext& dev_ctx,
const DenseTensor& tensor) const {
const T* data = tensor.data<T>();
T value;
const auto gpu_place = dev_ctx.GetPlace();
memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T),
dev_ctx.stream());
paddle::memory::Copy(
phi::CPUPlace(), &value, gpu_place, data, sizeof(T), dev_ctx.stream());
return value;
}
};
template <typename T>
__global__ void IscloseCUDAKernel(const T* in_data, const T* other_data,
const double rtol, const double atol,
bool equal_nan, int num, bool* out_data) {
struct IscloseFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& ctx,
const DenseTensor& in,
const DenseTensor& other,
const double rtol,
const double atol,
bool equal_nan,
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* in_b = other.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
// *out_data = true;
for (int i = 0; i < num; i++) {
out_data[i] = true;
}
for (int i = 0; i < num; i++) {
const T a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
// *out_data &= val;
out_data[i] = val;
}
}
};
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
__global__ void IscloseCUDAKernel(const T* in_data,
const T* other_data,
const double rtol,
const double atol,
bool equal_nan,
int num,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
......@@ -54,15 +126,18 @@ __global__ void IscloseCUDAKernel(const T* in_data, const T* other_data,
}
template <typename T>
struct IscloseFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
struct IscloseFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
const DenseTensor& in,
const DenseTensor& other,
const double rtol,
const double atol,
bool equal_nan,
DenseTensor* output) {
int num = in.numel();
const T* in_data = in.data<T>();
const T* other_data = other.data<T>();
bool* out_data = output->mutable_data<bool>(dev_ctx.GetPlace());
bool* out_data = dev_ctx.template Alloc<bool>(output);
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
......@@ -75,11 +150,27 @@ struct IscloseFunctor<platform::CUDADeviceContext, T> {
in_data, other_data, rtol, atol, equal_nan, num, out_data);
}
};
#endif
} // namespace operators
} // namespace paddle
template <typename T, typename Context>
void IscloseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& rtol,
const Scalar& atol,
bool equal_nan,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
atol.dtype(),
DataType::FLOAT64,
phi::errors::InvalidArgument("Input(Atol) type must be double"));
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(isclose, ops::IscloseKernel<CUDA, float>,
ops::IscloseKernel<CUDA, double>);
PADDLE_ENFORCE_EQ(
rtol.dtype(),
DataType::FLOAT64,
phi::errors::InvalidArgument("Input(Rtol) type must be double"));
IscloseFunctor<Context, T>()(
dev_ctx, x, y, rtol.to<double>(), atol.to<double>(), equal_nan, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossBackward {
HOSTDEVICE KLDivLossBackward() {}
HOSTDEVICE T operator()(const T& target, const T& grad) const {
if (target <= 0) {
return 0;
} else {
return static_cast<T>(-1.) * grad;
}
}
};
template <typename T, typename Context>
void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
DenseTensor* d_x) {
auto& place = *dev_ctx.eigen_device();
auto* target = &label;
auto* input_grad = d_x;
auto* loss_grad = &d_out;
const int n = input_grad->dims()[0];
const int numel = input_grad->numel();
const int expand = numel / loss_grad->numel();
dev_ctx.template Alloc<T>(input_grad);
auto target_t = phi::EigenVector<T>::Flatten(*target);
auto input_grad_t = phi::EigenVector<T>::Flatten(*input_grad);
auto loss_grad_t = phi::EigenVector<T>::Flatten(*loss_grad);
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
auto grad_t = target_t * loss_grad_expand;
input_grad_t.device(place) =
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
} else if ("batchmean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(n);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossForward {
HOSTDEVICE KLDivLossForward() {}
HOSTDEVICE T operator()(const T& target, const T& input) const {
if (target <= 0) {
return 0;
} else {
return target * (std::log(target) - input);
}
}
};
template <typename T, typename Context>
void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
DenseTensor* out) {
auto& place = *(dev_ctx.eigen_device());
auto* input = &x;
auto* target = &label;
auto* loss = out;
const int n = input->dims()[0];
dev_ctx.template Alloc<T>(loss);
auto input_t = phi::EigenVector<T>::Flatten(*input);
auto target_t = phi::EigenVector<T>::Flatten(*target);
auto loss_t = phi::EigenVector<T>::Flatten(*loss);
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
if ("none" == reduction) {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
auto output_sum = output.sum();
if (n > 0) {
loss_t.device(place) = output_sum / output_sum.constant(n);
} else {
loss_t.device(place) = output_sum;
}
} else if ("mean" == reduction) {
loss_t.device(place) = output.mean();
} else if ("sum" == reduction) {
loss_t.device(place) = output.sum();
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T>
struct LgammaGradFunctor {
LgammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]);
}
private:
const T* dout_;
const T* x_;
T* output_;
int64_t numel_;
};
template <typename T, typename Context>
void LgammaGradKernel(const Context& dev_ctx,
const DenseTensor& d_out,
const DenseTensor& x,
DenseTensor* d_x) {
auto numel = d_out.numel();
auto* dout_data = d_out.data<T>();
auto* x_data = x.data<T>();
auto* dx_data =
dev_ctx.template Alloc<T>(d_x, static_cast<size_t>(numel * sizeof(T)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
LgammaGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IscloseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& rtol,
const Scalar& atol,
bool equal_nan,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
// XKTODO (change name)
void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
DenseTensor* d_x);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LgammaGradKernel(const Context& dev_ctx,
const DenseTensor& d_out,
const DenseTensor& x,
DenseTensor* d_x);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LgammaKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CumprodGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cumprod_grad",
{"X", "Out", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(cumprod_grad, phi::CumprodGradGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IscloseOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Rtol")) {
if (ctx.HasInput("Atol")) {
return KernelSignature("isclose",
{"Input", "Other"},
{"Rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("isclose",
{"Input", "Other"},
{"Rtol", "atol", "equal_nan"},
{"Out"});
}
} else {
if (ctx.HasInput("Atol")) {
return KernelSignature("isclose",
{"Input", "Other"},
{"rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("isclose",
{"Input", "Other"},
{"rtol", "atol", "equal_nan"},
{"Out"});
}
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(isclose, phi::IscloseOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature KLDivLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("kldiv_loss_grad",
{"X", "Target", GradVarName("Loss")},
{"reduction"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(kldiv_loss_grad,
phi::KLDivLossGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"lgamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lgamma_grad, phi::LgammaGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册