diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index b0e1b8e41a5320aa14e316a56dbfd01e43c6816b..b5a9ca271360c607a177f5a9717d56dbb5867e20 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -17,9 +17,9 @@ limitations under the License. */ namespace paddle { namespace operators { -class OnehotCrossEntropyOp : public OperatorWithKernel { +class OnehotCrossEntropyOp : public framework::OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override { + void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of OnehotCrossEntropyOp must be two"); PADDLE_ENFORCE(ctx.OutputSize() == 1, @@ -36,9 +36,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel { } }; -class OnehotCrossEntropyGradientOp : public OperatorWithKernel { +class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override { + void InferShape(const framework::InferShapeContext &ctx) const override { auto X_grad = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); @@ -47,9 +47,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel { } }; -class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker { +class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: - OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker) + OnehotCrossEntropyOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The first input of OnehotCrossEntropyOp"); AddInput("label", "The second input of OnehotCrossEntropyOp"); @@ -65,11 +66,12 @@ OnehotCrossEntropy Operator. } // namespace operators } // namespace paddle +namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); - +REGISTER_OP_CPU_KERNEL( + onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); REGISTER_OP_CPU_KERNEL( onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 2f453f8379ca7ce0612fed757719acb2d2cf0ad8..4bbc8f093a794d46737a16488684a6a0cc25e285 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -15,5 +15,7 @@ #define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 88d06e13469f8e6fc9e634d804c1fe0bed5e2d75..15907158721e063cdd0caf1b51a314b2f120f8cb 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/operators/type_alias.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + static const float kCrossEntropyLogThreshold{1e-20}; template -class OnehotCrossEntropyOpKernel : public OpKernel { +class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: - void Compute(const ExecutionContext& ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { auto X = ctx.Input("X"); const T* Xdata = X->data(); const int* label_data = ctx.Input(1)->data(); @@ -45,9 +47,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel { }; template -class OnehotCrossEntropyGradientOpKernel : public OpKernel { +class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { public: - void Compute(const ExecutionContext& ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { auto X = ctx.Input("X"); auto dX = ctx.Output(framework::GradVarName("X")); auto dY = ctx.Input(framework::GradVarName("Y")); diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index fbc98e09923bda7f3baee04e02df9076247bff0b..a466c4f30fe87db4ad2a44518e083b57f3cbc2ed 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -15,7 +15,6 @@ */ #include "paddle/operators/net_op.h" -#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 6e7af7f02ae23ec65459dfd15d950a43e96fec4d..3342f40f51f1fbd8ac406cf88834d8c1d53fb57d 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -14,13 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/scope.h" -#include "paddle/operators/type_alias.h" -#include "paddle/platform/device_context.h" namespace paddle { namespace operators { diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index c0a345464a34329d42c7bf753ca94fd07195b8e0..f823f36234051330c7395026220ca59ac9985944 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -2,9 +2,6 @@ #include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" - namespace paddle { namespace operators {