From 97509b68e7e4a42ec128a821d8e1ca802231e359 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Tue, 26 Sep 2017 23:46:00 +0800 Subject: [PATCH] cross entropy as a functor to avoid duplicated codes. --- paddle/operators/CMakeLists.txt | 6 +- paddle/operators/cross_entropy_op.cu | 83 +------------ paddle/operators/cross_entropy_op.h | 38 +----- paddle/operators/math/CMakeLists.txt | 3 + paddle/operators/math/cross_entropy.cc | 59 ++++++++++ paddle/operators/math/cross_entropy.cu | 111 ++++++++++++++++++ paddle/operators/math/cross_entropy.h | 48 ++++++++ .../softmax_with_cross_entropy_op.cu | 86 ++------------ .../operators/softmax_with_cross_entropy_op.h | 29 +---- 9 files changed, 251 insertions(+), 212 deletions(-) create mode 100644 paddle/operators/math/cross_entropy.cc create mode 100644 paddle/operators/math/cross_entropy.cu create mode 100644 paddle/operators/math/cross_entropy.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f8b0bce68..e56895c63 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -88,10 +88,14 @@ add_subdirectory(math) set(DEPS_OPS recurrent_op - cond_op) + cond_op + cross_entropy_op + softmax_with_cross_entropy_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) +op_library(cross_entropy_op DEPS cross_entropy_function) +op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 18e44d77c..1cfeb7a53 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,62 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/op_registry.h" #include "paddle/operators/cross_entropy_op.h" -#include "paddle/platform/assert.h" -#include "paddle/platform/hostdevice.h" namespace paddle { namespace operators { -template -__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, - const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -TolerableValue()(log(X[i * D + label[i]])); - } -} - -template -__device__ __forceinline__ T sum_single_warp(T val) { - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); - return val; -} - -template -__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int class_num) { - int tid = threadIdx.x; - extern __shared__ T d_sum[]; - d_sum[tid] = 0; - - int cur_idx = tid; - int next_idx = blockIdx.x * class_num + tid; - while (cur_idx < class_num) { - d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; - next_idx += blockDim.x; - cur_idx += blockDim.x; - } - __syncthreads(); - - for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { - if (tid < stride) d_sum[tid] += d_sum[tid + stride]; - __syncthreads(); - } - - T val = d_sum[tid]; - val = sum_single_warp(val); - if (tid == 0) Y[blockIdx.x] = -val; -} - +namespace { // TODO(qingqing): make zero setting a common function. template __global__ void Zero(T* X, const int N) { @@ -100,6 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } +} // namespace template class CrossEntropyOpCUDAKernel : public framework::OpKernel { @@ -107,36 +58,13 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - const Tensor* x = ctx.Input("X"); const Tensor* label = ctx.Input("Label"); Tensor* y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); - const T* x_data = x->data(); - T* y_data = y->mutable_data(ctx.GetPlace()); - - int batch_size = x->dims()[0]; - int class_num = x->dims()[1]; - - if (ctx.Attr("softLabel")) { - auto* label_data = ctx.Input("Label")->data(); - int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); - - SoftCrossEntropyKernel< - T><<( - ctx.device_context()) - .stream()>>>(y_data, x_data, label_data, class_num); - } else { - auto* label_data = ctx.Input("Label")->data(); - int block = 512; - int grid = (batch_size + block - 1) / block; - CrossEntropyKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(y_data, x_data, label_data, - batch_size, class_num); - } + math::CrossEntropyFunctor()( + ctx, y, x, label, ctx.Attr("softLabel")); } }; @@ -150,6 +78,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { const Tensor* x = ctx.Input("X"); const Tensor* label = ctx.Input("Label"); Tensor* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); const T* dy_data = ctx.Input(framework::GradVarName("Y"))->data(); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 255b2e9f5..1f67461d3 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/platform/hostdevice.h" +#include "paddle/operators/math/cross_entropy.h" namespace paddle { namespace operators { @@ -25,18 +25,6 @@ template using EigenMatrix = framework::EigenMatrix; -template -struct TolerableValue { - HOSTDEVICE T operator()(const T& x) const { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - - if (x == INFINITY) return kApproInf; - if (x == -INFINITY) return -kApproInf; - return x; - } -}; - template class CrossEntropyOpKernel : public framework::OpKernel { public: @@ -46,28 +34,10 @@ class CrossEntropyOpKernel : public framework::OpKernel { const Tensor* x = ctx.Input("X"); const Tensor* labels = ctx.Input("Label"); Tensor* y = ctx.Output("Y"); - T* y_data = y->mutable_data(ctx.GetPlace()); - - const int batch_size = x->dims()[0]; - if (ctx.Attr("softLabel")) { - auto prob = EigenMatrix::From(*x); - auto lbl_mat = EigenMatrix::From(*labels); - auto loss = EigenMatrix::From(*y); + y->mutable_data(ctx.GetPlace()); - loss.device(ctx.GetEigenDevice()) = - -((lbl_mat * prob.log().unaryExpr(TolerableValue())) - .sum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(batch_size, 1))); - } else { - const int class_num = x->dims()[1]; - const T* x_data = x->data(); - - const int* label_data = labels->data(); - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - y_data[i] = -TolerableValue()(std::log(x_data[index])); - } - } + math::CrossEntropyFunctor()( + ctx, y, x, labels, ctx.Attr("softLabel")); } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 074ca47d7..91ae3d49f 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -3,10 +3,13 @@ if(WITH_GPU) im2col.cu DEPS cblas device_context operator) nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) + nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu + DEPS operator) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(softmax_function SRCS softmax.cc DEPS operator) + cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cross_entropy.cc b/paddle/operators/math/cross_entropy.cc new file mode 100644 index 000000000..a5a426bc7 --- /dev/null +++ b/paddle/operators/math/cross_entropy.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/math/cross_entropy.h" + +namespace paddle { +namespace operators { +namespace math { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class CrossEntropyFunctor { + public: + void operator()(const framework::ExecutionContext& ctx, + framework::Tensor* out, const framework::Tensor* prob, + const framework::Tensor* labels, const bool softLabel) { + const int batch_size = prob->dims()[0]; + if (softLabel) { + auto in = EigenMatrix::From(*prob); + auto lbl = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*out); + + loss.device(ctx.GetEigenDevice()) = + -((lbl * in.log().unaryExpr(math::TolerableValue())) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); + } else { + const int class_num = prob->dims()[1]; + const T* prob_data = prob->data(); + T* loss_data = out->data(); + + const int* label_data = labels->data(); + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + loss_data[i] = -math::TolerableValue()(std::log(prob_data[index])); + } + } + } +}; + +template class CrossEntropyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu new file mode 100644 index 000000000..d14a75a30 --- /dev/null +++ b/paddle/operators/math/cross_entropy.cu @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/math/cross_entropy.h" + +namespace paddle { +namespace operators { +namespace math { + +namespace { +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -math::TolerableValue()(log(X[i * D + label[i]])); + } +} + +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + +template +__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += + math::TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); + } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; +} +} // namespace + +using Tensor = framework::Tensor; + +template +class CrossEntropyFunctor { + public: + void operator()(const framework::ExecutionContext& ctx, + framework::Tensor* out, const framework::Tensor* prob, + const framework::Tensor* labels, bool softLabel) { + const T* prob_data = prob->data(); + T* loss_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = prob->dims()[0]; + int class_num = prob->dims()[1]; + + if (softLabel) { + const T* label_data = labels->data(); + int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); + + SoftCrossEntropyKernel< + T><<( + ctx.device_context()) + .stream()>>>(loss_data, prob_data, label_data, class_num); + } else { + const int* label_data = labels->data(); + int block = 512; + int grid = (batch_size + block - 1) / block; + CrossEntropyKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(loss_data, prob_data, label_data, + batch_size, class_num); + } + } +}; + +template class CrossEntropyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cross_entropy.h b/paddle/operators/math/cross_entropy.h new file mode 100644 index 000000000..18e637cf9 --- /dev/null +++ b/paddle/operators/math/cross_entropy.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +template +class CrossEntropyFunctor { + public: + // (TODO caoying) it is much better to use DeviceContext as the first + // parameter. + void operator()(const framework::ExecutionContext& context, + framework::Tensor* out, const framework::Tensor* prob, + const framework::Tensor* labels, const bool softLabel); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index feae903da..1cf4296dc 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -14,26 +14,14 @@ #define EIGEN_USE_GPU -#include "paddle/framework/op_registry.h" -#include "paddle/operators/cross_entropy_op.h" -#include "paddle/operators/math/softmax.h" +#include "paddle/operators/softmax_with_cross_entropy_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -__global__ void CrossEntropy(T* out, const T* softmax_out, const int* labels, - const int batch_size, const int class_num) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < batch_size) { - PADDLE_ASSERT(labels[i] >= 0 && labels[i] < class_num); - out[i] = - -TolerableValue()(std::log(softmax_out[i * class_num + labels[i]])); - } -} - +namespace { template __global__ void CrossEntropyGrad(T* out_grad, const T* in_grad, const int* labels, const int batch_size, @@ -50,42 +38,6 @@ __global__ void CrossEntropyGrad(T* out_grad, const T* in_grad, } } -template -__device__ __forceinline__ T sum_single_warp(T val) { - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); - return val; -} - -template -__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int class_num) { - int tid = threadIdx.x; - extern __shared__ T d_sum[]; - d_sum[tid] = 0; - - int cur_idx = tid; - int next_idx = blockIdx.x * class_num + tid; - while (cur_idx < class_num) { - d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; - next_idx += blockDim.x; - cur_idx += blockDim.x; - } - __syncthreads(); - - for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { - if (tid < stride) d_sum[tid] += d_sum[tid + stride]; - __syncthreads(); - } - - T val = d_sum[tid]; - val = sum_single_warp(val); - if (tid == 0) Y[blockIdx.x] = -val; -} - template __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, const T* loss_grad, @@ -98,6 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids]; } } +} // namespace template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { @@ -105,36 +58,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), "This kernel only runs on GPU device."); - T* loss_data = - context.Output("Loss")->mutable_data(context.GetPlace()); - const Tensor* logits = context.Input("Logits"); + const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); - T* softmax_out = softmax->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); - const int batch_size = logits->dims()[0]; - const int class_num = logits->dims()[1]; - int block = 512; - int grid = (batch_size + block - 1) / block; + Tensor* loss = context.Output("Loss"); + softmax->mutable_data(context.GetPlace()); + loss->mutable_data(context.GetPlace()); - if (context.Attr("softLabel")) { - const T* label_data = context.Input("Label")->data(); - block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); - - SoftCrossEntropyKernel< - T><<( - context.device_context()) - .stream()>>>(loss_data, softmax_out, label_data, class_num); - } else { - const int* label_data = context.Input("Label")->data(); - CrossEntropy<<( - context.device_context()) - .stream()>>>(loss_data, softmax_out, label_data, - batch_size, class_num); - } + math::SoftmaxFunctor()(context, logits, softmax); + math::CrossEntropyFunctor()( + context, loss, softmax, labels, context.Attr("softLabel")); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 581c5145a..bf792c1f5 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/operators/math/cross_entropy.h" #include "paddle/operators/math/softmax.h" namespace paddle { @@ -37,31 +37,12 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); - T* softmax_data = softmax->mutable_data(context.GetPlace()); - T* loss_data = loss->mutable_data(context.GetPlace()); + softmax->mutable_data(context.GetPlace()); + loss->mutable_data(context.GetPlace()); math::SoftmaxFunctor()(context, logits, softmax); - - const int batch_size = logits->dims()[0]; - if (context.Attr("softLabel")) { - // (TODO caoying) the forward implementation can be further optimized. - // Current implementation is exactly cross entropy after softmax. - auto prob = EigenMatrix::From(*softmax); - auto lbl_mat = EigenMatrix::From(*labels); - auto loss_mat = EigenMatrix::From(*loss); - - loss_mat.device(context.GetEigenDevice()) = - -((lbl_mat * prob.log().unaryExpr(TolerableValue())) - .sum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(batch_size, 1))); - } else { - const int* label_data = labels->data(); - const int class_num = logits->dims()[1]; - - for (int i = 0; i < batch_size; ++i) - loss_data[i] = -TolerableValue()( - std::log(softmax_data[i * class_num + label_data[i]])); - } + math::CrossEntropyFunctor()( + context, loss, softmax, labels, context.Attr("softLabel")); } }; -- GitLab