diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 4cf4e8e2be2fd7f18079406b838d2757317c4ffb..b0e1b8e41a5320aa14e316a56dbfd01e43c6816b 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -36,6 +36,17 @@ class OnehotCrossEntropyOp : public OperatorWithKernel { } }; +class OnehotCrossEntropyGradientOp : public OperatorWithKernel { + protected: + void InferShape(const InferShapeContext &ctx) const override { + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + // TODO(superjom) add enforce here after helper functions ready + X_grad->Resize(X->dims()); + } +}; + class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker { public: OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -58,3 +69,7 @@ REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); + +REGISTER_OP_CPU_KERNEL( + onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 7f7fb8d269df7eb46086ae143929d0bf6b808575..88d06e13469f8e6fc9e634d804c1fe0bed5e2d75 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -18,28 +18,53 @@ limitations under the License. */ namespace paddle { namespace operators { +static const float kCrossEntropyLogThreshold{1e-20}; + template class OnehotCrossEntropyOpKernel : public OpKernel { public: - constexpr T LOG_THRESHOLD() const { return static_cast(1e-20); } - void Compute(const ExecutionContext& ctx) const override { - auto X = ctx.Input(0); - const T* X_data = X->data(); + auto X = ctx.Input("X"); + const T* Xdata = X->data(); const int* label_data = ctx.Input(1)->data(); - auto Y = ctx.Output(0); + auto Y = ctx.Output("Y"); Y->mutable_data(ctx.GetPlace()); - T* Y_data = Y->data(); + T* Ydata = Y->data(); int batch_size = X->dims()[0]; int class_num = X->dims()[1]; // Y[i] = -log(X[i][j]) for (int i = 0; i < batch_size; ++i) { - Y_data[i] = -std::log( - std::max(X_data[i * class_num + label_data[i]], LOG_THRESHOLD())); + Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]], + kCrossEntropyLogThreshold)); + } + } +}; + +template +class OnehotCrossEntropyGradientOpKernel : public OpKernel { + public: + void Compute(const ExecutionContext& ctx) const override { + auto X = ctx.Input("X"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("label"); + + auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + auto* dYdata = dY->template data(); + auto* Xdata = X->template data(); + auto* label_data = label->data(); + + const int batch_size = X->dims()[0]; + const int class_num = X->dims()[1]; + + for (int i = 0; i < batch_size; ++i) { + dXdata[i * class_num + label_data[i]] = + -dYdata[i] / std::max(Xdata[i * class_num + label_data[i]], + kCrossEntropyLogThreshold); } } }; diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 609c56535ef0365dda728cba334d8b4d96312192..6d022f6bc0be60dbf2f796780a969bff0e8bfded 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -18,5 +18,7 @@ class TestSGD(unittest.TestCase): self.Y = numpy.array(Y).astype("float32") +# TODO(superjom) add gradient check + if __name__ == "__main__": unittest.main()