未验证 提交 9a589de8 编写于 作者: C chajchaj 提交者: GitHub

cherry-pick:change softmax_with_cross_entropy_op's parameter name from...

cherry-pick:change softmax_with_cross_entropy_op's parameter name from softmax_switch to use_softmax (#32750)

* change parameter name from softmax_switch to use_softmax, test=develop

* cherry-pick:change parameter name from softmax_switch to use_softmax, test=develop
上级 0bb079cd
...@@ -55,7 +55,7 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -55,7 +55,7 @@ class SoftmaxWithCrossEntropyOpMaker
"the given labels as soft labels.") "the given labels as soft labels.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>( AddAttr<bool>(
"softmax_switch", "use_softmax",
"(bool, default: true), A flag to indicate whether to do softmax ") "(bool, default: true), A flag to indicate whether to do softmax ")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>( AddAttr<bool>(
...@@ -320,7 +320,6 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, ...@@ -320,7 +320,6 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
REGISTER_OP_VERSION(softmax_with_cross_entropy) REGISTER_OP_VERSION(softmax_with_cross_entropy)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
Add a new attribute [softmax_switch] )ROC", Add a new attribute [use_softmax] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr( paddle::framework::compatible::OpVersionDesc().NewAttr(
"softmax_switch", "A flag to indicate whether to do softmax", "use_softmax", "A flag to indicate whether to do softmax", true));
true));
...@@ -772,10 +772,10 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -772,10 +772,10 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
platform::is_gpu_place(context.GetPlace()), true, platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("softmax_with_cross_entropy operator's " platform::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device.")); "CUDA kernel only runs on GPU device."));
const bool softmax_switch = context.Attr<bool>("softmax_switch"); const bool use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!softmax_switch) { if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits"); const Tensor* softmax = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label"); const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax_out = context.Output<Tensor>("Softmax"); Tensor* softmax_out = context.Output<Tensor>("Softmax");
...@@ -925,10 +925,10 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -925,10 +925,10 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
int block = 512; int block = 512;
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
auto softmax_switch = context.Attr<bool>("softmax_switch"); auto use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!softmax_switch) { if (!use_softmax) {
if (context.Attr<bool>("soft_label")) { if (context.Attr<bool>("soft_label")) {
int grid = (n * d + block - 1) / block; int grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>(); const T* label_data = labels->data<T>();
......
...@@ -31,10 +31,10 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -31,10 +31,10 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_cpu_place(context.GetPlace()), true, platform::is_cpu_place(context.GetPlace()), true,
platform::errors::Unimplemented("This kernel only runs on CPU.")); platform::errors::Unimplemented("This kernel only runs on CPU."));
const bool softmax_switch = context.Attr<bool>("softmax_switch"); const bool use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!softmax_switch) { if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits"); const Tensor* softmax = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label"); const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax_out = context.Output<Tensor>("Softmax"); Tensor* softmax_out = context.Output<Tensor>("Softmax");
...@@ -113,9 +113,9 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -113,9 +113,9 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax"); const Tensor* softmax = context.Input<Tensor>("Softmax");
const bool softmax_switch = context.Attr<bool>("softmax_switch"); const bool use_softmax = context.Attr<bool>("use_softmax");
if (logit_grad != softmax || !softmax_switch) { if (logit_grad != softmax || !use_softmax) {
framework::TensorCopy(*softmax, context.GetPlace(), framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad); context.device_context(), logit_grad);
} }
...@@ -138,8 +138,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -138,8 +138,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d); auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>() auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device(); .eigen_device();
if (!softmax_switch) { if (!use_softmax) {
// softmax_switch step1 // use_softmax step1
if (soft_label) { if (soft_label) {
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d); auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) = logit_grad_mat.device(place) =
...@@ -148,7 +148,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -148,7 +148,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) * out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
logit_grad_mat; logit_grad_mat;
} }
// softmax_switch step2 // use_softmax step2
else { else {
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->data<T>();
...@@ -181,7 +181,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -181,7 +181,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
return; return;
} }
// for softmax_switch=False, continue // for use_softmax=False, continue
if (soft_label) { if (soft_label) {
// when soft_label = True, ignore_index is not supported // when soft_label = True, ignore_index is not supported
......
...@@ -56,7 +56,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -56,7 +56,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [41, 37] self.shape = [41, 37]
self.softmax_switch = True self.use_softmax = True
def setUp(self): def setUp(self):
self.initParams() self.initParams()
...@@ -77,7 +77,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -77,7 +77,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
loss = cross_entropy(softmax, labels, self.soft_label, self.axis, loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index) self.ignore_index)
if self.softmax_switch == False: if self.use_softmax == False:
self.inputs = {"Logits": softmax, "Label": labels} self.inputs = {"Logits": softmax, "Label": labels}
else: else:
self.inputs = {"Logits": logits, "Label": labels} self.inputs = {"Logits": logits, "Label": labels}
...@@ -90,7 +90,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -90,7 +90,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
"numeric_stable_mode": self.numeric_stable_mode, "numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label, "soft_label": self.soft_label,
"ignore_index": self.ignore_index, "ignore_index": self.ignore_index,
"softmax_switch": self.softmax_switch, "use_softmax": self.use_softmax,
} }
if self.axis != -1: if self.axis != -1:
...@@ -117,7 +117,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( ...@@ -117,7 +117,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
...@@ -130,7 +130,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D( ...@@ -130,7 +130,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
############################################################################## ##############################################################################
...@@ -146,7 +146,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D( ...@@ -146,7 +146,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D(
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
...@@ -159,7 +159,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2( ...@@ -159,7 +159,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
...@@ -172,7 +172,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3( ...@@ -172,7 +172,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
...@@ -185,7 +185,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4( ...@@ -185,7 +185,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
self.axis = 3 self.axis = 3
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
############################################################################## ##############################################################################
...@@ -207,7 +207,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D( ...@@ -207,7 +207,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D(
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
...@@ -220,7 +220,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2( ...@@ -220,7 +220,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
...@@ -233,7 +233,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3( ...@@ -233,7 +233,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
...@@ -246,7 +246,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4( ...@@ -246,7 +246,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
self.axis = 3 self.axis = 3
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
############################################################################## ##############################################################################
...@@ -268,7 +268,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore( ...@@ -268,7 +268,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore(
self.axis = -1 self.axis = -1
self.ignore_index = 2 self.ignore_index = 2
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
...@@ -281,7 +281,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis( ...@@ -281,7 +281,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
self.axis = 1 self.axis = 1
self.ignore_index = 2 self.ignore_index = 2
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
...@@ -294,7 +294,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore( ...@@ -294,7 +294,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
self.axis = -1 self.axis = -1
self.ignore_index = 2 self.ignore_index = 2
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
...@@ -307,7 +307,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( ...@@ -307,7 +307,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
self.axis = 2 self.axis = 2
self.ignore_index = 2 self.ignore_index = 2
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = False #default is true, means "with softmax" self.use_softmax = False #default is true, means "with softmax"
############################################################################## ##############################################################################
...@@ -324,7 +324,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): ...@@ -324,7 +324,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -403,7 +403,7 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp): ...@@ -403,7 +403,7 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [41, 37] self.shape = [41, 37]
self.softmax_switch = True self.use_softmax = True
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -429,7 +429,7 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): ...@@ -429,7 +429,7 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
self.ignore_index = 5 self.ignore_index = 5
self.axis = -1 self.axis = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
...@@ -441,7 +441,7 @@ class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): ...@@ -441,7 +441,7 @@ class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
self.ignore_index = 4 self.ignore_index = 4
self.axis = -1 self.axis = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
...@@ -458,7 +458,7 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): ...@@ -458,7 +458,7 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
self.axis = 0 self.axis = 0
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
...@@ -475,7 +475,7 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): ...@@ -475,7 +475,7 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
...@@ -492,7 +492,7 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): ...@@ -492,7 +492,7 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
...@@ -509,7 +509,7 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): ...@@ -509,7 +509,7 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
self.axis = 3 self.axis = 3
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
...@@ -527,7 +527,7 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( ...@@ -527,7 +527,7 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 1] self.shape = [3, 5, 7, 1]
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
...@@ -540,7 +540,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( ...@@ -540,7 +540,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
self.axis = 0 self.axis = 0
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float16 self.dtype = np.float16
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
...@@ -553,7 +553,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( ...@@ -553,7 +553,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float16 self.dtype = np.float16
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
...@@ -566,7 +566,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( ...@@ -566,7 +566,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float16 self.dtype = np.float16
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
...@@ -579,7 +579,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( ...@@ -579,7 +579,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
self.axis = 0 self.axis = 0
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
...@@ -592,7 +592,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( ...@@ -592,7 +592,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
...@@ -605,7 +605,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( ...@@ -605,7 +605,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
...@@ -618,7 +618,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( ...@@ -618,7 +618,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
self.axis = 3 self.axis = 3
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
...@@ -631,7 +631,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( ...@@ -631,7 +631,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
self.ignore_index = 1 self.ignore_index = 1
self.axis = 0 self.axis = 0
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
...@@ -644,7 +644,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( ...@@ -644,7 +644,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
self.ignore_index = 0 self.ignore_index = 0
self.axis = 1 self.axis = 1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
...@@ -657,7 +657,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( ...@@ -657,7 +657,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
self.ignore_index = 3 self.ignore_index = 3
self.axis = 2 self.axis = 2
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
...@@ -670,7 +670,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( ...@@ -670,7 +670,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
self.ignore_index = 3 self.ignore_index = 3
self.axis = 3 self.axis = 3
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp):
...@@ -688,7 +688,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): ...@@ -688,7 +688,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp):
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.logits = np.full(self.shape, -500.0).astype(self.dtype) self.logits = np.full(self.shape, -500.0).astype(self.dtype)
self.softmax_switch = True self.use_softmax = True
class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
...@@ -707,7 +707,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): ...@@ -707,7 +707,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.logits = np.full(self.shape, 1000.0).astype(self.dtype) self.logits = np.full(self.shape, 1000.0).astype(self.dtype)
self.logits[:, :, 0, :] = -1000.0 self.logits[:, :, 0, :] = -1000.0
self.softmax_switch = True self.use_softmax = True
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -1371,8 +1371,6 @@ def cross_entropy(input, ...@@ -1371,8 +1371,6 @@ def cross_entropy(input,
"should be '-100', but received %s, which is not allowed." % "should be '-100', but received %s, which is not allowed." %
ignore_index) ignore_index)
softmax_switch = use_softmax
input_dims = len(list(input.shape)) input_dims = len(list(input.shape))
label_dims = len(list(label.shape)) label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims: if input_dims - 1 != label_dims and input_dims != label_dims:
...@@ -1385,7 +1383,7 @@ def cross_entropy(input, ...@@ -1385,7 +1383,7 @@ def cross_entropy(input,
_, out = core.ops.softmax_with_cross_entropy( _, out = core.ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis, ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'softmax_switch', softmax_switch) 'use_softmax', use_softmax)
if weight is not None: if weight is not None:
...@@ -1467,7 +1465,7 @@ def cross_entropy(input, ...@@ -1467,7 +1465,7 @@ def cross_entropy(input,
'ignore_index': ignore_index, 'ignore_index': ignore_index,
'numeric_stable_mode': True, 'numeric_stable_mode': True,
'axis': axis, 'axis': axis,
'softmax_switch': softmax_switch 'use_softmax': use_softmax
} }
helper = LayerHelper('softmax_with_cross_entropy', **locals()) helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype) softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册