From 8f8ea005fecd911e913ff728ed37ecb990dfbbca Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 15 Sep 2017 14:51:04 +0800 Subject: [PATCH] fix implementations. --- paddle/operators/math/utils.h | 42 ++++++++ paddle/operators/onehot_cross_entropy_op.cu | 20 +--- .../softmax_with_cross_entropy_op.cc | 12 +-- .../softmax_with_cross_entropy_op.cu | 97 ++++++++++++++++++- .../operators/softmax_with_cross_entropy_op.h | 7 +- .../framework/tests/test_cross_entropy_op.py | 1 - .../test_softmax_with_cross_entropy_op.py | 7 +- 7 files changed, 151 insertions(+), 35 deletions(-) create mode 100644 paddle/operators/math/utils.h diff --git a/paddle/operators/math/utils.h b/paddle/operators/math/utils.h new file mode 100644 index 0000000000..1e72c8e0ca --- /dev/null +++ b/paddle/operators/math/utils.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +T HOSTDEVICE tolerable_value(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + + const T kApproInf = 1e20; + + if (x == INFINITY) { + return kApproInf; + } + + if (x == -INFINITY) { + return -kApproInf; + } + + return x; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/onehot_cross_entropy_op.cu b/paddle/operators/onehot_cross_entropy_op.cu index d999bfce58..f8ed9680e7 100644 --- a/paddle/operators/onehot_cross_entropy_op.cu +++ b/paddle/operators/onehot_cross_entropy_op.cu @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/utils.h" #include "paddle/platform/assert.h" namespace paddle { @@ -20,20 +21,6 @@ namespace operators { using Tensor = framework::Tensor; -template -__host__ __device__ T clipping_log(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - T v = log(x); - if (v == INFINITY) { - return kApproInf; - } - if (v == -INFINITY) { - return -kApproInf; - } - return v; -} - template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { @@ -42,7 +29,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -clipping_log(X[i * D + label[i]]); + Y[i] = -math::tolerable_value(log(X[i * D + label[i]])); } } @@ -73,7 +60,7 @@ class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); auto X = ctx.Input("X"); const T* Xdata = X->data(); @@ -86,6 +73,7 @@ class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { int D = X->dims()[1]; int block = 512; int grid = (N + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream // base on ExecutionContext. CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index fd75494ff8..a0941bb624 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -32,7 +32,7 @@ class SoftmaxWithCrossEntropyOpMaker "Store the outputs of softmax function, " "which will be used in backward calculation.") .AsIntermediate(); - AddOutput("Loss", "A 1-D tensor with shape N."); + AddOutput("Out", "A 1-D tensor with shape N."); AddComment(R"DOC( Cross entropy loss with softmax are used as the output layer extensively. This operator computes the softmax normalized values for each row of the input @@ -56,14 +56,14 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), - "Input(Loss@Grad) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), "Input(Softmax) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), "Input(Lable) should be not null."); - ctx.Output(framework::GradVarName("Logits")) + ctx.Output(framework::GradVarName("Logits")) ->Resize(ctx.Input("Softmax")->dims()); } }; @@ -81,8 +81,8 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 1UL, "The label should be a 1-d tensor."); - ctx.Output("Softmax")->Resize(logits->dims()); - ctx.Output("Loss")->Resize({logits->dims()[0], 1}); + ctx.Output("Softmax")->Resize(logits->dims()); + ctx.Output("Out")->Resize({logits->dims()[0], 1}); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index 922bb19d4d..5af6a521a8 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,97 @@ limitations under the License. */ #define EIGEN_USE_GPU -#include "softmax_with_cross_entropy_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/utils.h" -namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void CrossEntropyKernel(T* out, const T* softmax_out, + const int* label, const int batch_size, + const int class_num) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= batch_size) return; + PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); + out[i] = -math::tolerable_value(log(softmax_out[i * class_num + label[i]])); +} + +template +__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out, + const int* label, + const int batch_size, + const int class_num) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= batch_size) return; + + PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); + softmax_out[i * class_num + label[i]] -= 1.; +} + +template +class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + + // Calculate ths softmax outputs. + const Tensor* logits = context.Input("Logits"); + Tensor* softmax = context.Output("Softmax"); + softmax->mutable_data(context.GetPlace()); + math::SoftmaxFunctor()(logits, softmax, context); + T* softmax_out = softmax->data(); + + // Calculate the cross entropy loss based on hard labels. + const int* label_data = context.Input("Label")->data(); + Tensor* loss = context.Output("Out"); + loss->mutable_data(context.GetPlace()); + T* loss_data = loss->data(); + + const int batch_size = logits->dims()[0]; + const int class_num = logits->dims()[1]; + int block = 512; + int grid = (batch_size + block - 1) / block; -// TODO(caoying) add GPU kernel + CrossEntropyKernel<<>>(loss_data, softmax_out, label_data, + batch_size, class_num); + } +}; + +template +class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + + Tensor* logit_grad = + context.Output(framework::GradVarName("Logits")); + logit_grad->ShareDataWith(*context.Input("Softmax")); + T* logit_grad_data = logit_grad->data(); + + const int batch_size = logit_grad->dims()[0]; + const int class_num = logit_grad->dims()[1]; + + const int* label_data = context.Input("Label")->data(); + + const int block = 512; + const int grid = (batch_size + block - 1) / block; + + CrossEntropyWithSoftmaxGradKernel<<>>( + logit_grad_data, label_data, batch_size, class_num); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy, + ops::SoftmaxWithCrossEntropyCUDAKernel); +REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyGradCUDAKernel); diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index e147cdb815..38b92a0bcd 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -30,8 +30,7 @@ template class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto place = context.GetPlace(); - PADDLE_ENFORCE(platform::is_cpu_place(place), + PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), "This kernel only runs on CPU."); // Calculate ths softmax outputs. @@ -45,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { T* softmax_out = softmax->data(); const int* label_data = context.Input("Label")->data(); - Tensor* loss = context.Output("Loss"); + Tensor* loss = context.Output("Out"); loss->mutable_data(context.GetPlace()); T* loss_data = loss->data(); @@ -74,7 +73,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { const int* label_data = context.Input("Label")->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - logit_grad_data[index] -= .1; + logit_grad_data[index] -= 1.; } } }; diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 5e06525d61..253e7b8a24 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,7 +1,6 @@ import unittest import numpy from op_test import OpTest -import pdb class TestCrossEntropy(OpTest): diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py index 4e35c063b9..e965dd0482 100644 --- a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py @@ -1,6 +1,5 @@ import unittest import numpy as np -import pdb from op_test import OpTest from test_softmax_op import stable_softmax @@ -11,7 +10,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.op_type = "softmax_with_cross_entropy" MAX_BATCH_SIZE = 23 - MAX_CLASS_NUM = 10 + MAX_CLASS_NUM = 17 batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0] class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0] @@ -26,13 +25,13 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): dtype="float32") self.inputs = {"Logits": logits, "Label": labels} - self.outputs = {"Softmax": softmax, "Loss": cross_entropy} + self.outputs = {"Softmax": softmax, "Out": cross_entropy} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(["Logits"], "Loss") + self.check_grad(["Logits"], "Out", max_relative_error=0.05) if __name__ == "__main__": -- GitLab