未验证 提交 3501ff7d 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move cross_entropy from fluid to phi (#48160)

* move cross_entropy from fluid to phi

* replace mutable_data with Alloc

* use .template
上级 88410225
...@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -237,9 +237,9 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> { ...@@ -237,9 +237,9 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits); auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits);
eigen_loss.device(*dev_ctx.eigen_device()) = eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue<T>()) - (eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue<T>()) -
eigen_predicted_logits) eigen_predicted_logits)
.unaryExpr(math::TolerableValue<T>()); .unaryExpr(phi::funcs::TolerableValue<T>());
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax * (eigen_softmax *
...@@ -372,9 +372,9 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> { ...@@ -372,9 +372,9 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits); auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits);
eigen_loss.device(*dev_ctx.eigen_device()) = eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue<T>()) - (eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue<T>()) -
eigen_predicted_logits) eigen_predicted_logits)
.unaryExpr(math::TolerableValue<T>()); .unaryExpr(phi::funcs::TolerableValue<T>());
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax * (eigen_softmax *
......
...@@ -22,9 +22,9 @@ limitations under the License. */ ...@@ -22,9 +22,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/phi/kernels/funcs/math.h" #include "paddle/phi/kernels/funcs/math.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
} }
int axis_dim = x->dims()[rank - 1]; int axis_dim = x->dims()[rank - 1];
math::CrossEntropyFunctor<DeviceContext, T>()( phi::funcs::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), ctx.template device_context<DeviceContext>(),
&y_2d, &y_2d,
&x_2d, &x_2d,
...@@ -190,7 +190,7 @@ struct HardLabelCrossEntropyForwardFunctor { ...@@ -190,7 +190,7 @@ struct HardLabelCrossEntropyForwardFunctor {
label); label);
auto match_x = x_[idx * feature_size_ + label]; auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(phi::funcs::real_log(match_x)); y_[idx] = -phi::funcs::TolerableValue<T>()(phi::funcs::real_log(match_x));
match_x_[idx] = match_x; match_x_[idx] = match_x;
} else { } else {
y_[idx] = 0; y_[idx] = 0;
......
...@@ -21,7 +21,6 @@ else() ...@@ -21,7 +21,6 @@ else()
math_library(concat_and_split DEPS concat_and_split_functor) math_library(concat_and_split DEPS concat_and_split_functor)
endif() endif()
math_library(context_project DEPS im2col math_function) math_library(context_project DEPS im2col math_function)
math_library(cross_entropy)
math_library(cos_sim_functor) math_library(cos_sim_functor)
math_library(depthwise_conv) math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
......
...@@ -16,10 +16,10 @@ limitations under the License. */ ...@@ -16,10 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,11 +14,11 @@ limitations under the License. */ ...@@ -14,11 +14,11 @@ limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_kernel.h" #include "paddle/phi/kernels/cross_entropy_kernel.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_kernel.h" #include "paddle/phi/kernels/softmax_kernel.h"
...@@ -64,7 +64,7 @@ void CrossEntropy(const CPUContext& dev_ctx, ...@@ -64,7 +64,7 @@ void CrossEntropy(const CPUContext& dev_ctx,
DenseTensor out_2d(*out); DenseTensor out_2d(*out);
out_2d.Resize({n, d / axis_dim}); out_2d.Resize({n, d / axis_dim});
paddle::operators::math::CrossEntropyFunctor<CPUContext, T>()( phi::funcs::CrossEntropyFunctor<CPUContext, T>()(
dev_ctx, &out_2d, &x_2d, &label_2d, soft_label, ignore_index, axis_dim); dev_ctx, &out_2d, &x_2d, &label_2d, soft_label, ignore_index, axis_dim);
} }
......
...@@ -16,6 +16,7 @@ math_library(pooling DEPS dense_tensor) ...@@ -16,6 +16,7 @@ math_library(pooling DEPS dense_tensor)
math_library(segment_pooling) math_library(segment_pooling)
math_library(sequence2batch) math_library(sequence2batch)
math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function) math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function)
math_library(cross_entropy)
cc_library( cc_library(
phi_data_layout_transform phi_data_layout_transform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,20 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/utils/data_type.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
template <typename T, template <typename T,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = phi::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
struct HardLabelCrossEntropyCPUFunctorImpl { struct HardLabelCrossEntropyCPUFunctorImpl {
...@@ -54,17 +53,17 @@ struct HardLabelCrossEntropyCPUFunctorImpl { ...@@ -54,17 +53,17 @@ struct HardLabelCrossEntropyCPUFunctorImpl {
for (int j = 0; j < num_remain; j++) { for (int j = 0; j < num_remain; j++) {
int lbl = static_cast<int>(label_data[i * num_remain + j]); int lbl = static_cast<int>(label_data[i * num_remain + j]);
if (lbl != ignore_index_) { if (lbl != ignore_index_) {
PADDLE_ENFORCE_GE(lbl, PADDLE_ENFORCE_GE(
0, lbl,
platform::errors::OutOfRange( 0,
"label value should >= 0 when label " phi::errors::OutOfRange("label value should >= 0 when label "
"value(%f) not equal to ignore_index(%f)", "value(%f) not equal to ignore_index(%f)",
lbl, lbl,
ignore_index_)); ignore_index_));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
lbl, lbl,
axis_dim_, axis_dim_,
platform::errors::OutOfRange( phi::errors::OutOfRange(
"label value should less than the shape of axis dimension " "label value should less than the shape of axis dimension "
"when label value(%f) not equal to ignore_index(%f), But " "when label value(%f) not equal to ignore_index(%f), But "
"received label value as %ld and shape of axis dimension " "received label value as %ld and shape of axis dimension "
...@@ -79,7 +78,7 @@ struct HardLabelCrossEntropyCPUFunctorImpl { ...@@ -79,7 +78,7 @@ struct HardLabelCrossEntropyCPUFunctorImpl {
loss_data[loss_idx] = loss_data[loss_idx] =
lbl == ignore_index_ lbl == ignore_index_
? 0 ? 0
: -math::TolerableValue<T>()(std::log(prob_data[index])); : -phi::funcs::TolerableValue<T>()(std::log(prob_data[index]));
} }
} }
} }
...@@ -112,19 +111,18 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()( ...@@ -112,19 +111,18 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
auto loss = EigenMatrix<T>::From(*out); auto loss = EigenMatrix<T>::From(*out);
loss.device(*ctx.eigen_device()) = loss.device(*ctx.eigen_device()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>())) -((lbl * in.log().unaryExpr(phi::funcs::TolerableValue<T>()))
.reshape(batch_axis_remain) .reshape(batch_axis_remain)
.sum(Eigen::DSizes<int, 1>(1))); .sum(Eigen::DSizes<int, 1>(1)));
} else { } else {
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl( HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(
out, prob, labels, ignore_index, axis_dim); out, prob, labels, ignore_index, axis_dim);
framework::VisitIntDataType(framework::TransToProtoVarType(labels->dtype()), phi::VisitDataType(labels->dtype(), functor_impl);
functor_impl);
} }
} }
template class CrossEntropyFunctor<phi::CPUContext, float>; template class CrossEntropyFunctor<phi::CPUContext, float>;
template class CrossEntropyFunctor<phi::CPUContext, double>; template class CrossEntropyFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math.h" #include "paddle/phi/kernels/funcs/math.h"
namespace paddle {
namespace operators { namespace phi {
namespace math { namespace funcs {
template <typename T, typename LabelT> template <typename T, typename LabelT>
__global__ void CrossEntropyKernel(T* Y, __global__ void CrossEntropyKernel(T* Y,
...@@ -38,10 +39,9 @@ __global__ void CrossEntropyKernel(T* Y, ...@@ -38,10 +39,9 @@ __global__ void CrossEntropyKernel(T* Y,
D, D,
ignore_index, ignore_index,
lbl); lbl);
Y[i] = Y[i] = ignore_index == lbl ? static_cast<T>(0)
ignore_index == lbl : -phi::funcs::TolerableValue<T>()(
? static_cast<T>(0) phi::funcs::real_log(X[i * D + lbl]));
: -math::TolerableValue<T>()(phi::funcs::real_log(X[i * D + lbl]));
} }
} }
...@@ -56,10 +56,11 @@ __global__ void SoftCrossEntropyKernel(T* Y, ...@@ -56,10 +56,11 @@ __global__ void SoftCrossEntropyKernel(T* Y,
int idx = blockIdx.x * class_num + tid; int idx = blockIdx.x * class_num + tid;
int end = blockIdx.x * class_num + class_num; int end = blockIdx.x * class_num + class_num;
for (; idx < end; idx += blockDim.x) { for (; idx < end; idx += blockDim.x) {
val += math::TolerableValue<T>()(phi::funcs::real_log(X[idx])) * label[idx]; val += phi::funcs::TolerableValue<T>()(phi::funcs::real_log(X[idx])) *
label[idx];
} }
val = paddle::platform::reduceSum(val, tid, blockDim.x); val = phi::backends::gpu::reduceSum(val, tid, blockDim.x);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
Y[blockIdx.x] = -val; Y[blockIdx.x] = -val;
} }
...@@ -117,8 +118,8 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()( ...@@ -117,8 +118,8 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
const bool softLabel, const bool softLabel,
const int ignore_index, const int ignore_index,
const int axis_dim) { const int axis_dim) {
T* loss_data = ctx.template Alloc<T>(out);
const T* prob_data = prob->data<T>(); const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
int batch_size = prob->dims()[0]; int batch_size = prob->dims()[0];
int class_num = prob->dims()[1]; int class_num = prob->dims()[1];
...@@ -145,8 +146,7 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()( ...@@ -145,8 +146,7 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
ignore_index, ignore_index,
kMaxBlockDim, kMaxBlockDim,
ctx.stream()); ctx.stream());
framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()), phi::VisitDataType(labels->dtype(), functor);
functor);
} }
} }
...@@ -154,6 +154,5 @@ template class CrossEntropyFunctor<phi::GPUContext, float>; ...@@ -154,6 +154,5 @@ template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>; template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>; template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -15,14 +15,13 @@ limitations under the License. */ ...@@ -15,14 +15,13 @@ limitations under the License. */
#pragma once #pragma once
#include <limits> #include <limits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/phi/common/float16.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T> template <typename T>
struct TolerableValue { struct TolerableValue {
...@@ -46,14 +45,15 @@ struct TolerableValue { ...@@ -46,14 +45,15 @@ struct TolerableValue {
// Also. In standard implementation of cross entropy, other // Also. In standard implementation of cross entropy, other
// framework not has the ValueClipping. // framework not has the ValueClipping.
template <> template <>
struct TolerableValue<platform::float16> { struct TolerableValue<phi::dtype::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& x) const { HOSTDEVICE phi::dtype::float16 operator()(
if (platform::isfinite(x)) const phi::dtype::float16& x) const {
if (phi::dtype::isfinite(x))
return x; return x;
else if (x > static_cast<platform::float16>(0)) else if (x > static_cast<phi::dtype::float16>(0))
return std::numeric_limits<platform::float16>::max(); return std::numeric_limits<phi::dtype::float16>::max();
else else
return std::numeric_limits<platform::float16>::min(); return std::numeric_limits<phi::dtype::float16>::min();
} }
}; };
...@@ -68,6 +68,5 @@ class CrossEntropyFunctor { ...@@ -68,6 +68,5 @@ class CrossEntropyFunctor {
const int ignore_index, const int ignore_index,
const int axis_dim); const int axis_dim);
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
......
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
...@@ -31,6 +30,7 @@ namespace cub = hipcub; ...@@ -31,6 +30,7 @@ namespace cub = hipcub;
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
...@@ -46,7 +46,7 @@ template <typename T> ...@@ -46,7 +46,7 @@ template <typename T>
static __device__ __forceinline__ T Log(T x) { static __device__ __forceinline__ T Log(T x) {
using AccT = typename dtype::MPTypeTrait<T>::Type; using AccT = typename dtype::MPTypeTrait<T>::Type;
AccT logx = std::log(static_cast<AccT>(x)); AccT logx = std::log(static_cast<AccT>(x));
return paddle::operators::math::TolerableValue<T>()(static_cast<T>(logx)); return phi::funcs::TolerableValue<T>()(static_cast<T>(logx));
} }
// Wrapper of exp function. Use exp(float32) for float16 // Wrapper of exp function. Use exp(float32) for float16
...@@ -54,7 +54,7 @@ template <typename T> ...@@ -54,7 +54,7 @@ template <typename T>
static __device__ __forceinline__ T Exp(T x) { static __device__ __forceinline__ T Exp(T x) {
using AccT = typename dtype::MPTypeTrait<T>::Type; using AccT = typename dtype::MPTypeTrait<T>::Type;
AccT expx = std::exp(static_cast<AccT>(x)); AccT expx = std::exp(static_cast<AccT>(x));
return paddle::operators::math::TolerableValue<T>()(static_cast<T>(expx)); return phi::funcs::TolerableValue<T>()(static_cast<T>(expx));
} }
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
...@@ -1285,16 +1285,15 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1285,16 +1285,15 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
DenseTensor softmax_out_2d(*softmax_out); DenseTensor softmax_out_2d(*softmax_out);
softmax_out_2d.Resize({n, d}); softmax_out_2d.Resize({n, d});
// math::CrossEntropyFunctor support axis is the last // phi::funcs::CrossEntropyFunctor support axis is the last
if (axis_v == -1) { if (axis_v == -1) {
paddle::operators::math::CrossEntropyFunctor<GPUContext, T>()( phi::funcs::CrossEntropyFunctor<GPUContext, T>()(dev_ctx,
dev_ctx, &loss_2d,
&loss_2d, &softmax_2d,
&softmax_2d, &labels_2d,
&labels_2d, soft_label,
soft_label, ignore_index,
ignore_index, axis_dim);
axis_dim);
return; return;
} }
...@@ -1389,14 +1388,13 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1389,14 +1388,13 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
loss_2d.Resize({n, 1}); loss_2d.Resize({n, 1});
paddle::operators::math::SoftmaxCUDNNFunctor<T, GPUContext>()( paddle::operators::math::SoftmaxCUDNNFunctor<T, GPUContext>()(
dev_ctx, &logits_2d, &softmax_2d); dev_ctx, &logits_2d, &softmax_2d);
paddle::operators::math::CrossEntropyFunctor<GPUContext, T>()( phi::funcs::CrossEntropyFunctor<GPUContext, T>()(dev_ctx,
dev_ctx, &loss_2d,
&loss_2d, &softmax_2d,
&softmax_2d, &labels_2d,
&labels_2d, false,
false, ignore_index,
ignore_index, axis_dim);
axis_dim);
} else { } else {
auto* logits_data = logits.data<T>(); auto* logits_data = logits.data<T>();
auto* labels_data = label.data<LabelT>(); auto* labels_data = label.data<LabelT>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册