提交 6f4bf505 编写于 作者: C caoying03

fix softmax with cross entropy op.

上级 d05c182e
...@@ -114,21 +114,17 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -114,21 +114,17 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"where N is the batch size and D is the number of classes. " "where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, " "This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator."); "which is almost always the result of a softmax operator.");
AddInput( AddInput("Label",
"Label", "(Tensor), the ground truth which is a 2-D tensor. When "
"(Tensor, default Tensor<int>), the ground truth which is " "soft_label is set to false, Label is a Tensor<int64> with shape "
"a 2-D tensor. " "[N x 1]. When soft_label is set to true, Label is a "
"When soft_label is set to false, Label is a Tensor<int> with shape " "Tensor<float/double> with shape [N x K].");
"[N x 1]. "
"When soft_label is set to true, Label is a Tensor<float/double> "
"with shape [N x K].");
AddOutput("Y", AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor " "(Tensor, default Tensor<float>), a 2-D tensor with shape "
"with shape [N x 1]. The cross entropy loss."); "[N x 1]. The cross entropy loss.");
AddAttr<bool>( AddAttr<bool>("soft_label",
"soft_label", "(bool, default false), a flag indicating whether to "
"(bool, default false), a flag to indicate whether to interpretate " "interpretate the given labels as soft labels.")
"the given labels as soft labels.")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
CrossEntropy Operator. CrossEntropy Operator.
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/softmax_with_cross_entropy_op.h" #include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h> #include <paddle/function/TensorType.h>
...@@ -30,12 +30,10 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -30,12 +30,10 @@ class SoftmaxWithCrossEntropyOpMaker
"which is a 2-D tensor with shape [N x K]. N is the batch_size, " "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."); "and K is the class number.");
AddInput("Label", AddInput("Label",
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D " "(Tensor) The ground truth which is a 2-D tensor. If soft_label "
"tensor. " "is set to false, Label is a Tensor<int64> with shape [N x 1]. If "
"If softLabel is set to false, Label is a Tensor<int> with shape " "soft_label is set to true, Label is a Tensor<float/double> with "
"[N x 1]." "shape [N x K].");
"If softLabel is set to true, Label is a Tensor<float/double> "
"with shape [N x K].");
AddOutput( AddOutput(
"Softmax", "Softmax",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. " "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
...@@ -62,7 +60,7 @@ Because this operator performs a softmax on logits internally, it expects ...@@ -62,7 +60,7 @@ Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results. softmax operator since that would produce incorrect results.
When the attribute softLabel is set false, this operators expects mutually When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label. probability of 1.0. Each sample in the batch will have a single label.
...@@ -198,6 +196,8 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, ...@@ -198,6 +196,8 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
REGISTER_OPERATOR(softmax_with_cross_entropy_grad, REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad); ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>); ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<float>); ops::SoftmaxWithCrossEntropyGradKernel<float>,
ops::SoftmaxWithCrossEntropyGradKernel<double>);
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
...@@ -24,7 +24,7 @@ using Tensor = framework::Tensor; ...@@ -24,7 +24,7 @@ using Tensor = framework::Tensor;
namespace { namespace {
template <typename T> template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad, __global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
const int* labels, const int batch_size, const int64_t* labels, const int batch_size,
const int class_num) { const int class_num) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = tid / class_num; int sample_idx = tid / class_num;
...@@ -50,7 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, ...@@ -50,7 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
int ids = blockIdx.x * blockDim.x + threadIdx.x; int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < batch_size * class_num) { if (ids < batch_size * class_num) {
int row_ids = ids / class_num; int row_ids = ids / class_num;
logit_grad[ids] = logit_grad[ids] * (loss_grad[row_ids] - labels[ids]); logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]);
} }
} }
} // namespace } // namespace
...@@ -104,7 +104,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -104,7 +104,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
.stream()>>>(logit_grad_data, loss_grad_data, .stream()>>>(logit_grad_data, loss_grad_data,
label_data, batch_size, class_num); label_data, batch_size, class_num);
} else { } else {
const int* label_data = labels->data<int>(); const int64_t* label_data = labels->data<int64_t>();
CrossEntropyGrad<T><<< CrossEntropyGrad<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context()) context.device_context())
...@@ -119,6 +119,8 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -119,6 +119,8 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyCUDAKernel<float>); ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad, REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>); ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
...@@ -60,25 +60,25 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -60,25 +60,25 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax")); logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
const int class_num = logit_grad->dims()[1]; const int class_num = logit_grad->dims()[1];
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
if (context.Attr<bool>("soft_label")) { if (context.Attr<bool>("soft_label")) {
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
auto lbl_mat = EigenMatrix<T>::From(*labels); auto lbl_mat = EigenMatrix<T>::From(*labels);
logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) = logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
logit_grad_mat * out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) *
(out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) - (logit_grad_mat - lbl_mat);
lbl_mat);
} else { } else {
logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num));
const int batch_size = logit_grad->dims()[0]; const int batch_size = logit_grad->dims()[0];
const int* label_data = labels->data<int>(); const int64_t* label_data = labels->data<int64_t>();
const T* out_grad_data = out_grad->data<T>();
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; logit_grad_data[i * class_num + label_data[i]] -= out_grad_data[i];
logit_grad_data[index] =
out_grad_data[i] * (logit_grad_data[index] - 1.);
} }
} }
} }
......
...@@ -12,30 +12,30 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -12,30 +12,30 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = 3 batch_size = 2
class_num = 37 class_num = 37
logits = np.random.uniform(0.1, 1.0, logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 1, logits) softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int32") labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64")
cross_entropy = np.asmatrix( cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i][0]])] [[-np.log(softmax[i][labels[i][0]])]
for i in range(softmax.shape[0])], for i in range(softmax.shape[0])],
dtype="float32") dtype="float64")
self.inputs = {"Logits": logits, "Label": labels} self.inputs = {"Logits": logits, "Label": labels}
self.outputs = { self.outputs = {
"Softmax": softmax.astype('float32'), "Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype('float32') "Loss": cross_entropy.astype("float64")
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.05) self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOp2(OpTest): class TestSoftmaxWithCrossEntropyOp2(OpTest):
...@@ -49,19 +49,19 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): ...@@ -49,19 +49,19 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
class_num = 37 class_num = 37
logits = np.random.uniform(0.1, 1.0, logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 1, logits) softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.uniform(0.1, 1.0, labels = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float64")
labels /= np.sum(labels, axis=1, keepdims=True) labels /= np.sum(labels, axis=1, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum( cross_entropy = (-labels * np.log(softmax)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float64")
self.inputs = {"Logits": logits, "Label": labels} self.inputs = {"Logits": logits, "Label": labels}
self.outputs = { self.outputs = {
"Softmax": softmax.astype('float32'), "Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype('float32') "Loss": cross_entropy.astype("float64")
} }
self.attrs = {"soft_label": True} self.attrs = {"soft_label": True}
...@@ -69,9 +69,8 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): ...@@ -69,9 +69,8 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.05) self.check_grad(["Logits"], "Loss")
if __name__ == "__main__": if __name__ == "__main__":
exit(0) # FIXME: xe has bug
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册