diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 4b38af7dc3236aeea6696f0a9710a8f6181f4a12..a21bc7335126ba8cff0de70aa8b8cc1fc36c1976 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -28,27 +28,27 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -tolerable_value(log(X[i * D + label[i]])); + Y[i] = -TolerableValue()(log(X[i * D + label[i]])); } } -template +template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int N, const int D) { int tid = threadIdx.x; - __shared__ T d_sum[blockSize]; + __shared__ T d_sum[BlockSize]; int next_idx = blockIdx.x * D + tid; d_sum[tid] = 0; int cur_idx = tid; while (cur_idx < D) { - d_sum[tid] += tolerable_value(std::log(X[next_idx])) * label[next_idx]; - next_idx += blockSize; - cur_idx += blockSize; + d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += BlockSize; + cur_idx += BlockSize; } __syncthreads(); - for (int stride = blockSize >> 1; stride > 0; stride >>= 1) { + for (int stride = BlockSize >> 1; stride > 0; stride >>= 1) { __syncthreads(); if (tid < stride) { next_idx = tid + stride; @@ -88,13 +88,12 @@ template __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const T* label, const int N, const int D) { - // TOOD(qingqing): optimize for this kernel - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - for (int j = 0; j < D; ++j) { - int idx = i * D + j; - dX[idx] = -label[idx] * dY[i] / X[idx]; - } + int row_ids = blockIdx.x * blockDim.x + threadIdx.x; + int col_ids = blockIdx.y * blockDim.y + threadIdx.y; + int ids = row_ids * D + col_ids; + + if (ids < N * D) { + dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } @@ -103,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); auto x = ctx.Input("X"); auto y = ctx.Output("Y"); @@ -136,7 +135,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); auto x = ctx.Input("X"); auto dx = ctx.Output(framework::GradVarName("X")); @@ -156,6 +155,11 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { // TODO(qingqing): launch kernel on specified stream // base on ExecutionContext. if (ctx.Attr("soft_label") == 1) { + int block_x = 32; + int block_y = 32; + dim3 block(block_x, block_y); + dim3 grid((n + block_x - 1) / block_x, (d + block_y - 1) / block_y); + auto* label_data = label->data(); SoftCrossEntropyGradientKernel<<>>( dx_data, dy_data, x_data, label_data, n, d); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 1b4b23ac2029138afadef0168262203ac2e20430..4bbd05a1bb677b3fd4f9dac9a19c0d2bcdb3185c 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/platform/hostdevice.h" @@ -20,19 +21,25 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; template -HOSTDEVICE T tolerable_value(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; - } - if (x == -INFINITY) { - return -kApproInf; +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + + if (x == INFINITY) { + return kApproInf; + } + if (x == -INFINITY) { + return -kApproInf; + } + return x; } - return x; -} +}; template class CrossEntropyOpKernel : public framework::OpKernel { @@ -40,33 +47,34 @@ class CrossEntropyOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - - auto* x_data = x->data(); + const Tensor* x = ctx.Input("X"); + const Tensor* labels = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); - - int batch_size = x->dims()[0]; - int class_num = x->dims()[1]; + const int batch_size = x->dims()[0]; if (ctx.Attr("soft_label") == 1) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - T sum = static_cast(0); - for (int j = 0; j < class_num; ++j) { - sum += label_data[index] * tolerable_value(std::log(x_data[index])); - y_data[i] = -sum; - index++; - } - } + auto prob = EigenMatrix::From(*x); + auto lbl_mat = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*y); + + // loss.device(ctx.GetEigenDevice()) = + // prob.log().unaryExpr(TolerableValue()); + + loss.device(ctx.GetEigenDevice()) = + -((lbl_mat * prob.log()) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); } else { - auto* label_data = ctx.Input("Label")->data(); + const int class_num = x->dims()[1]; + + const T* x_data = x->data(); + T* y_data = y->data(); + + const int* label_data = labels->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - y_data[i] = -tolerable_value(std::log(x_data[index])); + y_data[i] = -TolerableValue()(std::log(x_data[index])); } } }