未验证 提交 97cd7089 编写于 作者: C chajchaj 提交者: GitHub

cherry-pick:add softmax_switch for softmax_with_cross_entropy_op (#32105)

* cherry-pick:add softmax_switch for softmax_with_cross_entropy_op, test=develop

* add softmax_switch for softmax_with_cross_entropy_op, test=develop

* delete using EigenMatrix in softmax_with_cross_entropy_op.h, test=develop

* add REGISTER_OP_VERSION for softmax_switch attr of softmax_with_cross_entropy_op, test=develop

* cherry-pick:add softmax_switch for softmax_with_cross_entropy_op,test=develop

* change softmax_switch to use_softmax, test=develop

* fix code format for softmax_with_cross_entropy_op.cc, test=develop
上级 1b3cd0fb
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -53,6 +54,10 @@ class SoftmaxWithCrossEntropyOpMaker
"(bool, default: false), A flag to indicate whether to interpretant "
"the given labels as soft labels.")
.SetDefault(false);
AddAttr<bool>(
"use_softmax",
"(bool, default: true), A flag to indicate whether to do softmax ")
.SetDefault(true);
AddAttr<bool>(
"numeric_stable_mode",
"(bool, default: true), A flag to indicate whether to use more "
......@@ -312,3 +317,10 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<float>,
ops::SoftmaxWithCrossEntropyGradKernel<double>);
REGISTER_OP_VERSION(softmax_with_cross_entropy)
.AddCheckpoint(
R"ROC(
Add a new attribute [use_softmax] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_softmax", "A flag to indicate whether to do softmax", true));
......@@ -66,6 +66,57 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
}
}
template <typename T>
__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int n, const int d,
const int remain) {
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int idx_n = ids / d;
int idx_remain = ids % remain;
int idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]);
}
}
template <typename T>
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
const int64_t* labels,
const int n, const int d,
const int remain,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) {
int idx_n = index / remain;
int idx_remain = index % remain;
int tmp = labels[index];
int idx = idx_n * d + tmp * remain + idx_remain;
if (ignore_index != tmp) {
logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
}
}
}
template <typename T>
__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
const int num, const int d,
const int remain,
const int64_t* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d;
int idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain;
int k = (index % d) / remain;
if (labels[idx_lbl] == ignore_index || labels[idx_lbl] != k) {
logit_grad[index] = static_cast<T>(0.);
} else {
logit_grad[index] *= loss_grad[idx_lbl];
}
}
}
} // namespace
static __device__ __forceinline__ platform::float16 exp_on_device(
......@@ -248,6 +299,160 @@ static __global__ void RowReductionForSoftmaxAndCrossEntropy(
if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim>
static __global__ void RowReductionForCrossEntropy(const T* logits_data,
const T* labels_data,
T* loss_data, int d,
int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax, labels data view as [n, axis_dim, remain]
// loss_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim;
int idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d;
// log_diff_max_sum shares memory with loss
auto block_log_diff_max_sum = loss_data[blockIdx.x];
auto tmp = log_on_device(logits_data[beg_idx]); // when not with softmax,
// softmax is stored in
// logits_data
auto loss = -labels_data[beg_idx] * tmp;
int step = BlockDim * remain;
beg_idx += step;
while (beg_idx < end_idx) {
tmp = log_on_device(logits_data[beg_idx]); // when not with softmax,
// softmax is stored in
// logits_data
loss -= (labels_data[beg_idx] * tmp);
beg_idx += step;
}
loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}
template <typename T>
struct HardLabelCrossEntropyFunctor {
public:
HardLabelCrossEntropyFunctor(const int64_t* labels, T* loss,
const T* logits_data, int d, int axis_dim)
: labels_(labels),
loss_(loss),
logits_data_(logits_data),
d_(d),
axis_dim_(axis_dim) {}
__device__ void operator()(int idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_;
int idx_n = idx / d_;
int idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
// It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) {
} else {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
}
}
private:
const int64_t* labels_;
T* loss_;
const T* logits_data_;
int d_;
int axis_dim_;
};
template <typename T>
struct HardLabelCrossEntropyFunctorWithIgnoreIdx {
public:
HardLabelCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss,
const T* logits_data, int d,
int axis_dim, int ignore_idx)
: labels_(labels),
loss_(loss),
logits_data_(logits_data),
d_(d),
axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_;
int idx_n = idx / d_;
int idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
if (idx_axis == ignore_idx_) {
loss_[idx_lbl] = 0;
return;
}
if (idx_axis == labels_[idx_lbl]) {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
}
}
private:
const int64_t* labels_;
T* loss_;
const T* logits_data_;
int d_;
int axis_dim_;
int ignore_idx_;
};
template <typename T>
static void HardLabelCrossEntropy(const platform::CUDADeviceContext& ctx,
const T* logits_data,
const int64_t* labels_data, T* loss_data,
int n, int d, int axis_dim, int ignore_idx) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim;
auto stream = ctx.stream();
#define CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d); \
if (ignore_idx >= 0 && ignore_idx < axis_dim) { \
for_range(HardLabelCrossEntropyFunctorWithIgnoreIdx<T>( \
labels_data, loss_data, logits_data, d, axis_dim, ignore_idx)); \
} else { \
for_range(HardLabelCrossEntropyFunctor<T>(labels_data, loss_data, \
logits_data, d, axis_dim)); \
} \
} break
switch (block_dim) {
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW(platform::errors::Unavailable(
"Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
break;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctor {
public:
......@@ -420,6 +625,43 @@ static void SoftmaxWithCrossEntropyFusedKernel(
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
// not with softmax
template <typename T>
static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data,
T* loss_data, int n, int d, int axis_dim,
cudaStream_t stream) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim;
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
RowReductionForCrossEntropy<T, \
BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, d, axis_dim); \
break
switch (block_dim) {
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW(platform::errors::Unavailable(
"Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
break;
}
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -428,6 +670,73 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
const bool use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax
if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = softmax->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis];
const int n = SizeToAxis(axis, softmax->dims());
const int d = SizeFromAxis(axis, softmax->dims());
auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), softmax_out,
static_cast<T>(1));
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
return;
}
auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index");
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, 1});
softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});
// math::CrossEntropyFunctor support axis is the last
if (axis == -1) {
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
soft_label, ignore_index, axis_dim);
return;
}
// if axis is not the last, we need a new impliment
if (soft_label) {
auto* logits_data = softmax->data<T>();
auto* labels_data = labels->data<T>();
CrossEntropyFusedKernel(logits_data, labels_data, loss_data, n, d,
axis_dim,
context.cuda_device_context().stream());
} else { // HardLabel
auto* logits_data = softmax->data<T>();
auto* labels_data = labels->data<int64_t>();
HardLabelCrossEntropy<T>(context.cuda_device_context(), logits_data,
labels_data, loss_data, n, d, axis_dim,
ignore_index);
}
// cause of input is softmax
// copy to output softmax, directly
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), softmax_out);
return;
}
const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax");
......@@ -514,6 +823,34 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
int block = 512;
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax
if (!use_softmax) {
if (context.Attr<bool>("soft_label")) {
int grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>();
SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
Tensor logits_grad_2d;
logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
int grid = (n * remain + block - 1) / block;
const int64_t* label_data = labels->data<int64_t>();
HardLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d;
grid = (num + block - 1) / block;
ScaleCrossEntropyGradient<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, num, d, remain, label_data,
ignore_index);
}
return;
}
// with softmax, continue
if (context.Attr<bool>("soft_label")) {
int64_t grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>();
......
......@@ -34,6 +34,46 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(context.GetPlace()), true,
platform::errors::Unimplemented("This kernel only runs on CPU."));
const bool use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax
if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const bool soft_label = context.Attr<bool>("soft_label");
const int rank = softmax->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis];
softmax_out->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, softmax->dims());
const int d = SizeFromAxis(axis, softmax->dims());
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<int>("ignore_index"), axis_dim);
// cause of input is softmax
// copy to output softmax, directly
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), softmax_out);
return;
}
const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax");
......@@ -76,7 +116,9 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
if (logit_grad != softmax) {
const bool use_softmax = context.Attr<bool>("use_softmax");
if (logit_grad != softmax || !use_softmax) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
}
......@@ -99,28 +141,94 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (!use_softmax) {
// use_softmax step1
if (soft_label) {
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
(-lbl_mat / logit_grad_mat); // for each sample ,i is sample id
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
logit_grad_mat;
}
// use_softmax step2
else {
const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
if (label_data[idx] == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
logit_grad_data[i * d + label_data[idx] * remain + j] =
(-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) *
out_grad_data[idx];
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
if (k !=
label_data[idx]) { // label_data[idx]: this sample's label
logit_grad_data[i * d + k * remain + j] = 0;
}
}
}
}
}
}
return;
}
// for use_softmax=False, continue
if (soft_label) {
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
// when soft_label = True, ignore_index is not supported
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat);
(logit_grad_mat - lbl_mat); // for each sample ,i is sample id
// 1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
// P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
// all class's labels
// 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i]
// for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
// operation
} else {
logit_grad_mat.device(place) =
logit_grad_mat *
logit_grad_mat * // element_wise multiply
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < remain; j++) {
int idx = i * remain + j;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
if (label_data[idx] == ignore_index) {
for (int k = 0; k < axis_dim; ++k) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
// for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
// remain + j] = [i * d + label_data[idx]]
// let idx_x = i * d + label_data[idx] * remain + j,
// logit_grad_data[idx_x] = logit_grad_data[idx_x] -
// out_grad_data[idx]
// note: logit_grad_mat = logit_grad_mat * out_grad_mat
// so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) *
// out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy
logit_grad_data[i * d + label_data[idx] * remain + j] -=
out_grad_data[idx];
}
......
......@@ -55,6 +55,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
self.use_softmax = True
def setUp(self):
self.initParams()
......@@ -75,7 +76,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
self.inputs = {"Logits": logits, "Label": labels}
if self.use_softmax == False:
self.inputs = {"Logits": softmax, "Label": labels}
else:
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
......@@ -84,6 +89,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
"ignore_index": self.ignore_index,
"use_softmax": self.use_softmax,
}
if self.axis != -1:
......@@ -93,7 +99,215 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=5e-5)
self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001)
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [13, 8]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [13, 8]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
##############################################################################
#NotWithSoftmax_SoftLabel_2D start
##############################################################################
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
##############################################################################
#NotWithSoftmax_SoftLabel_2D end
##############################################################################
##############################################################################
#NotWithSoftmax_HardLabel_2D start
##############################################################################
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
##############################################################################
#NotWithSoftmax_HardLabel_2D end
##############################################################################
##############################################################################
#NotWithSoftmax_HardLabel_2D_Ignore start
##############################################################################
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.soft_label = False
self.shape = [13, 8]
self.axis = -1
self.ignore_index = 2
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.soft_label = False
self.shape = [13, 8]
self.axis = 1
self.ignore_index = 2
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = 2
self.dtype = np.float64
self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 2
self.ignore_index = 2
self.shape = [3, 5, 7, 11]
self.use_softmax = False #default is true, means "with softmax"
##############################################################################
#NotWithSoftmax_HardLabel_2D_Ignore end
##############################################################################
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
......@@ -105,6 +319,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = True
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -182,6 +397,7 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
self.use_softmax = True
def test_check_output(self):
self.check_output()
......@@ -203,6 +419,7 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
self.ignore_index = 5
self.axis = -1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
......@@ -214,6 +431,7 @@ class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
self.ignore_index = 4
self.axis = -1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
......@@ -230,6 +448,7 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
self.axis = 0
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
......@@ -246,6 +465,7 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
......@@ -262,6 +482,7 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
......@@ -278,6 +499,7 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
......@@ -295,6 +517,7 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
self.axis = -1
self.ignore_index = -1
self.shape = [3, 5, 7, 1]
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
......@@ -307,6 +530,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
self.axis = 0
self.ignore_index = -1
self.dtype = np.float16
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
......@@ -319,6 +543,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
self.axis = 1
self.ignore_index = -1
self.dtype = np.float16
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
......@@ -331,6 +556,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
self.axis = 2
self.ignore_index = -1
self.dtype = np.float16
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
......@@ -343,6 +569,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
self.axis = 0
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
......@@ -355,6 +582,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
self.axis = 1
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
......@@ -367,6 +595,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
self.axis = 2
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
......@@ -379,6 +608,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
self.axis = 3
self.ignore_index = -1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
......@@ -391,6 +621,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
self.ignore_index = 1
self.axis = 0
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
......@@ -403,6 +634,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
self.ignore_index = 0
self.axis = 1
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
......@@ -415,6 +647,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
self.ignore_index = 3
self.axis = 2
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
......@@ -427,6 +660,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
self.ignore_index = 3
self.axis = 3
self.dtype = np.float64
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp):
......@@ -444,6 +678,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp):
self.ignore_index = -1
self.dtype = np.float64
self.logits = np.full(self.shape, -500.0).astype(self.dtype)
self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
......@@ -462,6 +697,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
self.dtype = np.float64
self.logits = np.full(self.shape, 1000.0).astype(self.dtype)
self.logits[:, :, 0, :] = -1000.0
self.use_softmax = True
if __name__ == "__main__":
......
......@@ -1388,8 +1388,6 @@ def cross_entropy(input,
"should be '-100', but received %s, which is not allowed." %
ignore_index)
softmax_switch = use_softmax
input_dims = len(list(input.shape))
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
......@@ -1402,7 +1400,7 @@ def cross_entropy(input,
_, out = core.ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'softmax_switch', softmax_switch)
'use_softmax', use_softmax)
if weight is not None:
......@@ -1484,7 +1482,7 @@ def cross_entropy(input,
'ignore_index': ignore_index,
'numeric_stable_mode': True,
'axis': axis,
'softmax_switch': softmax_switch
'use_softmax': use_softmax
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册