提交 513bc997 编写于 作者: C caoying03

softmax with cross entropy as a cost operator.

上级 2070bc93
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
:A
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/softmax_op.h"
#include "paddle/operators/softmax_with_cross_entropy_op.h"
namespace paddle {
namespace operators {
class SoftmaxWithLossOp : public framework::OperatorWithKernel {
class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto logits = ctx.Input<Tensor>("logits");
PADDLE_ENFORCE(logits->dims().size() == 2UL,
"The input of softmax_with_loss_op should be a 2-d tensor.");
PADDLE_ENFORCE(
logits->dims().size() == 2UL,
"The input of softmax_with_cross_entropy should be a 2-d tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("lables")->dims().size() == 1UL,
"The label should be a 1-d tensor.");
ctx.Output<Tensor>("loss")->Resize({logits->dims()[0]});
ctx.Output<Tensor>("Y")->Resize({logits->dims()[0]});
}
};
class SoftmaxWithLossOpMaker : public framework::OpProtoAndCheckerMaker {
class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SoftmaxWithLossOpMaker(framework::OpProto *proto,
SoftmaxWithCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("logits",
"The unscaled log probabilities which is a 2-D tensor<float> with"
"shape [N x K]. N is the batch_size, and K is the class number.");
AddInput("label", "The ground truth. A 1-D tensor<int> with shape N.");
AddOutput("loss", "A 1-D tensor<float> with shape N.");
AddOutput("Y", "A 1-D tensor<float> with shape N.");
AddComment(R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
......@@ -60,12 +61,28 @@ and only one label.
}
};
class SoftmaxWithLossOpGrad : public framework::OperatorWithKernel {
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Y")->dims(),
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"Input(Y) and its gradients should have a same shape.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("labels"),
"Input(lables) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("logits")),
"Input(logits@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(
ctx.Input<Tensor>("logits")->dims(),
ctx.Input<Tensor>(framework::GradVarName("logits"))->dims(),
"Input(logits) and its gradients should have a same shape.");
}
};
} // namespace operators
......@@ -73,10 +90,13 @@ class SoftmaxWithLossOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxWithLossOp, ops::SoftmaxWithLossOpMaker,
softmax_grad, ops::SoftmaxWithLossOpGrad);
REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxWithLossKernel<paddle::platform::CPUPlace, float>);
softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
softmax_grad,
ops::SoftmaxWithLossGradKernel<paddle::platform::CPUPlace, float>);
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "softmax_with_cross_entropy_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
......@@ -25,13 +25,13 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class SoftmaxWithLossKernel : public framework::OpKernel {
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {}
};
template <typename Place, typename T>
class SoftmaxWithLossGradKernel : public framework::OpKernel {
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {}
};
......
......@@ -37,7 +37,7 @@ USE_OP(mul);
USE_OP(mean);
USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(softmax_with_loss);
USE_OP(softmax_with_cross_entropy);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_NO_KERNEL_OP(recurrent);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册