diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 7e744e68e9737f9338d4b787aa28fd1834b145da..a617b9fb1d948340d25853252be79fdd08fe0438 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -248,10 +248,14 @@ class CrossEntropyOp2 : public CrossEntropyOpBase { PADDLE_ENFORCE(ctx->HasOutput("XShape"), "Output(XShape) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("MatchX"), + "Output(MatchX) should be not null."); auto x_dims = ctx->GetInputDim("X"); auto x_dims_vec = framework::vectorize(x_dims); x_dims_vec.push_back(0); ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); + x_dims[x_dims.size() - 1] = 1; + ctx->SetOutputDim("MatchX", x_dims); ctx->ShareLoD("X", /*->*/ "XShape"); } @@ -264,6 +268,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase { class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { public: using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist"); + CrossEntropyGradientOpBase::InferShape(ctx); + } protected: virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { @@ -295,6 +303,8 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { "with 'X' except that the last dimension size is 1. It " "represents the cross entropy loss."); AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); + AddOutput("MatchX", + "X value that matches label, used for gradient computation."); AddAttr("ignore_index", "(int, default -100), Specifies a target value that is" "ignored and does not contribute to the input gradient." @@ -327,7 +337,7 @@ class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker { std::unique_ptr op(new framework::OpDesc()); op->SetType("cross_entropy_grad2"); op->SetInput("Label", Input("Label")); - op->SetInput("Y", Output("Y")); + op->SetInput("MatchX", Output("MatchX")); op->SetInput("XShape", Output("XShape")); op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 05609e4bc20b1c75872be38e057de221a0188b88..7eb663773ed072760c47a2914377b5306ceeb7af 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -138,15 +138,48 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { } }; +template +struct HardLabelCrossEntropyForwardFunctor { + HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x, + const int64_t* label, + int64_t ignore_index, + int64_t feature_size) + : x_(x), + y_(y), + match_x_(match_x), + label_(label), + ignore_index_(ignore_index), + feature_size_(feature_size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto label = label_[idx]; + if (label != ignore_index_) { + auto match_x = x_[idx * feature_size_ + label]; + y_[idx] = -math::TolerableValue()(real_log(match_x)); + match_x_[idx] = match_x; + } else { + y_[idx] = 0; + match_x_[idx] = 0; // any value is ok + } + } + + const T* x_; + T* y_; + T* match_x_; + const int64_t* label_; + int64_t ignore_index_; + int64_t feature_size_; +}; + template struct HardLabelCrossEntropyBackwardFunctor { - HardLabelCrossEntropyBackwardFunctor(T* dx, const T* y, const T* dy, + HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x, const int64_t* label, int64_t ignore_index, int64_t feature_size) : dx_(dx), - y_(y), dy_(dy), + match_x_(match_x), label_(label), ignore_index_(ignore_index), feature_size_(feature_size) {} @@ -156,15 +189,15 @@ struct HardLabelCrossEntropyBackwardFunctor { auto col_idx = idx % feature_size_; auto label = label_[row_idx]; if (label == col_idx && label != ignore_index_) { - dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]); + dx_[idx] = -dy_[row_idx] / match_x_[row_idx]; } else { dx_[idx] = 0; } } T* dx_; - const T* y_; const T* dy_; + const T* match_x_; const int64_t* label_; int64_t ignore_index_; int64_t feature_size_; @@ -174,20 +207,26 @@ template class CrossEntropyOpKernel2 : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x_original = ctx.Input("X"); - int rank = x_original->dims().size(); - - auto x = framework::ReshapeToMatrix(*x_original, rank - 1); - auto label = - framework::ReshapeToMatrix(*ctx.Input("Label"), rank - 1); + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); auto* y = ctx.Output("Y"); - y->mutable_data(ctx.GetPlace()); + auto* match_x = ctx.Output("MatchX"); + + auto& x_dims = x->dims(); + auto feature_size = x_dims[x_dims.size() - 1]; + auto batch_size = framework::product(x->dims()) / feature_size; + + auto* p_x = x->data(); + auto* p_label = label->data(); + auto* p_y = y->mutable_data(ctx.GetPlace()); + auto* p_match_x = match_x->mutable_data(ctx.GetPlace()); auto ignore_index = ctx.Attr("ignore_index"); - math::CrossEntropyFunctor()( - ctx.template device_context(), y, &x, &label, false, - ignore_index); + platform::ForRange for_range( + ctx.template device_context(), batch_size); + for_range(HardLabelCrossEntropyForwardFunctor( + p_x, p_y, p_match_x, p_label, ignore_index, feature_size)); } }; @@ -196,13 +235,13 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* dx = ctx.Output(framework::GradVarName("X")); - auto* y = ctx.Input("Y"); auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* match_x = ctx.Input("MatchX"); auto* label = ctx.Input("Label"); auto* p_dx = dx->mutable_data(ctx.GetPlace()); - auto* p_y = y->data(); auto* p_dy = dy->data(); + auto* p_match_x = match_x->data(); auto* p_label = label->data(); int64_t ignore_index = ctx.Attr("ignore_index"); @@ -214,7 +253,7 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel { ctx.template device_context(), batch_size * feature_size); for_range(HardLabelCrossEntropyBackwardFunctor( - p_dx, p_y, p_dy, p_label, ignore_index, feature_size)); + p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size)); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 67ff1425ac8df9f75275378f5e365655fedd4aba..9886f4e84c2ac9afba92ac7f98284b5a439c70b3 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1450,11 +1450,13 @@ def cross_entropy2(input, label, ignore_index=kIgnoreIndex): helper = LayerHelper('cross_entropy2', **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) xshape = helper.create_variable_for_type_inference(dtype=input.dtype) + match_x = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( type='cross_entropy2', inputs={'X': [input], 'Label': [label]}, outputs={'Y': [out], + 'MatchX': [match_x], 'XShape': [xshape]}, attrs={'ignore_index': ignore_index}) return out