diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 0a6a0fd15c73330902552f7a9aa6339de24c1a18..75e8a989036f0b818687e1fec3e600bb90e86b22 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>( - num_samples, infer_width, inference_data, label_data, accuracy_data); + AccuracyCudaKernel<<< + 1, PADDLE_CUDA_NUM_THREADS, 0, + reinterpret_cast( + ctx.device_context()) + .stream()>>>(num_samples, infer_width, inference_data, label_data, + accuracy_data); } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index b11dc1472d153dd188a0b3553d6950774216a3fd..2e16201e74c153888594ebe6679fb0036734dad4 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + "Input(Label) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - if (ctx.Attr("soft_label")) { + if (ctx.Attr("softLabel")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "If Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "If Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); } ctx.Output("Y")->Resize({x->dims()[0], 1}); @@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); + "Input(Label) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) must not be null."); + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); auto dy = ctx.Input(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, + "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], - "The 1st dimension of Input(X) and Input(Y@Grad) must " + "The 1st dimension of Input(X) and Input(Y@Grad) should " "be equal."); PADDLE_ENFORCE_EQ(dy->dims()[1], 1, - "The 2nd dimension of Input(Y@Grad) must be 1."); - if (ctx.Attr("soft_label")) { + "The 2nd dimension of Input(Y@Grad) should be 1."); + if (ctx.Attr("softLabel")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "When Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "When Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); } auto dx = ctx.Output(framework::GradVarName("X")); @@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { CrossEntropyOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of CrossEntropyOp"); - AddInput("Label", "The second input of CrossEntropyOp"); - AddOutput("Y", "The output of CrossEntropyOp"); - AddAttr("soft_label", "Is soft label. Default zero.") + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "This input is a probability computed by the previous operator, " + "which is almost always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor, default Tensor), the ground truth which is " + "a 2-D tensor. " + "When softLabel is set to false, `Label` is a Tensor with shape " + "[N x 1]. " + "When softLabel is set to true, `Label` is a Tensor " + "with shape [N x K]."); + AddOutput("Y", + "(Tensor, default Tensor), a 2-D tensor " + "with shape [N x 1]. The cross entropy loss."); + AddAttr( + "softLabel", + "(bool, default false), a flag to indicate whether to interpretate " + "the given labels as soft labels.") .SetDefault(false); - AddComment(R"DOC( CrossEntropy Operator. It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: - soft_label = False, Label[i, 0] indicates the class index for sample i: + softLabel = false, Label[i, 0] indicates the class index for sample i: Y[i] = -log(X[i, Label[i]]) 2) Soft-label cross-entropy: - soft_label = True, Label[i, j] indicates the soft label of class j + softLabel = true, Label[i, j] indicates the soft label of class j for sample i: Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 1d6361a81472a49729958120c52060b1dff803f2..18e44d77c9f62b296dc57952e546f844670c7d57 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -tolerable_value(log(X[i * D + label[i]])); + Y[i] = -TolerableValue()(log(X[i * D + label[i]])); } } +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int N, const int D) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - T sum = static_cast(0); - for (int j = 0; j < D; j++) { - sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); - } - Y[i] = -sum; + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; } -// TODO(qingqing): make zero setting an common function. +// TODO(qingqing): make zero setting a common function. template -__global__ void zero(T* X, const int N) { +__global__ void Zero(T* X, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { X[i] = 0.0; @@ -71,13 +94,10 @@ template __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const T* label, const int N, const int D) { - // TOOD(qingqing): optimize for this kernel - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - for (int j = 0; j < D; ++j) { - int idx = i * D + j; - dX[idx] = -label[idx] * dY[i] / X[idx]; - } + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < N * D) { + int row_ids = ids / D; + dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } @@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - auto label = ctx.Input("Label"); + const Tensor* x = ctx.Input("X"); + const Tensor* label = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); - auto* x_data = x->data(); - y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); + const T* x_data = x->data(); + T* y_data = y->mutable_data(ctx.GetPlace()); - int n = x->dims()[0]; - int d = x->dims()[1]; - int block = 512; - int grid = (n + block - 1) / block; - // TODO(qingqing) launch kernel on specified stream - // base on ExecutionContext. - if (ctx.Attr("soft_label")) { + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + if (ctx.Attr("softLabel")) { auto* label_data = ctx.Input("Label")->data(); - SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, - d); + int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); + + SoftCrossEntropyKernel< + T><<( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, class_num); } else { auto* label_data = ctx.Input("Label")->data(); - CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + int block = 512; + int grid = (batch_size + block - 1) / block; + CrossEntropyKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, + batch_size, class_num); } } }; @@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); + + const Tensor* x = ctx.Input("X"); + const Tensor* label = ctx.Input("Label"); + Tensor* dx = ctx.Output(framework::GradVarName("X")); - auto x = ctx.Input("X"); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("Label"); + const T* dy_data = + ctx.Input(framework::GradVarName("Y"))->data(); + T* dx_data = dx->mutable_data(ctx.GetPlace()); + const T* x_data = x->data(); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); - auto* x_data = x->data(); + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; - int n = x->dims()[0]; - int d = x->dims()[1]; int block = 512; - int grid = (n * d + block - 1) / block; - zero<<>>(dx_data, n * d); - grid = (n + block - 1) / block; - // TODO(qingqing): launch kernel on specified stream - // base on ExecutionContext. - if (ctx.Attr("soft_label")) { + int grid = (batch_size * class_num + block - 1) / block; + + if (ctx.Attr("softLabel")) { auto* label_data = label->data(); - SoftCrossEntropyGradientKernel<<>>( - dx_data, dy_data, x_data, label_data, n, d); + SoftCrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + batch_size, class_num); } else { + Zero<<( + ctx.device_context()) + .stream()>>>(dx_data, batch_size * class_num); + auto* label_data = label->data(); - CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, - label_data, n, d); + grid = (batch_size + block - 1) / block; + CrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + batch_size, class_num); } } }; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 69caba5ff31f60df2c24cef0e6331f058f6ba8d6..255b2e9f5ea7566cca7fd3914e38da804b7c7006 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/platform/hostdevice.h" @@ -20,53 +21,51 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; template -HOSTDEVICE T tolerable_value(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; } - if (x == -INFINITY) { - return -kApproInf; - } - return x; -} +}; template class CrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - - auto* x_data = x->data(); - y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); - - int batch_size = x->dims()[0]; - int class_num = x->dims()[1]; - - if (ctx.Attr("soft_label")) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - T sum = static_cast(0); - for (int j = 0; j < class_num; ++j) { - sum += label_data[index] * tolerable_value(std::log(x_data[index])); - y_data[i] = -sum; - index++; - } - } + "This kernel only runs on CPU."); + const Tensor* x = ctx.Input("X"); + const Tensor* labels = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); + T* y_data = y->mutable_data(ctx.GetPlace()); + + const int batch_size = x->dims()[0]; + if (ctx.Attr("softLabel")) { + auto prob = EigenMatrix::From(*x); + auto lbl_mat = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*y); + + loss.device(ctx.GetEigenDevice()) = + -((lbl_mat * prob.log().unaryExpr(TolerableValue())) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); } else { - auto* label_data = ctx.Input("Label")->data(); + const int class_num = x->dims()[1]; + const T* x_data = x->data(); + + const int* label_data = labels->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - y_data[i] = -tolerable_value(std::log(x_data[index])); + y_data[i] = -TolerableValue()(std::log(x_data[index])); } } } @@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto x = ctx.Input("X"); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("Label"); + "This kernel only runs on CPU."); + const Tensor* x = ctx.Input("X"); + const Tensor* dy = ctx.Input(framework::GradVarName("Y")); + const Tensor* label = ctx.Input("Label"); + Tensor* dx = ctx.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); - auto* x_data = x->data(); - - int batch_size = x->dims()[0]; int class_num = x->dims()[1]; - - // TODO(qingqing): make zero setting an common function. - if (ctx.Attr("soft_label")) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - for (int j = 0; j < class_num; ++j) { - dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; - index++; - } - } + if (ctx.Attr("softLabel")) { + auto x_mat = EigenMatrix::From(*x); + auto dy_mat = EigenMatrix::From(*dy); + auto lbl_mat = EigenMatrix::From(*label); + auto dx_mat = EigenMatrix::From(*dx); + + dx_mat.device(ctx.GetEigenDevice()) = + -(lbl_mat * dy_mat.broadcast(Eigen::DSizes(1, class_num)) / + x_mat); } else { - auto* label_data = label->data(); + int batch_size = x->dims()[0]; + const T* dy_data = dy->data(); + const T* x_data = x->data(); + const int* label_data = label->data(); + + // TODO(qingqing): make zero setting a common function. memset(dx_data, 0, sizeof(T) * batch_size * class_num); + for (int i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); int index = i * class_num + label_data[i]; diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 708344046760691aa2da562eb1ee3d8b130c5f18..62f63b4f3c876e084e2468001e8bcb9310d16a82 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTable<<>>(output, table, ids, N, K, D); + LookupTable<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(output, table, ids, N, K, D); } }; @@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGrad<<>>(d_table, d_output, ids, N, - K, D); + LookupTableGrad<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(d_table, d_output, ids, N, K, D); } }; diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu index afe4d149c53819c45e20353bc9d16393f3f61e0f..53fe505b77bfac8a33803f082f8e935d3ed403b6 100644 --- a/paddle/operators/top_k_op.cu +++ b/paddle/operators/top_k_op.cu @@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel { // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. - // TODO(typhoonzero): launch kernel on specified stream. // TODO(typhoonzero): refine this kernel. dim3 threads(256, 1); dim3 grid(input_height, 1); - KeMatrixTopK<<>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, int(k)); + KeMatrixTopK<<< + grid, threads, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(output_data, output->dims()[1], + indices_data, input_data, + input_width, input_width, int(k)); } }; diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index f10db783225c07be9ffde25267fdfe096e97ecac..1de514dff487158e0823fd628d9b3b50f36fdd9b 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -4,22 +4,24 @@ from op_test import OpTest class TestCrossEntropyOp1(OpTest): - """Test standard cross-entropy, with index representation of labels. + """Test cross-entropy with discrete one-hot labels. """ def setUp(self): self.op_type = "cross_entropy" batch_size = 30 class_num = 10 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") cross_entropy = np.asmatrix( [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], dtype="float32") + self.inputs = {"X": X, "Label": label} self.outputs = {"Y": cross_entropy} - self.attrs = {'soft_label': False} + self.attrs = {"softLabel": False} def test_check_output(self): self.check_output() @@ -29,13 +31,14 @@ class TestCrossEntropyOp1(OpTest): class TestCrossEntropyOp2(OpTest): - """Test soft-label cross-entropy, with vecterized soft labels. + """Test cross-entropy with vectorized soft labels. """ def setUp(self): self.op_type = "cross_entropy" - batch_size = 10 - class_num = 5 + batch_size = 5 + class_num = 37 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = np.random.uniform(0.1, 1.0, @@ -43,46 +46,49 @@ class TestCrossEntropyOp2(OpTest): label /= label.sum(axis=1, keepdims=True) cross_entropy = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"softLabel": True} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y", max_relative_error=0.05) class TestCrossEntropyOp3(OpTest): - """Test one-hot cross-entropy, with vecterized one-hot representation of - labels. + """Test cross-entropy with vectorized one-hot representation of labels. """ def setUp(self): self.op_type = "cross_entropy" - batch_size = 30 - class_num = 10 + batch_size = 5 + class_num = 17 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label_index = np.random.randint( 0, class_num, (batch_size), dtype="int32") label = np.zeros(X.shape) label[np.arange(batch_size), label_index] = 1 + cross_entropy = np.asmatrix( [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], dtype="float32") cross_entropy2 = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"softLabel": True} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y", max_relative_error=0.05) if __name__ == "__main__":