提交 79def5e6 编写于 作者: Q qijun

refine CrossEntropyFunctor

上级 c634a848
7 合并请求!11636[IMPORTANT] MKLDNN layout: Support for sum operator,!8482Release/0.11.0,!8190Release/0.11.0,!8189Release/0.11.0,!6633给线性回归的get-started代码加上了预测的示例~~,!4615Feature/tensor array add python binding,!4492Refine some functors
......@@ -18,14 +18,6 @@ namespace paddle {
namespace operators {
namespace {
// TODO(qingqing): make zero setting a common function.
template <typename T>
__global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
}
}
template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
......@@ -99,11 +91,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
} else {
Zero<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, batch_size * class_num);
math::SetConstant<platform::GPUPlace, T>(ctx.device_context(), dx, 0);
auto* label_data = label->data<int>();
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cross_entropy.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......@@ -37,7 +38,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
ctx, y, x, labels, ctx.Attr<bool>("softLabel"));
ctx.device_context(), y, x, labels, ctx.Attr<bool>("softLabel"));
}
};
......@@ -69,8 +70,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>();
const int* label_data = label->data<int>();
// TODO(qingqing): make zero setting a common function.
memset(dx_data, 0, sizeof(T) * batch_size * class_num);
math::SetConstant<platform::CPUPlace, T>(ctx.device_context(), dx, 0);
for (int i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
......
......@@ -26,8 +26,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class CrossEntropyFunctor<platform::CPUPlace, T> {
public:
void operator()(const framework::ExecutionContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
void operator()(const platform::DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel) {
const int batch_size = prob->dims()[0];
if (softLabel) {
......@@ -35,7 +35,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
auto lbl = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*out);
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
loss.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
......
......@@ -74,8 +74,8 @@ using Tensor = framework::Tensor;
template <typename T>
class CrossEntropyFunctor<platform::GPUPlace, T> {
public:
void operator()(const framework::ExecutionContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
void operator()(const framework::DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel) {
const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
......@@ -87,20 +87,18 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
const T* label_data = labels->data<T>();
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(loss_data, prob_data, label_data, class_num);
SoftCrossEntropyKernel<T><<<
batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
loss_data, prob_data, label_data, class_num);
} else {
const int* label_data = labels->data<int>();
int block = 512;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(loss_data, prob_data, label_data,
batch_size, class_num);
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
loss_data, prob_data, label_data, batch_size, class_num);
}
}
};
......
......@@ -37,9 +37,7 @@ struct TolerableValue {
template <typename Place, typename T>
class CrossEntropyFunctor {
public:
// (TODO caoying) it is much better to use DeviceContext as the first
// parameter.
void operator()(const framework::ExecutionContext& context,
void operator()(const platform::DeviceContext& context,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel);
};
......
......@@ -42,7 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册