diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f8b0bce6815ff17a60ef64b0eec34a7cc9d16e72..e56895c63a426b782f7b46091bc86c367d49899d 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -88,10 +88,14 @@ add_subdirectory(math) set(DEPS_OPS recurrent_op - cond_op) + cond_op + cross_entropy_op + softmax_with_cross_entropy_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) +op_library(cross_entropy_op DEPS cross_entropy_function) +op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 18e44d77c9f62b296dc57952e546f844670c7d57..1cfeb7a53b047541322ac53c5b7249e660039d5c 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,62 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/op_registry.h" #include "paddle/operators/cross_entropy_op.h" -#include "paddle/platform/assert.h" -#include "paddle/platform/hostdevice.h" namespace paddle { namespace operators { -template -__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, - const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -TolerableValue()(log(X[i * D + label[i]])); - } -} - -template -__device__ __forceinline__ T sum_single_warp(T val) { - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); - return val; -} - -template -__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int class_num) { - int tid = threadIdx.x; - extern __shared__ T d_sum[]; - d_sum[tid] = 0; - - int cur_idx = tid; - int next_idx = blockIdx.x * class_num + tid; - while (cur_idx < class_num) { - d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; - next_idx += blockDim.x; - cur_idx += blockDim.x; - } - __syncthreads(); - - for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { - if (tid < stride) d_sum[tid] += d_sum[tid + stride]; - __syncthreads(); - } - - T val = d_sum[tid]; - val = sum_single_warp(val); - if (tid == 0) Y[blockIdx.x] = -val; -} - +namespace { // TODO(qingqing): make zero setting a common function. template __global__ void Zero(T* X, const int N) { @@ -100,6 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } +} // namespace template class CrossEntropyOpCUDAKernel : public framework::OpKernel { @@ -107,36 +58,13 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - const Tensor* x = ctx.Input("X"); const Tensor* label = ctx.Input("Label"); Tensor* y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); - const T* x_data = x->data(); - T* y_data = y->mutable_data(ctx.GetPlace()); - - int batch_size = x->dims()[0]; - int class_num = x->dims()[1]; - - if (ctx.Attr("softLabel")) { - auto* label_data = ctx.Input("Label")->data(); - int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); - - SoftCrossEntropyKernel< - T><<( - ctx.device_context()) - .stream()>>>(y_data, x_data, label_data, class_num); - } else { - auto* label_data = ctx.Input("Label")->data(); - int block = 512; - int grid = (batch_size + block - 1) / block; - CrossEntropyKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(y_data, x_data, label_data, - batch_size, class_num); - } + math::CrossEntropyFunctor()( + ctx, y, x, label, ctx.Attr("softLabel")); } }; @@ -150,6 +78,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { const Tensor* x = ctx.Input("X"); const Tensor* label = ctx.Input("Label"); Tensor* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); const T* dy_data = ctx.Input(framework::GradVarName("Y"))->data(); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 255b2e9f5ea7566cca7fd3914e38da804b7c7006..1f67461d3fadb1a979832ad049d4e0098256b834 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/platform/hostdevice.h" +#include "paddle/operators/math/cross_entropy.h" namespace paddle { namespace operators { @@ -25,18 +25,6 @@ template using EigenMatrix = framework::EigenMatrix; -template -struct TolerableValue { - HOSTDEVICE T operator()(const T& x) const { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - - if (x == INFINITY) return kApproInf; - if (x == -INFINITY) return -kApproInf; - return x; - } -}; - template class CrossEntropyOpKernel : public framework::OpKernel { public: @@ -46,28 +34,10 @@ class CrossEntropyOpKernel : public framework::OpKernel { const Tensor* x = ctx.Input("X"); const Tensor* labels = ctx.Input("Label"); Tensor* y = ctx.Output("Y"); - T* y_data = y->mutable_data(ctx.GetPlace()); - - const int batch_size = x->dims()[0]; - if (ctx.Attr("softLabel")) { - auto prob = EigenMatrix::From(*x); - auto lbl_mat = EigenMatrix::From(*labels); - auto loss = EigenMatrix::From(*y); + y->mutable_data(ctx.GetPlace()); - loss.device(ctx.GetEigenDevice()) = - -((lbl_mat * prob.log().unaryExpr(TolerableValue())) - .sum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(batch_size, 1))); - } else { - const int class_num = x->dims()[1]; - const T* x_data = x->data(); - - const int* label_data = labels->data(); - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - y_data[i] = -TolerableValue()(std::log(x_data[index])); - } - } + math::CrossEntropyFunctor()( + ctx, y, x, labels, ctx.Attr("softLabel")); } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index f8333f34f7b4c7b0f9a0f14a7a33f9d98e1d331c..91ae3d49f1df51d9524547f7765285bff9dbb5c5 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,9 +1,15 @@ - if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc + im2col.cu DEPS cblas device_context operator) + nv_library(softmax_function SRCS softmax.cc softmax.cu + DEPS operator) + nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu + DEPS operator) else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc im2col.cc + DEPS cblas device_context operator) + cc_library(softmax_function SRCS softmax.cc DEPS operator) + cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cross_entropy.cc b/paddle/operators/math/cross_entropy.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5a426bc7b16852e67afd790df7a91d89a458c8a --- /dev/null +++ b/paddle/operators/math/cross_entropy.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/math/cross_entropy.h" + +namespace paddle { +namespace operators { +namespace math { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class CrossEntropyFunctor { + public: + void operator()(const framework::ExecutionContext& ctx, + framework::Tensor* out, const framework::Tensor* prob, + const framework::Tensor* labels, const bool softLabel) { + const int batch_size = prob->dims()[0]; + if (softLabel) { + auto in = EigenMatrix::From(*prob); + auto lbl = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*out); + + loss.device(ctx.GetEigenDevice()) = + -((lbl * in.log().unaryExpr(math::TolerableValue())) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); + } else { + const int class_num = prob->dims()[1]; + const T* prob_data = prob->data(); + T* loss_data = out->data(); + + const int* label_data = labels->data(); + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + loss_data[i] = -math::TolerableValue()(std::log(prob_data[index])); + } + } + } +}; + +template class CrossEntropyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu new file mode 100644 index 0000000000000000000000000000000000000000..d14a75a30c01deb86937a3ced43005aed4066d86 --- /dev/null +++ b/paddle/operators/math/cross_entropy.cu @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/math/cross_entropy.h" + +namespace paddle { +namespace operators { +namespace math { + +namespace { +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -math::TolerableValue()(log(X[i * D + label[i]])); + } +} + +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + +template +__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += + math::TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); + } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; +} +} // namespace + +using Tensor = framework::Tensor; + +template +class CrossEntropyFunctor { + public: + void operator()(const framework::ExecutionContext& ctx, + framework::Tensor* out, const framework::Tensor* prob, + const framework::Tensor* labels, bool softLabel) { + const T* prob_data = prob->data(); + T* loss_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = prob->dims()[0]; + int class_num = prob->dims()[1]; + + if (softLabel) { + const T* label_data = labels->data(); + int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); + + SoftCrossEntropyKernel< + T><<( + ctx.device_context()) + .stream()>>>(loss_data, prob_data, label_data, class_num); + } else { + const int* label_data = labels->data(); + int block = 512; + int grid = (batch_size + block - 1) / block; + CrossEntropyKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(loss_data, prob_data, label_data, + batch_size, class_num); + } + } +}; + +template class CrossEntropyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cross_entropy.h b/paddle/operators/math/cross_entropy.h new file mode 100644 index 0000000000000000000000000000000000000000..18e637cf9186b5dc21e94f1ab15b3d858ec93c67 --- /dev/null +++ b/paddle/operators/math/cross_entropy.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +template +class CrossEntropyFunctor { + public: + // (TODO caoying) it is much better to use DeviceContext as the first + // parameter. + void operator()(const framework::ExecutionContext& context, + framework::Tensor* out, const framework::Tensor* prob, + const framework::Tensor* labels, const bool softLabel); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc new file mode 100644 index 0000000000000000000000000000000000000000..1224c05810543226773b02b5aa1b26cea15d9f54 --- /dev/null +++ b/paddle/operators/math/softmax.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { +namespace math { + +template class SoftmaxFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu new file mode 100644 index 0000000000000000000000000000000000000000..4c3df0550e7ca6f4310db1d35cc34d5c73a2dd16 --- /dev/null +++ b/paddle/operators/math/softmax.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { +namespace math { + +template class SoftmaxFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2f0d0aecffcd0fe51166c3d863aa8b91bba196 --- /dev/null +++ b/paddle/operators/math/softmax.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenMatrix = framework::EigenMatrix; + +template +struct ValueClip { + HOSTDEVICE T operator()(const T& x) const { + const T kThreshold = -64.; + return x < kThreshold ? kThreshold : x; + } +}; + +template +class SoftmaxFunctor { + public: + void operator()(const framework::ExecutionContext& context, + const framework::Tensor* X, framework::Tensor* Y) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + + softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice()) = + (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + } +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 8a3a5ab927c0e2937936fcc973f000d4d95c3dbc..7220f486be055e1b841a06b15f519717c54f575c 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { @@ -30,36 +31,11 @@ class SoftmaxKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto X = context.Input("X"); auto Y = context.Output("Y"); - Y->mutable_data(context.GetPlace()); - - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); - - const int kBatchDim = 0; - const int kClassDim = 1; - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto shifted_logits = (logits - - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); - - softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); + // allocate memory on device. + Y->mutable_data(context.GetPlace()); - softmax.device(context.GetEigenDevice()) = - (softmax * - softmax.sum(along_class) - .inverse() - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); + math::SoftmaxFunctor()(context, X, Y); } }; @@ -67,8 +43,6 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - std::shared_ptr scale_ = std::make_shared(); - auto Y = context.Input("Y"); auto dY = context.Input(framework::GradVarName("Y")); auto dX = context.Output(framework::GradVarName("X")); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b6f33ad9e030bc69184de2fb5b88c9dcb07e71a9 --- /dev/null +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -0,0 +1,169 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/softmax_with_cross_entropy_op.h" + +namespace paddle { +namespace operators { + +class SoftmaxWithCrossEntropyOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Logits", + "(Tensor, default: Tensor), The unscaled log probabilities " + "which is a 2-D tensor with shape [N x K]. N is the batch_size, " + "and K is the class number.") + .NotInGradient(); + AddInput( + "Label", + "(Tensor, default: Tensor), The ground truth which is a 2-D " + "tensor. " + "If softLable is set to 0, Label is a Tensor with shape [N x 1]. " + "If softLable is set to 1, Label is a Tensor " + "with shape [N x K]."); + AddOutput( + "Softmax", + "(Tensor, default: Tensor), A 2-D tensor with shape [N x K]. " + "The outputs value of softmax activation by given the input batch, " + "which will be used in backward calculation.") + .AsIntermediate(); + AddOutput("Loss", + "(Tensor, default: Tensor), A 2-D tensor. The cross " + "entropy loss with shape [N x 1]."); + AddAttr( + "softLabel", + "(bool, default: false), A flag to indicate whether to interpretate " + "the given labels as soft labels.") + .SetDefault(false); + AddComment(R"DOC( +Cross entropy loss with softmax are used as the output layer extensively. This +operator computes the softmax normalized values for each row of the input +tensor, after which cross-entropy loss is then computed. This provides a more +numerically stable gradient. + +Because this operators performs a softmax on logits internally, it expects +unscaled logits. Please do not call this op with the output of softmax operator, +which will produce incorrect results. + +This operators expects mutually exclusive hard labels, each sample in a batch +is in exactly one class with probabilities 1. Each sample in the batch with one +and only one label. + +Equation: + +1) hard label (one-hot label) + +Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K + +2) soft label (a distribution over all classes) + +Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K + +)DOC"); + } +}; + +class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"), + "Input(Logits) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) should be not null."); + + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"), + "Output(Softmax) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"), + "Output(Loss) should be not null."); + + const Tensor* logits = ctx.Input("Logits"); + const Tensor* labels = ctx.Input("Label"); + PADDLE_ENFORCE_EQ( + logits->dims().size(), 2UL, + "The input of softmax_with_cross_entropy should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(ctx.Input("Label")->dims().size(), 2UL, + "The labels should be a 2-D tensor."); + + if (ctx.Attr("softLabel")) { + PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1], + "If Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); + } else { + PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, + "If Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); + } + + ctx.Output("Softmax")->Resize(logits->dims()); + ctx.Output("Loss")->Resize({logits->dims()[0], 1}); + + ctx.ShareLoD("Logits", /*->*/ "Softmax"); + ctx.ShareLoD("Logits", /*->*/ "Loss"); + } +}; + +class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), + "Input(Loss@Grad) should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), + "Input(Softmax) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")), + "Output(Logits@Grad) should be not null."); + + const Tensor* softmax = ctx.Input("Softmax"); + const Tensor* labels = ctx.Input("Label"); + PADDLE_ENFORCE_EQ(ctx.Input("Label")->dims().size(), 2UL, + "The labels should be a 2-D tensor."); + + if (ctx.Attr("softLabel")) { + PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1], + "When Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); + } else { + PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, + "When Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); + } + + ctx.Output(framework::GradVarName("Logits")) + ->Resize(ctx.Input("Softmax")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, + ops::SoftmaxWithCrossEntropyOpMaker, + softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyOpGrad); +REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, + ops::SoftmaxWithCrossEntropyKernel); +REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyGradKernel); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..1cf4296dccf68aece6fdfb7910a9c68449633b76 --- /dev/null +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/softmax_with_cross_entropy_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +namespace { +template +__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad, + const int* labels, const int batch_size, + const int class_num) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int sample_idx = tid / class_num; + + if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx]; + __syncthreads(); + + if (tid < batch_size) { + PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num); + out_grad[tid * class_num + labels[tid]] -= 1.; + } +} + +template +__global__ void SoftCrossEntropyGradientKernel(T* logit_grad, + const T* loss_grad, + const T* labels, + const int batch_size, + const int class_num) { + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < batch_size * class_num) { + int row_ids = ids / class_num; + logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids]; + } +} +} // namespace + +template +class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + const Tensor* logits = context.Input("Logits"); + const Tensor* labels = context.Input("Label"); + Tensor* softmax = context.Output("Softmax"); + + Tensor* loss = context.Output("Loss"); + softmax->mutable_data(context.GetPlace()); + loss->mutable_data(context.GetPlace()); + + math::SoftmaxFunctor()(context, logits, softmax); + math::CrossEntropyFunctor()( + context, loss, softmax, labels, context.Attr("softLabel")); + } +}; + +template +class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + const Tensor* labels = context.Input("Label"); + const T* loss_grad_data = + context.Input(framework::GradVarName("Loss"))->data(); + Tensor* logit_grad = + context.Output(framework::GradVarName("Logits")); + logit_grad->ShareDataWith(*context.Input("Softmax")); + T* logit_grad_data = logit_grad->data(); + + const int batch_size = logit_grad->dims()[0]; + const int class_num = logit_grad->dims()[1]; + int block = 512; + int grid = (batch_size * class_num + block - 1) / block; + + if (context.Attr("softLabel")) { + const T* label_data = labels->data(); + SoftCrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(logit_grad_data, loss_grad_data, + label_data, batch_size, class_num); + } else { + const int* label_data = labels->data(); + CrossEntropyGrad<<< + grid, block, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(logit_grad_data, loss_grad_data, + label_data, batch_size, class_num); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy, + ops::SoftmaxWithCrossEntropyCUDAKernel); +REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyGradCUDAKernel); diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bf792c1f59e2e43a98c93bddbc2aa63d646dee6f --- /dev/null +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/cross_entropy.h" +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), + "This kernel only runs on CPU."); + const Tensor* logits = context.Input("Logits"); + const Tensor* labels = context.Input("Label"); + Tensor* softmax = context.Output("Softmax"); + Tensor* loss = context.Output("Loss"); + + softmax->mutable_data(context.GetPlace()); + loss->mutable_data(context.GetPlace()); + + math::SoftmaxFunctor()(context, logits, softmax); + math::CrossEntropyFunctor()( + context, loss, softmax, labels, context.Attr("softLabel")); + } +}; + +template +class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* out_grad = + context.Input(framework::GradVarName("Loss")); + const Tensor* labels = context.Input("Label"); + Tensor* logit_grad = + context.Output(framework::GradVarName("Logits")); + logit_grad->ShareDataWith(*context.Input("Softmax")); + + const int class_num = logit_grad->dims()[1]; + if (context.Attr("softLabel")) { + auto out_grad_mat = EigenMatrix::From(*out_grad); + auto logit_grad_mat = EigenMatrix::From(*logit_grad); + auto lbl_mat = EigenMatrix::From(*labels); + + logit_grad_mat.device(context.GetEigenDevice()) = + logit_grad_mat * + out_grad_mat.broadcast(Eigen::DSizes(1, class_num)) - + lbl_mat; + } else { + const int batch_size = logit_grad->dims()[0]; + const int* label_data = labels->data(); + const T* out_grad_data = out_grad->data(); + T* logit_grad_data = logit_grad->data(); + + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + logit_grad_data[index] = + (out_grad_data[i] * logit_grad_data[index] - 1.); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 0a5673868c547d9e184e8ce05346c3ebabe06892..579ad7b40738f45bf055f740e66d2238f4db22fc 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -177,7 +177,7 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, class OpTest(unittest.TestCase): - def check_output_with_place(self, place): + def check_output_with_place(self, place, atol): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict() @@ -206,22 +206,23 @@ class OpTest(unittest.TestCase): self.scope.find_var(sub_out_name).get_tensor()) self.assertTrue( np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + " has diff") + actual, expect, atol=atol), + "output name: " + out_name + " has diff.") else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] + self.assertTrue( np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + " has diff") + actual, expect, atol=atol), + "output name: " + out_name + " has diff.") - def check_output(self): + def check_output(self, atol=1e-5): places = [core.CPUPlace()] if core.is_compile_gpu(): places.append(core.GPUPlace(0)) for place in places: - self.check_output_with_place(place) + self.check_output_with_place(place, atol) def __assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): @@ -235,9 +236,10 @@ class OpTest(unittest.TestCase): def err_msg(): offset = np.argmax(diff_mat > max_relative_error) - return "%s Variable %s max gradient diff %f over limit %f, the first " \ - "error element is %d" % ( - msg_prefix, name, max_diff, max_relative_error, offset) + return ("%s Variable %s max gradient diff %f over limit %f, " + "the first error element is %d") % ( + msg_prefix, name, max_diff, max_relative_error, + offset) self.assertLessEqual(max_diff, max_relative_error, err_msg()) diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py index 1b948f252fa631e9886840b377de2996e110dc91..b41c810d9a6269c934a434b085748a86deccb475 100644 --- a/python/paddle/v2/framework/tests/test_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -5,7 +5,7 @@ from op_test import OpTest def stable_softmax(x): """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x) + shiftx = x - np.max(x).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps) diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..428395b76c8fbcbc07b19ee1979419f0e64aca85 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py @@ -0,0 +1,70 @@ +import unittest +import numpy as np + +from op_test import OpTest +from test_softmax_op import stable_softmax + + +class TestSoftmaxWithCrossEntropyOp(OpTest): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + """ + + def setUp(self): + self.op_type = "softmax_with_cross_entropy" + batch_size = 3 + class_num = 37 + + logits = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int32") + + cross_entropy = np.asmatrix( + [[-np.log(softmax[i][labels[i][0]])] + for i in range(softmax.shape[0])], + dtype="float32") + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = {"Softmax": softmax, "Loss": cross_entropy} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Logits"], "Loss", max_relative_error=0.05) + + +class TestSoftmaxWithCrossEntropyOp2(OpTest): + """ + Test softmax with cross entropy operator with soft labels. + """ + + def setUp(self): + self.op_type = "softmax_with_cross_entropy" + batch_size = 2 + class_num = 17 + + logits = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + labels = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + labels /= np.sum(labels, axis=1, keepdims=True) + + cross_entropy = (-labels * np.log(softmax)).sum( + axis=1, keepdims=True).astype("float32") + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = {"Softmax": softmax, "Loss": cross_entropy} + self.attrs = {"softLabel": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Logits"], "Loss", max_relative_error=0.05) + + +if __name__ == "__main__": + unittest.main()