提交 5b42d2b2 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #4081 from xinghai-sun/soft_label_cross_entropy

Add soft-label support for cross-entropy operator.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cross_entropy_op.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
class CrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal.");
if (ctx.Attr<int>("soft_label") == 1) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1.");
}
ctx.Output<LoDTensor>("Y")->Resize({x->dims()[0], 1});
}
};
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) must not be null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
"The 1st dimension of Input(X) and Input(Y@Grad) must "
"be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1.");
if (ctx.Attr<int>("soft_label") == 1) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1.");
}
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
dx->Resize(x->dims());
}
};
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0);
AddComment(R"DOC(
CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
soft_label = 0, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
soft_label = 1, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Please make sure that in this case the summuation of each row of Label
equals one.
3) One-hot cross-entropy with vecterized Input(Label):
As a special case of 2), when each row of Input(Label) has only one
non-zero element (equals 1), soft-label cross-entropy degenerates to a
one-hot cross-entropy with one-hot label representation.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<float>);
......@@ -13,27 +13,13 @@
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__host__ __device__ T clipping_log(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
T v = log(x);
if (v == INFINITY) {
return kApproInf;
}
if (v == -INFINITY) {
return -kApproInf;
}
return v;
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
......@@ -42,7 +28,20 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log(X[i * D + label[i]]);
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
}
}
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int N, const int D) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) {
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
}
Y[i] = -sum;
}
}
......@@ -69,57 +68,84 @@ __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
}
template <typename T>
class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N,
const int D) {
// TOOD(qingqing): optimize for this kernel
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
for (int j = 0; j < D; ++j) {
int idx = i * D + j;
dX[idx] = -label[idx] * dY[i] / X[idx];
}
}
}
template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace());
T* Ydata = Y->data<T>();
auto x = ctx.Input<Tensor>("X");
auto y = ctx.Output<Tensor>("Y");
auto label = ctx.Input<Tensor>("Label");
int N = X->dims()[0];
int D = X->dims()[1];
auto* x_data = x->data<T>();
y->mutable_data<T>(ctx.GetPlace());
auto* y_data = y->data<T>();
int n = x->dims()[0];
int d = x->dims()[1];
int block = 512;
int grid = (N + block - 1) / block;
int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d);
} else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
}
}
};
template <typename T>
class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("label");
auto x = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("Label");
auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
auto* dYdata = dY->template data<T>();
auto* Xdata = X->template data<T>();
auto* label_data = label->data<int>();
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
auto* x_data = x->data<T>();
int N = X->dims()[0];
int D = X->dims()[1];
int n = x->dims()[0];
int d = x->dims()[1];
int block = 512;
int grid = (N * D + block - 1) / block;
zero<T><<<grid, block>>>(dXdata, N * D);
grid = (N + block - 1) / block;
int grid = (n * d + block - 1) / block;
zero<T><<<grid, block>>>(dx_data, n * d);
grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel<T><<<grid, block>>>(dXdata, dYdata, Xdata,
label_data, N, D);
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d);
} else {
auto* label_data = label->data<int>();
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
label_data, n, d);
}
}
};
......@@ -127,7 +153,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
......@@ -21,75 +22,93 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
HOSTDEVICE T tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
template <typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
class CrossEntropyOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace());
T* Ydata = Y->data<T>();
int batch_size = X->dims()[0];
int class_num = X->dims()[1];
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
Ydata[i] = -tolerable_value(std::log(Xdata[index]));
auto x = ctx.Input<Tensor>("X");
auto y = ctx.Output<Tensor>("Y");
auto* x_data = x->data<T>();
y->mutable_data<T>(ctx.GetPlace());
auto* y_data = y->data<T>();
int batch_size = x->dims()[0];
int class_num = x->dims()[1];
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0;
for (int i = 0; i < batch_size; ++i) {
T sum = static_cast<T>(0);
for (int j = 0; j < class_num; ++j) {
sum += label_data[index] * tolerable_value(std::log(x_data[index]));
y_data[i] = -sum;
index++;
}
}
} else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
y_data[i] = -tolerable_value(std::log(x_data[index]));
}
}
}
};
template <typename T>
class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
class CrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("label");
auto x = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("Label");
auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
auto* dYdata = dY->template data<T>();
auto* Xdata = X->template data<T>();
auto* label_data = label->data<int>();
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
auto* x_data = x->data<T>();
const int batch_size = X->dims()[0];
const int class_num = X->dims()[1];
int batch_size = x->dims()[0];
int class_num = x->dims()[1];
// TODO(qingqing): make zero setting an common function.
memset(dXdata, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0;
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < class_num; ++j) {
dx_data[index] = -label_data[index] * dy_data[i] / x_data[index];
index++;
}
}
} else {
auto* label_data = label->data<int>();
memset(dx_data, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
int index = i * class_num + label_data[i];
dx_data[index] = -dy_data[i] / x_data[index];
}
}
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/onehot_cross_entropy_op.h"
namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("X"),
"Input(X) of OnehotCrossEntropyOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("label"),
"Input(label) of OnehotCrossEntropyOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Y"),
"Output(Y) of OnehotCrossEntropyOp should not be null.");
auto *X = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("label");
PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
ctx.Output<framework::LoDTensor>("Y")->Resize({X->dims()[0], 1});
}
};
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dX = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
dX->Resize(X->dims());
}
};
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp");
AddComment(R"DOC(
OnehotCrossEntropy Operator.
Y[i] = -log(X[i][j])
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<float>);
import unittest
import numpy as np
from op_test import OpTest
class TestCrossEntropyOp1(OpTest):
"""Test standard cross-entropy, with index representation of labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float32")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': 0}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y")
class TestCrossEntropyOp2(OpTest):
"""Test soft-label cross-entropy, with vecterized soft labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 10
class_num = 5
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
class TestCrossEntropyOp3(OpTest):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label_index = np.random.randint(
0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape)
label[np.arange(batch_size), label_index] = 1
cross_entropy = np.asmatrix(
[[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])],
dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
if __name__ == "__main__":
unittest.main()
......@@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
def cross_entropy_layer(net, input, label):
cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator(
"onehot_cross_entropy", X=input, label=label, Y=cost_name)
"cross_entropy", X=input, Label=label, Y=cost_name)
net.append_op(cross_entropy_op)
scope.new_var(cost_name)
net.infer_shape(scope)
......@@ -181,7 +181,7 @@ def error_rate(predict, label):
images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
labels = data_layer(name="label", dims=[BATCH_SIZE])
labels = data_layer(name="label", dims=[BATCH_SIZE, 1])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
......@@ -215,6 +215,7 @@ def test(cost_name):
for data in test_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
......@@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM):
for data in train_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
......
import unittest
import numpy
from op_test import OpTest
class TestOnehotCrossEntropyOp(OpTest):
def setUp(self):
self.op_type = "onehot_cross_entropy"
batch_size = 30
class_num = 10
X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
cross_entropy = numpy.asmatrix(
[[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
dtype="float32")
self.inputs = {"X": X, "label": labels}
self.outputs = {"Y": cross_entropy}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册