未验证 提交 b7bbe39c 编写于 作者: L Linjie Chen 提交者: GitHub

[phi] move sigmoid_cross_entopy_with_logits log_loss cumsum auc kernel to phi (#39976)

* move sigmoid cross entopy with logits to phi

* fix ci

* move log_loss to phi

* move cumsum to phi

* revert infershape

* fix xpu ci

* move auc to phi

* remove comment

* update sigmoid_cross_entropy_with_logits_op.cu

* update sigmoid_cross_entropy_with_logits_op

* Update log_loss
上级 0bfba16b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <array>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename Functor>
class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto& X = GET_DATA_SAFELY(context.Input<framework::Tensor>("X"), "Input",
"X", "Cum");
auto& Out = GET_DATA_SAFELY(context.Output<framework::Tensor>("Out"),
"Output", "Out", "Cum");
int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse");
auto out_dims = Out.dims();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()), true,
platform::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(), out_dims.size() - 1, axis));
if (axis < 0) {
axis += out_dims.size();
}
Out.template mutable_data<T>(context.GetPlace());
int pre = 1;
int post = 1;
int mid = out_dims[axis];
for (int i = 0; i < axis; ++i) {
pre *= out_dims[i];
}
for (int i = axis + 1; i < out_dims.size(); ++i) {
post *= out_dims[i];
}
auto x = framework::EigenVector<T>::Flatten(X);
auto out = framework::EigenVector<T>::Flatten(Out);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
using IndexT = Eigen::DenseIndex;
if (pre == 1) {
if (post == 1) {
ComputeImp(*place, Eigen::DSizes<IndexT, 1>(mid), x, out,
/* axis= */ 0, reverse, exclusive);
} else {
ComputeImp(*place, Eigen::DSizes<IndexT, 2>(mid, post), x, out,
/* axis= */ 0, reverse, exclusive);
}
} else {
if (post == 1) {
ComputeImp(*place, Eigen::DSizes<IndexT, 2>(pre, mid), x, out,
/* axis= */ 1, reverse, exclusive);
} else {
ComputeImp(*place, Eigen::DSizes<IndexT, 3>(pre, mid, post), x, out,
/* axis= */ 1, reverse, exclusive);
}
}
}
private:
template <typename Device, typename Dim, typename X, typename Out>
void ComputeImp(Device d, const Dim& dims, X x, Out out, int axis,
bool reverse, bool exclusive) const {
if (!reverse) {
out.reshape(dims).device(d) = Functor()(x.reshape(dims), axis, exclusive);
} else {
std::array<bool, Dim::count> rev;
rev.fill(false);
rev[axis] = reverse;
out.reshape(dims).device(d) =
Functor()(x.reshape(dims).reverse(rev), axis, exclusive).reverse(rev);
}
}
};
template <typename T>
struct CumsumFunctor {
using ELEMENT_TYPE = T;
template <typename X>
const typename X::TensorScanSumOp operator()(X x, int axis,
bool exclusive) const {
return x.cumsum(axis, exclusive);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/cum_op.h"
namespace paddle {
namespace operators {
......@@ -91,11 +91,6 @@ using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>,
ops::CumKernel<CPU, ops::CumsumFunctor<double>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int16_t>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int64_t>>);
REGISTER_OP_VERSION(cumsum)
.AddCheckpoint(
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_loss_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -149,13 +149,3 @@ REGISTER_OPERATOR(log_loss, ops::LogLossOp, ops::LogLossOpMaker<float>,
ops::LogLossGradMaker<paddle::framework::OpDesc>,
ops::LogLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(log_loss_grad, ops::LogLossGradOp);
REGISTER_OP_CPU_KERNEL(
log_loss, ops::LogLossKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
log_loss_grad,
ops::LogLossGradKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
log_loss, ops::LogLossKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
log_loss_grad,
ops::LogLossGradKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T, typename AttrType = T>
class LogLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* loss_out = ctx.Output<Tensor>("Loss");
loss_out->mutable_data<T>(ctx.GetPlace());
auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon"));
auto prediction = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Predicted"));
auto label = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Labels"));
auto loss = EigenVector<T>::Flatten(*loss_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
EigenLogLoss<std::decay_t<decltype(place)>, T>::Eval(
place, loss, prediction, label, epsilon);
}
};
template <typename DeviceContext, typename T, typename AttrType = T>
class LogLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon"));
auto prediction = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Predicted"));
auto label = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Labels"));
auto* dloss = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* dpred = ctx.Output<Tensor>(framework::GradVarName("Predicted"));
auto dl = EigenVector<T>::Flatten(*dloss);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
if (dpred) {
dpred->mutable_data<T>(ctx.GetPlace());
auto dx = framework::EigenVector<T>::Flatten(*dpred);
EigenLogLossGrad<std::decay_t<decltype(place)>, T>::Eval(
place, dx, dl, prediction, label, epsilon);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_loss_op.h"
#include <cmath>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -10,11 +10,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/log_loss_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T, typename AttrType = T>
class LogLossXPUKernel : public framework::OpKernel<T> {
public:
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/metrics/auc_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -146,4 +146,3 @@ There are two types of possible curves:
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker);
REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/metrics/auc_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
__global__ void ClearObsoleteDataKernel(int64_t *pos, int64_t *neg,
const int bucket_length,
const int slide_steps) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
CUDA_KERNEL_LOOP(i, bucket_length) {
pos[sum_step_begin + i] -= pos[cur_step_begin + i];
neg[sum_step_begin + i] -= neg[cur_step_begin + i];
pos[cur_step_begin + i] = neg[cur_step_begin + i] = 0;
}
}
__global__ void UpdateSumDataKernel(int64_t *pos, int64_t *neg,
const int bucket_length,
const int slide_steps) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
CUDA_KERNEL_LOOP(i, bucket_length) {
pos[sum_step_begin + i] += pos[cur_step_begin + i];
neg[sum_step_begin + i] += neg[cur_step_begin + i];
}
}
template <typename T>
__global__ void AddDataKernel(const int64_t *label_data, const T *pred_data,
const int inference_width,
const int num_thresholds, int64_t *pos,
int64_t *neg, const int numel,
const int slide_steps) {
int cur_step_begin = 0;
if (slide_steps > 0) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * (1 + num_thresholds)]) %
slide_steps;
cur_step_begin = cur_step_index * (1 + num_thresholds);
}
CUDA_KERNEL_LOOP(i, numel) {
auto predict_data = pred_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE(predict_data <= 1, "The predict data must less or equal 1.");
PADDLE_ENFORCE(predict_data >= 0,
"The predict data must gather or equal 0.");
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
paddle::platform::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1);
} else {
paddle::platform::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1);
}
}
}
__global__ void CalcAucKernel(int64_t *stat_pos, int64_t *stat_neg,
int num_thresholds, double *auc,
bool need_add_batch_num) {
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
int idx = num_thresholds;
while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += (totNeg - totNegPrev) * (totPos + totPosPrev) / 2.0;
--idx;
}
if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
}
if (need_add_batch_num) {
stat_pos[num_thresholds + 1] += 1;
stat_neg[num_thresholds + 1] += 1;
}
}
template <typename DeviceContext, typename T>
class AucCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps");
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto *auc_tensor = ctx.Output<Tensor>("AUC");
auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
auto *auc_value = auc_tensor->mutable_data<double>(ctx.GetPlace());
auto *stat_pos_in_tensor = ctx.Input<Tensor>("StatPos");
auto *pos_in_data = stat_pos_in_tensor->data<int64_t>();
auto *stat_neg_in_tensor = ctx.Input<Tensor>("StatNeg");
auto *neg_in_data = stat_neg_in_tensor->data<int64_t>();
#ifdef PADDLE_WITH_CUDA
if (stat_pos_in_tensor != stat_pos) {
cudaMemcpy(origin_stat_pos, pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
cudaMemcpyDeviceToDevice);
}
if (stat_neg_in_tensor != stat_neg) {
cudaMemcpy(origin_stat_neg, neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
cudaMemcpyDeviceToDevice);
}
#else
if (stat_pos_in_tensor != stat_pos) {
hipMemcpy(origin_stat_pos, pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
hipMemcpyDeviceToDevice);
}
if (stat_neg_in_tensor != stat_neg) {
hipMemcpy(origin_stat_neg, neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
hipMemcpyDeviceToDevice);
}
#endif
statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
CalcAucKernel<<<1, 1, 0, stream>>>(
origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
num_thresholds, auc_value, slide_steps > 0);
}
private:
inline static double trapezoidArea(double X1, double X2, double Y1,
double Y2) {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
}
inline static void statAuc(const framework::ExecutionContext &ctx,
const framework::Tensor *label,
const framework::Tensor *predict,
const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>();
const int bucket_length = num_thresholds + 1;
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
if (slide_steps == 0) {
AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
label_data, inference_data, inference_width, num_thresholds,
origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
return;
}
// the last number of origin_stat_pos store the index should be used in
// current step
int cur_step_index =
static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
ClearObsoleteDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
label_data, inference_data, inference_width, num_thresholds,
origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(auc,
ops::AucCUDAKernel<paddle::platform::CUDAPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class AucKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps");
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto *auc_tensor = ctx.Output<Tensor>("AUC");
auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
auto *auc_value = auc_tensor->mutable_data<double>(ctx.GetPlace());
// Just for pass UT, since UT's input & output connot be set same var
auto *stat_pos_in_tensor = ctx.Input<Tensor>("StatPos");
auto *pos_in_data = stat_pos_in_tensor->data<int64_t>();
auto *stat_neg_in_tensor = ctx.Input<Tensor>("StatNeg");
auto *neg_in_data = stat_neg_in_tensor->data<int64_t>();
if (stat_pos_in_tensor != stat_pos) {
memcpy(origin_stat_pos, pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
if (stat_neg_in_tensor != stat_neg) {
memcpy(origin_stat_neg, neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
statAuc(label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
calcAuc(origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
num_thresholds, auc_value);
if (slide_steps) {
origin_stat_pos[(slide_steps + 1) * (num_thresholds + 1)] += 1;
origin_stat_neg[(slide_steps + 1) * (num_thresholds + 1)] += 1;
}
}
private:
inline static double trapezoidArea(double X1, double X2, double Y1,
double Y2) {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
}
inline static void statAuc(const framework::Tensor *label,
const framework::Tensor *predict,
const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>();
const int bucket_length = num_thresholds + 1;
if (slide_steps == 0) {
for (size_t i = 0; i < batch_size; i++) {
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
// if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data, 1,
platform::errors::PreconditionNotMet(
"The predict data must less or equal 1."));
PADDLE_ENFORCE_GE(predict_data, 0,
platform::errors::PreconditionNotMet(
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i] > 0) {
origin_stat_pos[binIdx] += 1;
} else if (label_data[i] == 0) {
origin_stat_neg[binIdx] += 1;
}
}
return;
}
// the last number of origin_stat_pos store the index should be used in
// current step
int cur_step_index =
static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] -=
origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] -=
origin_stat_neg[cur_step_begin + i];
}
std::memset(origin_stat_pos + cur_step_begin, 0,
bucket_length * sizeof(int64_t));
std::memset(origin_stat_neg + cur_step_begin, 0,
bucket_length * sizeof(int64_t));
for (size_t i = 0; i < batch_size; i++) {
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
// if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data, 1,
platform::errors::PreconditionNotMet(
"The predict data must less or equal 1."));
PADDLE_ENFORCE_GE(predict_data, 0,
platform::errors::PreconditionNotMet(
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i] > 0) {
origin_stat_pos[cur_step_begin + binIdx] += 1;
} else if (label_data[i] == 0) {
origin_stat_neg[cur_step_begin + binIdx] += 1;
}
}
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] +=
origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] +=
origin_stat_neg[cur_step_begin + i];
}
}
inline static void calcAuc(const int64_t *stat_pos, const int64_t *stat_neg,
int num_thresholds, double *auc) {
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
int idx = num_thresholds;
while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
--idx;
}
if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
const int kIgnoreIndex = -100;
class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public:
......@@ -209,14 +210,3 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp,
ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext,
float>,
ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CPUDeviceContext, float>,
ops::SigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#ifdef __HIPCC__
static constexpr int kNumCUDAThreads = 256;
#else
static constexpr int kNumCUDAThreads = 512;
#endif
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(static_cast<double>(x) != 0);
}
};
template <typename T>
struct SigmoidFwdFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidFwdFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label) {
T counts;
T out_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
out_data = static_cast<T>(0.);
counts = 0;
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = real_log(static_cast<T>(1) + real_exp(static_cast<T>(-abs(x))));
out_data = term1 - term2 + term3;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = out_data;
outs[1] = counts;
return outs;
}
};
template <typename T>
struct SigmoidBwdFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidBwdFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label,
const T dout) {
T counts;
T dx_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
dx_data = static_cast<T>(0.);
counts = 0;
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + real_exp(-x));
T diff = simoid_x - label;
dx_data = dout * diff;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = dx_data;
outs[1] = counts;
return outs;
}
};
template <typename T>
struct DivFunctor {
const T norm_;
HOSTDEVICE inline DivFunctor(const T norm) : norm_(norm) {}
HOSTDEVICE inline T operator()(T loss) {
loss /= norm_;
return loss;
}
};
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
Tensor *Out = context.Output<Tensor>("Out");
int ignore_index = context.Attr<int>("ignore_index");
auto out_data = Out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.cuda_device_context();
bool normalize = context.Attr<bool>("normalize");
// Temporary memory
Tensor *counts_tensor = new Tensor();
counts_tensor->mutable_data<T>(context.GetPlace(),
Labels->numel() * sizeof(T));
counts_tensor->Resize(Out->dims());
int limit = Out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
std::vector<const framework::Tensor *> ins = {X, Labels};
std::vector<framework::Tensor *> outs = {Out, counts_tensor};
auto functor = SigmoidFwdFunctor<T>(ignore_index);
constexpr int Size = 2;
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(dev_ctx, ins,
&outs, functor);
if (normalize) {
T *counts = counts_tensor->mutable_data<T>(context.GetPlace());
Tensor *norm_tensor = new Tensor();
norm_tensor->mutable_data<T>(context.GetPlace(), sizeof(T));
auto dims = phi::vectorize(counts_tensor->dims());
std::vector<int> reduce_dim = {};
for (int i = 0; i < dims.size(); i++) {
reduce_dim.push_back(i);
}
TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
context.cuda_device_context(), *counts_tensor, norm_tensor,
NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream());
T *norm = norm_tensor->mutable_data<T>(context.GetPlace());
auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm,
sizeof(T), dev_ctx.stream());
auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
std::vector<const framework::Tensor *> div_ins = {Out};
std::vector<framework::Tensor *> div_outs = {Out};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
div_functor);
delete norm_tensor;
delete counts_tensor;
}
}
};
// dX = sigmoid(X) - labels
template <typename DeviceContext, typename T>
class GPUSigmoidCrossEntropyWithLogitsGradKernel
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
int ignore_index = context.Attr<int>("ignore_index");
auto &dev_ctx = context.cuda_device_context();
// Temporary memory
Tensor *counts_tensor = new Tensor();
counts_tensor->mutable_data<T>(context.GetPlace(),
Labels->numel() * sizeof(T));
counts_tensor->Resize(dX->dims());
int limit = dX->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
std::vector<const framework::Tensor *> ins = {X, Labels, dOut};
std::vector<framework::Tensor *> outs = {dX, counts_tensor};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
constexpr int Size = 2;
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(dev_ctx, ins,
&outs, functor);
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
T *counts = counts_tensor->mutable_data<T>(context.GetPlace());
Tensor *norm_tensor = new Tensor();
norm_tensor->mutable_data<T>(context.GetPlace(), sizeof(T));
auto dims = phi::vectorize(counts_tensor->dims());
std::vector<int> reduce_dim = {};
for (int i = 0; i < dims.size(); i++) {
reduce_dim.push_back(i);
}
TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
context.cuda_device_context(), *counts_tensor, norm_tensor,
NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream());
T *norm = norm_tensor->mutable_data<T>(context.GetPlace());
auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm,
sizeof(T), dev_ctx.stream());
auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
std::vector<const framework::Tensor *> div_ins = {dX};
std::vector<framework::Tensor *> div_outs = {dX};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
div_functor);
delete norm_tensor;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sigmoid_cross_entropy_with_logits,
ops::GPUSigmoidCrossEntropyWithLogitsKernel<
paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidCrossEntropyWithLogitsKernel<
paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(sigmoid_cross_entropy_with_logits_grad,
ops::GPUSigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <limits>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
const int kIgnoreIndex = -100;
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
Tensor *Out = context.Output<Tensor>("Out");
int ignore_index = context.Attr<int>("ignore_index");
auto out_data = Out->mutable_data<T>(context.GetPlace());
int limit = Out->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
if (static_cast<int>(label) == ignore_index) {
out_data[idx] = static_cast<T>(0.);
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
out_data[idx] = term1 - term2 + term3;
}
}
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
int norm = 0;
T eps = static_cast<T>(1e-6);
for (int idx = 0; idx < limit; ++idx) {
T diff = label_data[idx] - static_cast<T>(ignore_index);
if ((diff < -eps) || (diff > eps)) {
norm += 1;
}
}
eps = static_cast<T>(1e-5);
norm = norm > eps ? norm : eps;
std::for_each(out_data, out_data + limit, [norm](T &v) { v = v / norm; });
}
}
};
// dX = sigmoid(X) - labels
template <typename DeviceContext, typename T>
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
int ignore_index = context.Attr<int>("ignore_index");
int limit = dX->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<T>();
auto dout_data = dOut->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
T dout = dout_data[idx];
if (static_cast<int>(label) == ignore_index) {
dx_data[idx] = static_cast<T>(0.);
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
T diff = simoid_x - label;
dx_data[idx] = dout * diff;
}
}
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
int norm = 0;
T eps = static_cast<T>(1e-6);
for (int idx = 0; idx < limit; ++idx) {
T diff = label_data[idx] - static_cast<T>(ignore_index);
if ((diff < -eps) || (diff > eps)) {
norm += 1;
}
}
eps = static_cast<T>(1e-5);
norm = norm > eps ? norm : eps;
std::for_each(dx_data, dx_data + limit, [norm](T &v) { v = v / norm; });
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
const int kIgnoreIndex = -100;
void CheckAttrs(const framework::ExecutionContext& ctx) {
// Add this check is is due to Ascend SigmoidCrossEntropyWithLogits
......
......@@ -17,13 +17,15 @@
#include <memory>
#include <vector>
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SigmoidCrossEntropyWithLogitsXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AucKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& stat_pos,
const DenseTensor& stat_neg,
const std::string& curve,
int num_thresholds,
int slide_steps,
DenseTensor* auc,
DenseTensor* stat_pos_out,
DenseTensor* stat_neg_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/auc_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
inline static double trapezoidArea(double X1, double X2, double Y1, double Y2) {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
}
template <typename T>
void statAuc(const DenseTensor &label,
const DenseTensor &predict,
const int num_thresholds,
const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
size_t batch_size = predict.dims()[0];
size_t inference_width = predict.dims()[1];
const T *inference_data = predict.data<T>();
const auto *label_data = label.data<int64_t>();
const int bucket_length = num_thresholds + 1;
if (slide_steps == 0) {
for (size_t i = 0; i < batch_size; i++) {
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
// if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data,
1,
phi::errors::PreconditionNotMet(
"The predict data must less or equal 1."));
PADDLE_ENFORCE_GE(predict_data,
0,
phi::errors::PreconditionNotMet(
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i] > 0) {
origin_stat_pos[binIdx] += 1;
} else if (label_data[i] == 0) {
origin_stat_neg[binIdx] += 1;
}
}
return;
}
// the last number of origin_stat_pos store the index should be used in
// current step
int cur_step_index =
static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] -= origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] -= origin_stat_neg[cur_step_begin + i];
}
std::memset(
origin_stat_pos + cur_step_begin, 0, bucket_length * sizeof(int64_t));
std::memset(
origin_stat_neg + cur_step_begin, 0, bucket_length * sizeof(int64_t));
for (size_t i = 0; i < batch_size; i++) {
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
// if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data,
1,
phi::errors::PreconditionNotMet(
"The predict data must less or equal 1."));
PADDLE_ENFORCE_GE(predict_data,
0,
phi::errors::PreconditionNotMet(
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i] > 0) {
origin_stat_pos[cur_step_begin + binIdx] += 1;
} else if (label_data[i] == 0) {
origin_stat_neg[cur_step_begin + binIdx] += 1;
}
}
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] += origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] += origin_stat_neg[cur_step_begin + i];
}
}
inline static void calcAuc(const int64_t *stat_pos,
const int64_t *stat_neg,
int num_thresholds,
double *auc) {
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
int idx = num_thresholds;
while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
--idx;
}
if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
}
}
template <typename T, typename Context>
void AucKernel(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &label,
const DenseTensor &stat_pos,
const DenseTensor &stat_neg,
const std::string &curve,
int num_thresholds,
int slide_steps,
DenseTensor *auc,
DenseTensor *stat_pos_out,
DenseTensor *stat_neg_out) {
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto *origin_stat_pos = dev_ctx.template Alloc<int64_t>(stat_pos_out);
auto *origin_stat_neg = dev_ctx.template Alloc<int64_t>(stat_neg_out);
auto *auc_value = dev_ctx.template Alloc<double>(auc);
// Just for pass UT, since UT's input & output connot be set same var
auto *stat_pos_in_tensor = &stat_pos;
auto *stat_neg_in_tensor = &stat_neg;
auto *pos_in_data = stat_pos.data<int64_t>();
auto *neg_in_data = stat_neg.data<int64_t>();
if (stat_pos_in_tensor != stat_pos_out) {
memcpy(
origin_stat_pos,
pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
if (stat_neg_in_tensor != stat_neg_out) {
memcpy(
origin_stat_neg,
neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
statAuc<T>(label,
input,
num_thresholds,
slide_steps,
origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
calcAuc(origin_stat_pos + sum_offset,
origin_stat_neg + sum_offset,
num_thresholds,
auc_value);
if (slide_steps) {
origin_stat_pos[(slide_steps + 1) * (num_thresholds + 1)] += 1;
origin_stat_neg[(slide_steps + 1) * (num_thresholds + 1)] += 1;
}
}
} // namespace phi
PD_REGISTER_KERNEL(auc, CPU, ALL_LAYOUT, phi::AucKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumsum_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
struct CumsumFunctor {
template <typename X>
const typename X::TensorScanSumOp operator()(X x,
int axis,
bool exclusive) const {
return x.cumsum(axis, exclusive);
}
};
template <typename Device, typename Dim, typename X, typename Out>
void ComputeImp(Device d,
const Dim& dims,
X x,
Out out,
int axis,
bool reverse,
bool exclusive) {
if (!reverse) {
out.reshape(dims).device(d) =
CumsumFunctor()(x.reshape(dims), axis, exclusive);
} else {
std::array<bool, Dim::count> rev;
rev.fill(false);
rev[axis] = reverse;
out.reshape(dims).device(d) =
CumsumFunctor()(x.reshape(dims).reverse(rev), axis, exclusive)
.reverse(rev);
}
}
template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
auto out_dims = out->dims();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(),
out_dims.size() - 1,
axis));
if (axis < 0) {
axis += out_dims.size();
}
dev_ctx.template Alloc<T>(out);
int pre = 1;
int post = 1;
int mid = out_dims[axis];
for (int i = 0; i < axis; ++i) {
pre *= out_dims[i];
}
for (int i = axis + 1; i < out_dims.size(); ++i) {
post *= out_dims[i];
}
auto x0 = EigenVector<T>::Flatten(x);
auto out0 = EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
using IndexT = Eigen::DenseIndex;
if (pre == 1) {
if (post == 1) {
ComputeImp(place,
Eigen::DSizes<IndexT, 1>(mid),
x0,
out0,
/* axis= */ 0,
reverse,
exclusive);
} else {
ComputeImp(place,
Eigen::DSizes<IndexT, 2>(mid, post),
x0,
out0,
/* axis= */ 0,
reverse,
exclusive);
}
} else {
if (post == 1) {
ComputeImp(place,
Eigen::DSizes<IndexT, 2>(pre, mid),
x0,
out0,
/* axis= */ 1,
reverse,
exclusive);
} else {
ComputeImp(place,
Eigen::DSizes<IndexT, 3>(pre, mid, post),
x0,
out0,
/* axis= */ 1,
reverse,
exclusive);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(cumsum,
CPU,
ALL_LAYOUT,
phi::CumsumKernel,
float,
double,
int16_t,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_loss_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/log_loss_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
log_loss_grad, CPU, ALL_LAYOUT, phi::LogLossGradKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_loss_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/log_loss_kernel_impl.h"
PD_REGISTER_KERNEL(log_loss, CPU, ALL_LAYOUT, phi::LogLossKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sigmoid_cross_entropy_with_logits_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
DenseTensor* in_grad) {
auto dx_data = dev_ctx.template Alloc<T>(in_grad);
int limit = in_grad->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto dout_data = out_grad.data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
T dout = dout_data[idx];
if (static_cast<int>(label) == ignore_index) {
dx_data[idx] = static_cast<T>(0.);
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
T diff = simoid_x - label;
dx_data[idx] = dout * diff;
}
}
if (normalize) {
int norm = 0;
T eps = static_cast<T>(1e-6);
for (int idx = 0; idx < limit; ++idx) {
T diff = label_data[idx] - static_cast<T>(ignore_index);
if ((diff < -eps) || (diff > eps)) {
norm += 1;
}
}
eps = static_cast<T>(1e-5);
norm = norm > eps ? norm : eps;
std::for_each(dx_data, dx_data + limit, [norm](T& v) { v = v / norm; });
}
}
} // namespace phi
PD_REGISTER_KERNEL(sigmoid_cross_entropy_with_logits_grad,
CPU,
ALL_LAYOUT,
phi::SigmoidCrossEntropyWithLogitsGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sigmoid_cross_entropy_with_logits_kernel.h"
#include <algorithm>
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
bool normalize,
int ignore_index,
DenseTensor* out) {
auto out_data = dev_ctx.template Alloc<T>(out);
int limit = out->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
if (static_cast<int>(label) == ignore_index) {
out_data[idx] = static_cast<T>(0.);
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
out_data[idx] = term1 - term2 + term3;
}
}
if (normalize) {
int norm = 0;
T eps = static_cast<T>(1e-6);
for (int idx = 0; idx < limit; ++idx) {
T diff = label_data[idx] - static_cast<T>(ignore_index);
if ((diff < -eps) || (diff > eps)) {
norm += 1;
}
}
eps = static_cast<T>(1e-5);
norm = norm > eps ? norm : eps;
std::for_each(out_data, out_data + limit, [norm](T& v) { v = v / norm; });
}
}
} // namespace phi
PD_REGISTER_KERNEL(sigmoid_cross_entropy_with_logits,
CPU,
ALL_LAYOUT,
phi::SigmoidCrossEntropyWithLogitsKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename Functor, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/auc_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
__global__ void ClearObsoleteDataKernel(int64_t *pos,
int64_t *neg,
const int bucket_length,
const int slide_steps) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
CUDA_KERNEL_LOOP(i, bucket_length) {
pos[sum_step_begin + i] -= pos[cur_step_begin + i];
neg[sum_step_begin + i] -= neg[cur_step_begin + i];
pos[cur_step_begin + i] = neg[cur_step_begin + i] = 0;
}
}
__global__ void UpdateSumDataKernel(int64_t *pos,
int64_t *neg,
const int bucket_length,
const int slide_steps) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
CUDA_KERNEL_LOOP(i, bucket_length) {
pos[sum_step_begin + i] += pos[cur_step_begin + i];
neg[sum_step_begin + i] += neg[cur_step_begin + i];
}
}
template <typename T>
__global__ void AddDataKernel(const int64_t *label_data,
const T *pred_data,
const int inference_width,
const int num_thresholds,
int64_t *pos,
int64_t *neg,
const int numel,
const int slide_steps) {
int cur_step_begin = 0;
if (slide_steps > 0) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * (1 + num_thresholds)]) %
slide_steps;
cur_step_begin = cur_step_index * (1 + num_thresholds);
}
CUDA_KERNEL_LOOP(i, numel) {
auto predict_data = pred_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE(predict_data <= 1, "The predict data must less or equal 1.");
PADDLE_ENFORCE(predict_data >= 0,
"The predict data must gather or equal 0.");
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
paddle::platform::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1);
} else {
paddle::platform::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1);
}
}
}
__global__ void CalcAucKernel(int64_t *stat_pos,
int64_t *stat_neg,
int num_thresholds,
double *auc,
bool need_add_batch_num) {
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
int idx = num_thresholds;
while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += (totNeg - totNegPrev) * (totPos + totPosPrev) / 2.0;
--idx;
}
if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
}
if (need_add_batch_num) {
stat_pos[num_thresholds + 1] += 1;
stat_neg[num_thresholds + 1] += 1;
}
}
inline static double trapezoidArea(double X1, double X2, double Y1, double Y2) {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
}
template <typename T, typename Context>
void statAuc(const Context &dev_ctx,
const DenseTensor &label,
const DenseTensor &predict,
const int num_thresholds,
const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
size_t batch_size = predict.dims()[0];
size_t inference_width = predict.dims()[1];
const T *inference_data = predict.data<T>();
const auto *label_data = label.data<int64_t>();
const int bucket_length = num_thresholds + 1;
if (slide_steps == 0) {
AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
dev_ctx.stream()>>>(label_data,
inference_data,
inference_width,
num_thresholds,
origin_stat_pos,
origin_stat_neg,
batch_size,
slide_steps);
return;
}
// the last number of origin_stat_pos store the index should be used in
// current step
int cur_step_index =
static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
ClearObsoleteDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
dev_ctx.stream()>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
dev_ctx.stream()>>>(label_data,
inference_data,
inference_width,
num_thresholds,
origin_stat_pos,
origin_stat_neg,
batch_size,
slide_steps);
UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
dev_ctx.stream()>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
}
template <typename T, typename Context>
void AucKernel(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &label,
const DenseTensor &stat_pos,
const DenseTensor &stat_neg,
const std::string &curve,
int num_thresholds,
int slide_steps,
DenseTensor *auc,
DenseTensor *stat_pos_out,
DenseTensor *stat_neg_out) {
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto *origin_stat_pos = dev_ctx.template Alloc<int64_t>(stat_pos_out);
auto *origin_stat_neg = dev_ctx.template Alloc<int64_t>(stat_neg_out);
auto *auc_value = dev_ctx.template Alloc<double>(auc);
auto *stat_pos_in_tensor = &stat_pos;
auto *stat_neg_in_tensor = &stat_neg;
auto *pos_in_data = stat_pos.data<int64_t>();
auto *neg_in_data = stat_neg.data<int64_t>();
#ifdef PADDLE_WITH_CUDA
if (stat_pos_in_tensor != stat_pos_out) {
cudaMemcpy(
origin_stat_pos,
pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
cudaMemcpyDeviceToDevice);
}
if (stat_neg_in_tensor != stat_neg_out) {
cudaMemcpy(
origin_stat_neg,
neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
cudaMemcpyDeviceToDevice);
}
#else
if (stat_pos_in_tensor != stat_pos_out) {
hipMemcpy(
origin_stat_pos,
pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
hipMemcpyDeviceToDevice);
}
if (stat_neg_in_tensor != stat_neg_out) {
hipMemcpy(
origin_stat_neg,
neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
hipMemcpyDeviceToDevice);
}
#endif
statAuc<T, Context>(dev_ctx,
label,
input,
num_thresholds,
slide_steps,
origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
CalcAucKernel<<<1, 1, 0, dev_ctx.stream()>>>(origin_stat_pos + sum_offset,
origin_stat_neg + sum_offset,
num_thresholds,
auc_value,
slide_steps > 0);
}
} // namespace phi
PD_REGISTER_KERNEL(auc, GPU, ALL_LAYOUT, phi::AucKernel, float) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumsum_kernel.h"
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
......@@ -23,18 +25,16 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
using Tensor = paddle::framework::Tensor;
using LoDTensor = paddle::framework::LoDTensor;
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename T, int BLOCK_SIZE>
__device__ void BlockReverse(const T* idata, T* odata, int src_base,
int dst_base, int valid_item) {
__device__ void BlockReverse(
const T* idata, T* odata, int src_base, int dst_base, int valid_item) {
__shared__ T sh_mem[BLOCK_SIZE];
int tx = threadIdx.x;
......@@ -57,8 +57,10 @@ __device__ void BlockReverse(const T* idata, T* odata, int src_base,
}
template <typename T>
__global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data,
int reverse_size, int outer_size,
__global__ void MatrixRowReverse(const T* matrix_data,
T* reverse_data,
int reverse_size,
int outer_size,
int inner_size) {
int bx = blockIdx.x;
int by = blockIdx.y;
......@@ -77,8 +79,8 @@ __global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data,
valid_item = reverse_size;
}
BlockReverse<T, 1024>(matrix_data, reverse_data, src_offset, dst_offset,
valid_item);
BlockReverse<T, 1024>(
matrix_data, reverse_data, src_offset, dst_offset, valid_item);
}
}
......@@ -101,7 +103,9 @@ struct BlockPrefixCallbackOp {
// No bank-conflict transpose
template <typename T, int TILE_DIM, int BLOCK_ROWS>
__global__ void MatrixTranspose(T* odata, const T* idata, size_t height,
__global__ void MatrixTranspose(T* odata,
const T* idata,
size_t height,
size_t width) {
__shared__ T tile[TILE_DIM][TILE_DIM + 1];
......@@ -128,14 +132,18 @@ __global__ void MatrixTranspose(T* odata, const T* idata, size_t height,
}
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD>
__global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size,
int outer_size, int scan_size, bool exclusive) {
__global__ void BlockScanKernel(T* d_out,
const T* d_in,
int inner_size,
int outer_size,
int scan_size,
bool exclusive) {
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_LOAD_TRANSPOSE>
typedef cub::
BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT;
typedef cub::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
typedef cub::
BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT;
typedef cub::BlockScan<T, BLOCK_THREADS> BlockScanT;
// Allocate type-safe, repurposable shared memory for collectives
......@@ -184,31 +192,32 @@ __global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size,
}
}
template <typename DeviceContext, typename T>
class CumCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse");
template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
auto out_dims = out->dims();
auto size = in->numel();
auto size = x.numel();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()), true,
platform::errors::OutOfRange(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(), out_dims.size() - 1, axis));
out_dims.size(),
out_dims.size() - 1,
axis));
if (axis < 0) {
axis += out_dims.size();
}
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* in_data = in->data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
const T* in_data = x.data<T>();
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
......@@ -218,20 +227,20 @@ class CumCUDAKernel : public framework::OpKernel<T> {
thrust::device_pointer_cast(in_data);
thrust::device_vector<T> vec(dev_ptr, dev_ptr + size);
if (exclusive) {
thrust::exclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
out_data);
thrust::exclusive_scan(
thrust::device, vec.rbegin(), vec.rend(), out_data);
} else {
thrust::inclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
out_data);
thrust::inclusive_scan(
thrust::device, vec.rbegin(), vec.rend(), out_data);
}
thrust::reverse(thrust::device, out_data, out_data + size);
} else {
if (exclusive) {
thrust::exclusive_scan(thrust::device, in_data, in_data + size,
out_data);
thrust::exclusive_scan(
thrust::device, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(thrust::device, in_data, in_data + size,
out_data);
thrust::inclusive_scan(
thrust::device, in_data, in_data + size, out_data);
}
}
return;
......@@ -253,15 +262,13 @@ class CumCUDAKernel : public framework::OpKernel<T> {
dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
auto& dev_ctx = context.template device_context<DeviceContext>();
framework::Tensor tmp;
tmp.Resize(out_dims);
auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
out->Resize(out_dims);
auto* tmp_data = out->data<T>();
T* next_in_data = out_data;
T* next_out_data = tmp_data;
if (transpose) {
MatrixTranspose<T, 32,
8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
MatrixTranspose<T, 32, 8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
out_data, in_data, height, width);
next_in_data = out_data;
next_out_data = tmp_data;
......@@ -295,7 +302,11 @@ class CumCUDAKernel : public framework::OpKernel<T> {
} else {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
next_out_data, next_in_data, outer_size, inner_size, scan_size,
next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
}
swap_ptr(next_in_data, next_out_data);
......@@ -307,19 +318,19 @@ class CumCUDAKernel : public framework::OpKernel<T> {
if (transpose) {
transpose_grids.x = (height + tile_size - 1) / tile_size;
transpose_grids.y = (width + tile_size - 1) / tile_size;
MatrixTranspose<T, 32,
8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
MatrixTranspose<T, 32, 8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
next_out_data, next_in_data, width, height);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cumsum, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
}
} // namespace phi
PD_REGISTER_KERNEL(cumsum,
GPU,
ALL_LAYOUT,
phi::CumsumKernel,
float,
double,
int16_t,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_loss_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/log_loss_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
log_loss_grad, GPU, ALL_LAYOUT, phi::LogLossGradKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_loss_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/log_loss_kernel_impl.h"
PD_REGISTER_KERNEL(log_loss, GPU, ALL_LAYOUT, phi::LogLossKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
namespace phi {
#ifdef __HIPCC__
static constexpr int kNumCUDAThreads = 256;
#else
static constexpr int kNumCUDAThreads = 512;
#endif
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(static_cast<double>(x) != 0);
}
};
template <typename T>
struct DivFunctor {
const T norm_;
HOSTDEVICE inline DivFunctor(const T norm) : norm_(norm) {}
HOSTDEVICE inline T operator()(T loss) {
loss /= norm_;
return loss;
}
};
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sigmoid_cross_entropy_with_logits_grad_kernel.h"
#include "paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h"
namespace phi {
template <typename T>
struct SigmoidBwdFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidBwdFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x,
const T label,
const T dout) {
T counts;
T dx_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
dx_data = static_cast<T>(0.);
counts = 0;
} else {
T simoid_x = static_cast<T>(1) /
(static_cast<T>(1) + paddle::operators::real_exp(-x));
T diff = simoid_x - label;
dx_data = dout * diff;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = dx_data;
outs[1] = counts;
return outs;
}
};
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const DenseTensor &out_grad,
bool normalize,
int ignore_index,
DenseTensor *in_grad) {
auto dx_data = dev_ctx.template Alloc<T>(in_grad);
// Temporary memory
DenseTensor *counts_tensor = new DenseTensor();
int64_t out_dims = label.numel() * sizeof(T);
counts_tensor->Resize({out_dims});
dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(in_grad->dims());
int limit = in_grad->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
std::vector<DenseTensor *> outs = {in_grad, counts_tensor};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
constexpr int Size = 2;
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(
dev_ctx, ins, &outs, functor);
if (normalize) {
T *counts = dev_ctx.template Alloc<T>(counts_tensor);
DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)});
dev_ctx.template Alloc<T>(norm_tensor);
auto dims = phi::vectorize(counts_tensor->dims());
std::vector<int> reduce_dim = {};
for (int i = 0; i < dims.size(); i++) {
reduce_dim.push_back(i);
}
kernels::TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
dev_ctx,
*counts_tensor,
norm_tensor,
NonzeroFunctor<T>(),
reduce_dim,
dev_ctx.stream());
T *norm = dev_ctx.template Alloc<T>(norm_tensor);
auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
paddle::memory::Copy(phi::CPUPlace(),
norm_cpu_ptr,
dev_ctx.GetPlace(),
norm,
sizeof(T),
dev_ctx.stream());
auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
std::vector<const DenseTensor *> div_ins = {in_grad};
std::vector<DenseTensor *> div_outs = {in_grad};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor);
delete norm_tensor;
}
}
} // namespace phi
PD_REGISTER_KERNEL(sigmoid_cross_entropy_with_logits_grad,
GPU,
ALL_LAYOUT,
phi::SigmoidCrossEntropyWithLogitsGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sigmoid_cross_entropy_with_logits_kernel.h"
#include "paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h"
namespace phi {
template <typename T>
struct SigmoidFwdFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidFwdFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label) {
T counts;
T out_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
out_data = static_cast<T>(0.);
counts = 0;
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = paddle::operators::real_log(
static_cast<T>(1) +
paddle::operators::real_exp(static_cast<T>(-abs(x))));
out_data = term1 - term2 + term3;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = out_data;
outs[1] = counts;
return outs;
}
};
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
bool normalize,
int ignore_index,
DenseTensor *out) {
auto out_data = dev_ctx.template Alloc<T>(out);
// Temporary memory
DenseTensor *counts_tensor = new DenseTensor();
int64_t out_dims = label.numel() * sizeof(T);
counts_tensor->Resize({out_dims});
dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(out->dims());
int limit = out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
std::vector<const DenseTensor *> ins = {&x, &label};
std::vector<DenseTensor *> outs = {out, counts_tensor};
auto functor = SigmoidFwdFunctor<T>(ignore_index);
constexpr int Size = 2;
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(
dev_ctx, ins, &outs, functor);
if (normalize) {
T *counts = dev_ctx.template Alloc<T>(counts_tensor);
DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)});
dev_ctx.template Alloc<T>(norm_tensor);
auto dims = phi::vectorize(counts_tensor->dims());
std::vector<int> reduce_dim = {};
for (int i = 0; i < dims.size(); i++) {
reduce_dim.push_back(i);
}
kernels::TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
dev_ctx,
*counts_tensor,
norm_tensor,
NonzeroFunctor<T>(),
reduce_dim,
dev_ctx.stream());
T *norm = dev_ctx.template Alloc<T>(norm_tensor);
auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
paddle::memory::Copy(phi::CPUPlace(),
norm_cpu_ptr,
dev_ctx.GetPlace(),
norm,
sizeof(T),
dev_ctx.stream());
auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
std::vector<const DenseTensor *> div_ins = {out};
std::vector<DenseTensor *> div_outs = {out};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor);
delete norm_tensor;
delete counts_tensor;
}
}
} // namespace phi
PD_REGISTER_KERNEL(sigmoid_cross_entropy_with_logits,
GPU,
ALL_LAYOUT,
phi::SigmoidCrossEntropyWithLogitsKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename Context>
void LogLossGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& out_grad,
float epsilon,
DenseTensor* in_grad) {
auto prediction = EigenVector<T>::Flatten(input);
auto label_out = EigenVector<T>::Flatten(label);
auto dl = EigenVector<T>::Flatten(out_grad);
auto& place = *dev_ctx.eigen_device();
if (in_grad) {
dev_ctx.template Alloc<T>(in_grad);
auto dx = EigenVector<T>::Flatten(*in_grad);
phi::funcs::EigenLogLossGrad<std::decay_t<decltype(place)>, T>::Eval(
place, dx, dl, prediction, label_out, epsilon);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename Context>
void LogLossKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
float epsilon,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto prediction = EigenVector<T>::Flatten(input);
auto label_out = EigenVector<T>::Flatten(label);
auto loss = EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
phi::funcs::EigenLogLoss<std::decay_t<decltype(place)>, T>::Eval(
place, loss, prediction, label_out, epsilon);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogLossGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& out_grad,
float epsilon,
DenseTensor* in_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogLossKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
float epsilon,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
DenseTensor* in_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
bool normalize,
int ignore_index,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LogLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("log_loss_grad",
{"Predicted", "Labels", GradVarName("Loss")},
{"epsilon"},
{GradVarName("Predicted")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(log_loss_grad, phi::LogLossGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SigmoidCrossEntropyWithLogitsKernelGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sigmoid_cross_entropy_with_logits_grad",
{"X", "Label", GradVarName("Out")},
{"normalize", "ignore_index"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(
sigmoid_cross_entropy_with_logits_grad,
phi::SigmoidCrossEntropyWithLogitsKernelGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册