提交 9e25988f 编写于 作者: D dongzhihong

"net op alias"

上级 610801b5
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, PADDLE_ENFORCE(ctx.InputSize() == 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, PADDLE_ENFORCE(ctx.OutputSize() == 1,
...@@ -36,9 +36,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel { ...@@ -36,9 +36,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel {
} }
}; };
class OnehotCrossEntropyGradientOp : public OperatorWithKernel { class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
...@@ -47,9 +47,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel { ...@@ -47,9 +47,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
} }
}; };
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker { class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker) OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp"); AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp"); AddInput("label", "The second input of OnehotCrossEntropyOp");
...@@ -65,11 +66,12 @@ OnehotCrossEntropy Operator. ...@@ -65,11 +66,12 @@ OnehotCrossEntropy Operator.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, REGISTER_OP_CPU_KERNEL(
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>); onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>); ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -15,5 +15,7 @@ ...@@ -15,5 +15,7 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, namespace ops = paddle::operators;
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);
...@@ -13,17 +13,19 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/type_alias.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
static const float kCrossEntropyLogThreshold{1e-20}; static const float kCrossEntropyLogThreshold{1e-20};
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel { class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>(); const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>(); const int* label_data = ctx.Input<Tensor>(1)->data<int>();
...@@ -45,9 +47,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel { ...@@ -45,9 +47,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyGradientOpKernel : public OpKernel { class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y")); auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
*/ */
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,13 +14,7 @@ limitations under the License. */ ...@@ -14,13 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册