diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 0e0622e290f42811c83c354d749ef32a2d9dcadb..2b2a9dc8319f964875371214168ce04cb67fc818 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -164,11 +164,13 @@ or not. But the output only shares the LoD information with input X. } // namespace paddle namespace ops = paddle::operators; +using CPUCtx = paddle::platform::CPUDeviceContext; + REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, + ops::CrossEntropyOpKernel); REGISTER_OP_CPU_KERNEL(cross_entropy_grad, - ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel); + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index 6449149d4b55962e84baafffc0c2c03f8108866f..30dbd5bd3d39dd2992c3dd91364003bb7715a2eb 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -14,98 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" -namespace paddle { -namespace operators { - -namespace { - -template -__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, - const int64_t* label, const int N, - const int D) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - int idx = i * D + label[i]; - dX[idx] = -dY[i] / X[idx]; - } -} - -template -__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, - const T* label, const int N, - const int D) { - int ids = blockIdx.x * blockDim.x + threadIdx.x; - if (ids < N * D) { - int row_ids = ids / D; - dX[ids] = -label[ids] * dY[row_ids] / X[ids]; - } -} -} // namespace - -template -class CrossEntropyOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); - const Tensor* x = ctx.Input("X"); - const Tensor* label = ctx.Input("Label"); - Tensor* y = ctx.Output("Y"); - y->mutable_data(ctx.GetPlace()); - - math::CrossEntropyFunctor()( - ctx.template device_context(), y, x, label, - ctx.Attr("soft_label")); - } -}; - -template -class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); - - const Tensor* x = ctx.Input("X"); - const Tensor* label = ctx.Input("Label"); - Tensor* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); - - const T* dy_data = - ctx.Input(framework::GradVarName("Y"))->data(); - T* dx_data = dx->mutable_data(ctx.GetPlace()); - const T* x_data = x->data(); - - int64_t batch_size = x->dims()[0]; - int64_t class_num = x->dims()[1]; - - int block = 512; - int grid = (batch_size * class_num + block - 1) / block; - - auto& dev_ctx = ctx.template device_context(); - auto stream = dev_ctx.stream(); - - if (ctx.Attr("soft_label")) { - auto* label_data = label->data(); - SoftCrossEntropyGradientKernel<<>>( - dx_data, dy_data, x_data, label_data, batch_size, class_num); - } else { - math::SetConstant functor; - functor(dev_ctx, dx, 0); - auto* label_data = label->data(); - grid = (batch_size + block - 1) / block; - CrossEntropyGradientKernel<<>>( - dx_data, dy_data, x_data, label_data, batch_size, class_num); - } - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel, - ops::CrossEntropyOpCUDAKernel); +using CUDACtx = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(cross_entropy, + ops::CrossEntropyOpKernel, + ops::CrossEntropyOpKernel); REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, - ops::CrossEntropyGradientOpCUDAKernel, - ops::CrossEntropyGradientOpCUDAKernel); + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 6da3a24dc89a85fe432b6350d3af7b0e84337c9d..19a2aec92b267ece94685ce34604b7d1cfa5d209 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -17,69 +17,106 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; -template +template class CrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "This kernel only runs on CPU."); - const Tensor* x = ctx.Input("X"); - const Tensor* labels = ctx.Input("Label"); - Tensor* y = ctx.Output("Y"); + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); - math::CrossEntropyFunctor()( - ctx.template device_context(), y, x, labels, + math::CrossEntropyFunctor()( + ctx.template device_context(), y, x, labels, ctx.Attr("soft_label")); } }; template +class XeSoftlabelGradFunctor { + public: + XeSoftlabelGradFunctor(T* dx, + const T* dy, // NOLINT + const T* x, // NOLINT + const T* label, // NOLINT + size_t num_classes) + : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {} + + HOSTDEVICE void operator()(size_t i) { + auto row_ids = i / num_classes_; + dx_[i] = -label_[i] * dy_[row_ids] / x_[i]; + } + + private: + T* dx_; + const T* dy_; + const T* x_; + const T* label_; + size_t num_classes_; +}; + +template +class XeGradFunctor { + public: + XeGradFunctor(T* dx, + const T* dy, // NOLINT + const T* x, // NOLINT + const int64_t* label, // NOLINT + size_t num_classes) + : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {} + + HOSTDEVICE void operator()(size_t sample_id) { + auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id]; + for (size_t x_offset = sample_id * num_classes_; + x_offset < (sample_id + 1) * num_classes_; ++x_offset) { + dx_[x_offset] = x_offset != x_is_true_offset + ? static_cast(0) + : -dy_[sample_id] / x_[x_offset]; + } + } + + private: + T* dx_; + const T* dy_; + const T* x_; + const int64_t* label_; + size_t num_classes_; +}; + +template class CrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "This kernel only runs on CPU."); - const Tensor* x = ctx.Input("X"); - const Tensor* dy = ctx.Input(framework::GradVarName("Y")); - const Tensor* label = ctx.Input("Label"); - Tensor* dx = ctx.Output(framework::GradVarName("X")); - T* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* x = ctx.Input("X"); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* label = ctx.Input("Label"); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); int64_t class_num = x->dims()[1]; if (ctx.Attr("soft_label")) { - auto x_mat = EigenMatrix::From(*x); - auto dy_mat = EigenMatrix::From(*dy); - auto lbl_mat = EigenMatrix::From(*label); - auto dx_mat = EigenMatrix::From(*dx); - - dx_mat.device(*ctx.template device_context() - .eigen_device()) = - -(lbl_mat * - dy_mat.broadcast(Eigen::DSizes(1, class_num)) / x_mat); + XeSoftlabelGradFunctor functor(dx_data, dy->data(), x->data(), + label->data(), + static_cast(class_num)); + platform::ForRange for_range( + ctx.template device_context(), + static_cast(dx->numel())); + for_range(functor); } else { - int64_t batch_size = x->dims()[0]; - const T* dy_data = dy->data(); - const T* x_data = x->data(); - const int64_t* label_data = label->data(); - - math::SetConstant functor; - functor(ctx.template device_context(), dx, 0); - - for (int64_t i = 0; i < batch_size; ++i) { - PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); - int64_t index = i * class_num + label_data[i]; - dx_data[index] = math::TolerableValue()(-dy_data[i] / x_data[index]); - } + XeGradFunctor functor(dx_data, dy->data(), x->data(), + label->data(), + static_cast(class_num)); + platform::ForRange for_range( + ctx.template device_context(), + static_cast(dy->numel())); + for_range(functor); } } };