提交 9e25988f 编写于 作者: D dongzhihong

"net op alias"

上级 610801b5
......@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1,
......@@ -36,9 +36,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
......@@ -47,9 +47,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
......@@ -65,11 +66,12 @@ OnehotCrossEntropy Operator.
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>);
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
......@@ -15,5 +15,7 @@
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);
......@@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static const float kCrossEntropyLogThreshold{1e-20};
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel {
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>();
......@@ -45,9 +47,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
};
template <typename Place, typename T>
class OnehotCrossEntropyGradientOpKernel : public OpKernel {
class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
......
......@@ -15,7 +15,6 @@
*/
#include "paddle/operators/net_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
......
......@@ -14,13 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......
......@@ -2,9 +2,6 @@
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册