提交 97509b68 编写于 作者: C caoying03

cross entropy as a functor to avoid duplicated codes.

上级 40aee48a
......@@ -88,10 +88,14 @@ add_subdirectory(math)
set(DEPS_OPS
recurrent_op
cond_op)
cond_op
cross_entropy_op
softmax_with_cross_entropy_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy_function)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -12,62 +12,12 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
namespace {
// TODO(qingqing): make zero setting a common function.
template <typename T>
__global__ void Zero(T* X, const int N) {
......@@ -100,6 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
}
}
} // namespace
template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
......@@ -107,36 +58,13 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* label = ctx.Input<Tensor>("Label");
Tensor* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();
T* y_data = y->mutable_data<T>(ctx.GetPlace());
int batch_size = x->dims()[0];
int class_num = x->dims()[1];
if (ctx.Attr<bool>("softLabel")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
} else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
int block = 512;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data,
batch_size, class_num);
}
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
ctx, y, x, label, ctx.Attr<bool>("softLabel"));
}
};
......@@ -150,6 +78,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* label = ctx.Input<Tensor>("Label");
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const T* dy_data =
ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
#include "paddle/operators/math/cross_entropy.h"
namespace paddle {
namespace operators {
......@@ -25,18 +25,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
if (x == -INFINITY) return -kApproInf;
return x;
}
};
template <typename T>
class CrossEntropyOpKernel : public framework::OpKernel {
public:
......@@ -46,28 +34,10 @@ class CrossEntropyOpKernel : public framework::OpKernel {
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* y = ctx.Output<Tensor>("Y");
T* y_data = y->mutable_data<T>(ctx.GetPlace());
const int batch_size = x->dims()[0];
if (ctx.Attr<bool>("softLabel")) {
auto prob = EigenMatrix<T>::From(*x);
auto lbl_mat = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*y);
y->mutable_data<T>(ctx.GetPlace());
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl_mat * prob.log().unaryExpr(TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else {
const int class_num = x->dims()[1];
const T* x_data = x->data<T>();
const int* label_data = labels->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
y_data[i] = -TolerableValue<T>()(std::log(x_data[index]));
}
}
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
ctx, y, x, labels, ctx.Attr<bool>("softLabel"));
}
};
......
......@@ -3,10 +3,13 @@ if(WITH_GPU)
im2col.cu DEPS cblas device_context operator)
nv_library(softmax_function SRCS softmax.cc softmax.cu
DEPS operator)
nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu
DEPS operator)
else()
cc_library(math_function SRCS math_function.cc im2col.cc
DEPS cblas device_context operator)
cc_library(softmax_function SRCS softmax.cc DEPS operator)
cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cross_entropy.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class CrossEntropyFunctor<platform::CPUPlace, T> {
public:
void operator()(const framework::ExecutionContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel) {
const int batch_size = prob->dims()[0];
if (softLabel) {
auto in = EigenMatrix<T>::From(*prob);
auto lbl = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*out);
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else {
const int class_num = prob->dims()[1];
const T* prob_data = prob->data<T>();
T* loss_data = out->data<T>();
const int* label_data = labels->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
}
}
}
};
template class CrossEntropyFunctor<platform::CPUPlace, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cross_entropy.h"
namespace paddle {
namespace operators {
namespace math {
namespace {
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -math::TolerableValue<T>()(log(X[i * D + label[i]]));
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] +=
math::TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
} // namespace
using Tensor = framework::Tensor;
template <typename T>
class CrossEntropyFunctor<platform::GPUPlace, T> {
public:
void operator()(const framework::ExecutionContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel) {
const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
int batch_size = prob->dims()[0];
int class_num = prob->dims()[1];
if (softLabel) {
const T* label_data = labels->data<T>();
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(loss_data, prob_data, label_data, class_num);
} else {
const int* label_data = labels->data<int>();
int block = 512;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(loss_data, prob_data, label_data,
batch_size, class_num);
}
}
};
template class CrossEntropyFunctor<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
if (x == -INFINITY) return -kApproInf;
return x;
}
};
template <typename Place, typename T>
class CrossEntropyFunctor {
public:
// (TODO caoying) it is much better to use DeviceContext as the first
// parameter.
void operator()(const framework::ExecutionContext& context,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,26 +14,14 @@
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/math/softmax.h"
#include "paddle/operators/softmax_with_cross_entropy_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void CrossEntropy(T* out, const T* softmax_out, const int* labels,
const int batch_size, const int class_num) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
PADDLE_ASSERT(labels[i] >= 0 && labels[i] < class_num);
out[i] =
-TolerableValue<T>()(std::log(softmax_out[i * class_num + labels[i]]));
}
}
namespace {
template <typename T>
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
const int* labels, const int batch_size,
......@@ -50,42 +38,6 @@ __global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
......@@ -98,6 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
}
}
} // namespace
template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
......@@ -105,36 +58,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
T* loss_data =
context.Output<Tensor>("Loss")->mutable_data<T>(context.GetPlace());
const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax");
T* softmax_out = softmax->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
const int batch_size = logits->dims()[0];
const int class_num = logits->dims()[1];
int block = 512;
int grid = (batch_size + block - 1) / block;
Tensor* loss = context.Output<Tensor>("Loss");
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
if (context.Attr<bool>("softLabel")) {
const T* label_data = context.Input<Tensor>("Label")->data<T>();
block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(loss_data, softmax_out, label_data, class_num);
} else {
const int* label_data = context.Input<Tensor>("Label")->data<int>();
CrossEntropy<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(loss_data, softmax_out, label_data,
batch_size, class_num);
}
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
}
};
......
......@@ -15,7 +15,7 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/math/cross_entropy.h"
#include "paddle/operators/math/softmax.h"
namespace paddle {
......@@ -37,31 +37,12 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
T* softmax_data = softmax->mutable_data<T>(context.GetPlace());
T* loss_data = loss->mutable_data<T>(context.GetPlace());
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
const int batch_size = logits->dims()[0];
if (context.Attr<bool>("softLabel")) {
// (TODO caoying) the forward implementation can be further optimized.
// Current implementation is exactly cross entropy after softmax.
auto prob = EigenMatrix<T>::From(*softmax);
auto lbl_mat = EigenMatrix<T>::From(*labels);
auto loss_mat = EigenMatrix<T>::From(*loss);
loss_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
-((lbl_mat * prob.log().unaryExpr(TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else {
const int* label_data = labels->data<int>();
const int class_num = logits->dims()[1];
for (int i = 0; i < batch_size; ++i)
loss_data[i] = -TolerableValue<T>()(
std::log(softmax_data[i * class_num + label_data[i]]));
}
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册