提交 3df05389 编写于 作者: J jerrywgz

replace -100 to kIgnoreIndex

上级 13e254fa
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
const int kIgnoreIndex = -100;
class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public: public:
...@@ -100,11 +101,11 @@ class SigmoidCrossEntropyWithLogitsOpMaker ...@@ -100,11 +101,11 @@ class SigmoidCrossEntropyWithLogitsOpMaker
AddOutput("Out", AddOutput("Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D " "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D "
" of elementwise logistic losses."); " of elementwise logistic losses.");
AddAttr<int>( AddAttr<int>("ignore_index",
"ignore_index", "(int, default kIgnoreIndex), Specifies a target value that "
"(int, default -100), Specifies a target value that is ignored and" "is ignored and"
"does not contribute to the input gradient.") "does not contribute to the input gradient.")
.SetDefault(-100); .SetDefault(kIgnoreIndex);
AddComment(R"DOC( AddComment(R"DOC(
SigmoidCrossEntropyWithLogits Operator. SigmoidCrossEntropyWithLogits Operator.
......
...@@ -170,6 +170,8 @@ __all__ = [ ...@@ -170,6 +170,8 @@ __all__ = [
'bilinear_tensor_product', 'bilinear_tensor_product',
] ]
kIgnoreIndex = -100
def fc(input, def fc(input,
size, size,
...@@ -1103,7 +1105,7 @@ def dropout(x, ...@@ -1103,7 +1105,7 @@ def dropout(x,
return out return out
def cross_entropy(input, label, soft_label=False, ignore_index=-100): def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
""" """
**Cross Entropy Layer** **Cross Entropy Layer**
...@@ -4796,7 +4798,7 @@ def multiplex(inputs, index): ...@@ -4796,7 +4798,7 @@ def multiplex(inputs, index):
def softmax_with_cross_entropy(logits, def softmax_with_cross_entropy(logits,
label, label,
soft_label=False, soft_label=False,
ignore_index=-100, ignore_index=kIgnoreIndex,
numeric_stable_mode=False, numeric_stable_mode=False,
return_softmax=False): return_softmax=False):
""" """
...@@ -7892,7 +7894,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -7892,7 +7894,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
@templatedoc() @templatedoc()
def sigmoid_cross_entropy_with_logits(x, label, ignore_index=-100, name=None): def sigmoid_cross_entropy_with_logits(x,
label,
ignore_index=kIgnoreIndex,
name=None):
""" """
${comment} ${comment}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册