提交 6735585b 编写于 作者: C caoying03

fix cpu kernel with soft labels.

上级 30bfaab3
...@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { ...@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return; return;
} }
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>( AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
num_samples, infer_width, inference_data, label_data, accuracy_data); 1, PADDLE_CUDA_NUM_THREADS, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(num_samples, infer_width, inference_data, label_data,
accuracy_data);
} }
}; };
......
...@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null."); "Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) should be not null.");
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label"); auto label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2, PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2."); "Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, The 2nd dimension of " "If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) must be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, The 2nd dimension of " "If Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) must be 1."); "Input(Label) should be 1.");
} }
ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1}); ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
...@@ -57,35 +58,36 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -57,35 +58,36 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null."); "Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) must not be null."); "Input(Y@GRAD) shoudl be not null.");
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label"); auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y")); auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
"Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2, PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2."); "Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
"The 1st dimension of Input(X) and Input(Y@Grad) must " "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1, PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1."); "The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, The 2nd dimension of " "When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) must be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, The 2nd dimension of " "When Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) must be 1."); "Input(Label) should be 1.");
} }
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -98,12 +100,26 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,12 +100,26 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
CrossEntropyOpMaker(framework::OpProto *proto, CrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of CrossEntropyOp"); AddInput("X",
AddInput("Label", "The second input of CrossEntropyOp"); "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
AddOutput("Y", "The output of CrossEntropyOp"); "where N is the batch size and D is the number of classes. "
AddAttr<bool>("soft_label", "Is soft label. Default zero.") "This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator.");
AddInput("Label",
"(Tensor, default Tensor<int>), the ground truth which is "
"a 1-D or 2-D tensor. "
"When soft_label is set to 0, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When soft_label is set to 1, `Label` is a Tensor<float/double> "
"with shape [N x K].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 1-D tensor "
"with shape [N x 1]. The cross entropy loss.");
AddAttr<bool>(
"soft_label",
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
CrossEntropy Operator. CrossEntropy Operator.
......
...@@ -32,37 +32,71 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, ...@@ -32,37 +32,71 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
} }
} }
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}
// This kernel is called when the class number is less than or equal to 512.
template <typename T>
__global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
// This kernel is called when the class number is larger than 512.
template <typename T, int BlockSize> template <typename T, int BlockSize>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, __global__ void SoftCrossEntropyKernel2(T* Y, const T* X, const T* label,
const int N, const int D) { const int class_num) {
int tid = threadIdx.x; int tid = threadIdx.x;
__shared__ T d_sum[BlockSize]; __shared__ T d_sum[BlockSize];
int next_idx = blockIdx.x * D + tid; int next_idx = blockIdx.x * class_num + tid;
d_sum[tid] = 0; d_sum[tid] = 0;
int cur_idx = tid; int cur_idx = tid;
while (cur_idx < D) { while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx]; d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += BlockSize; next_idx += BlockSize;
cur_idx += BlockSize; cur_idx += BlockSize;
} }
__syncthreads(); __syncthreads();
for (int stride = BlockSize >> 1; stride > 0; stride >>= 1) { for (unsigned int stride = BlockSize >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads(); __syncthreads();
if (tid < stride) {
next_idx = tid + stride;
d_sum[tid] += d_sum[next_idx];
}
} }
__syncthreads();
if (tid == 0) { T val = d_sum[tid];
Y[blockIdx.x] = -d_sum[0]; val = sum_single_warp<T>(val);
} if (tid == 0) Y[blockIdx.x] = -val;
} }
// TODO(qingqing): make zero setting an common function. // TODO(qingqing): make zero setting a common function.
template <typename T> template <typename T>
__global__ void zero(T* X, const int N) { __global__ void zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
...@@ -88,11 +122,9 @@ template <typename T> ...@@ -88,11 +122,9 @@ template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N, const T* label, const int N,
const int D) { const int D) {
int row_ids = blockIdx.x * blockDim.x + threadIdx.x; int ids = blockIdx.x * blockDim.x + threadIdx.x;
int col_ids = blockIdx.y * blockDim.y + threadIdx.y;
int ids = row_ids * D + col_ids;
if (ids < N * D) { if (ids < N * D) {
int row_ids = ids / D;
dX[ids] = -label[ids] * dY[row_ids] / X[ids]; dX[ids] = -label[ids] * dY[row_ids] / X[ids];
} }
} }
...@@ -112,20 +144,34 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -112,20 +144,34 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
auto* y_data = y->data<T>(); auto* y_data = y->data<T>();
int n = x->dims()[0]; int batch_size = x->dims()[0];
int d = x->dims()[1]; int class_num = x->dims()[1];
int block = 512; int block = 512;
int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
grid = d; if (class_num > 512) {
SoftCrossEntropyKernel<T, 512><<<grid, block>>>(y_data, x_data, SoftCrossEntropyKernel2<
label_data, n, d); T, 512><<<batch_size, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
} else {
int block_size = pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel1<
T><<<batch_size, block_size, block_size * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
}
} else { } else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>(); auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d); int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data,
batch_size, class_num);
} }
} }
}; };
...@@ -148,25 +194,27 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { ...@@ -148,25 +194,27 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
int n = x->dims()[0]; int n = x->dims()[0];
int d = x->dims()[1]; int d = x->dims()[1];
int block = 512; int block = 512;
int grid = (n * d + block - 1) / block; int grid = (n * d + block - 1) / block;
zero<T><<<grid, block>>>(dx_data, n * d); zero<T><<<grid, block, 0,
grid = (n + block - 1) / block; reinterpret_cast<const platform::CUDADeviceContext&>(
// TODO(qingqing): launch kernel on specified stream ctx.device_context())
// base on ExecutionContext. .stream()>>>(dx_data, n * d);
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
int block_x = 32;
int block_y = 32;
dim3 block(block_x, block_y);
dim3 grid((n + block_x - 1) / block_x, (d + block_y - 1) / block_y);
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>( SoftCrossEntropyGradientKernel<T><<<
dx_data, dy_data, x_data, label_data, n, d); grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
n, d);
} else { } else {
auto* label_data = label->data<int>(); auto* label_data = label->data<int>();
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data, CrossEntropyGradientKernel<T><<<
label_data, n, d); grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
n, d);
} }
} }
}; };
......
...@@ -31,12 +31,8 @@ struct TolerableValue { ...@@ -31,12 +31,8 @@ struct TolerableValue {
PADDLE_ASSERT(std::is_floating_point<T>::value); PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20; const T kApproInf = 1e20;
if (x == INFINITY) { if (x == INFINITY) return kApproInf;
return kApproInf; if (x == -INFINITY) return -kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x; return x;
} }
}; };
...@@ -58,11 +54,8 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -58,11 +54,8 @@ class CrossEntropyOpKernel : public framework::OpKernel {
auto lbl_mat = EigenMatrix<T>::From(*labels); auto lbl_mat = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*y); auto loss = EigenMatrix<T>::From(*y);
// loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
// prob.log().unaryExpr(TolerableValue<T>());
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) = loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl_mat * prob.log()) -((lbl_mat * prob.log().unaryExpr(TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1)) .sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1))); .reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else { } else {
......
...@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel { ...@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel {
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D); LookupTable<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(output, table, ids, N, K, D);
} }
}; };
...@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { ...@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N, LookupTableGrad<T, 128, 8, 8><<<
K, D); grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(d_table, d_output, ids, N, K, D);
} }
}; };
......
...@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel { ...@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel {
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen. // NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): launch kernel on specified stream.
// TODO(typhoonzero): refine this kernel. // TODO(typhoonzero): refine this kernel.
dim3 threads(256, 1); dim3 threads(256, 1);
dim3 grid(input_height, 1); dim3 grid(input_height, 1);
KeMatrixTopK<T, 5, 256><<<grid, threads>>>( KeMatrixTopK<T, 5, 256><<<
output_data, output->dims()[1], indices_data, input_data, input_width, grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
input_width, int(k)); ctx.device_context())
.stream()>>>(output_data, output->dims()[1],
indices_data, input_data,
input_width, input_width, int(k));
} }
}; };
......
...@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest): ...@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype="float32") dtype="float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': False} self.attrs = {"soft_label": False}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -34,7 +34,8 @@ class TestCrossEntropyOp2(OpTest): ...@@ -34,7 +34,8 @@ class TestCrossEntropyOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 13 batch_size = 5
# this setting tests threads in more than one wrap.
class_num = 37 class_num = 37
X = np.random.uniform(0.1, 1.0, X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
...@@ -43,9 +44,9 @@ class TestCrossEntropyOp2(OpTest): ...@@ -43,9 +44,9 @@ class TestCrossEntropyOp2(OpTest):
label /= label.sum(axis=1, keepdims=True) label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum( cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {"X": X, "Label": label}
self.outputs = {'Y': cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': True} self.attrs = {"soft_label": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -61,8 +62,9 @@ class TestCrossEntropyOp3(OpTest): ...@@ -61,8 +62,9 @@ class TestCrossEntropyOp3(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 13 batch_size = 5
class_num = 37 # this setting tests all threads in one wrap.
class_num = 17
X = np.random.uniform(0.1, 1.0, X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label_index = np.random.randint( label_index = np.random.randint(
...@@ -74,9 +76,36 @@ class TestCrossEntropyOp3(OpTest): ...@@ -74,9 +76,36 @@ class TestCrossEntropyOp3(OpTest):
dtype="float32") dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum( cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {"X": X, "Label": label}
self.outputs = {'Y': cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': True} self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
class TestCrossEntropyOp4(OpTest):
"""Test soft-label cross-entropy.
This unittest tests the gpu kernel for layer size excesses 512.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 2
class_num = 517
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册