diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 0a6a0fd15c73330902552f7a9aa6339de24c1a18..75e8a989036f0b818687e1fec3e600bb90e86b22 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>( - num_samples, infer_width, inference_data, label_data, accuracy_data); + AccuracyCudaKernel<<< + 1, PADDLE_CUDA_NUM_THREADS, 0, + reinterpret_cast( + ctx.device_context()) + .stream()>>>(num_samples, infer_width, inference_data, label_data, + accuracy_data); } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index b11dc1472d153dd188a0b3553d6950774216a3fd..80f7b69c142eb02c64076c8724a16d09e31c72d0 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + "Input(Label) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); if (ctx.Attr("soft_label")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "If Attr(soft_label) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "If Attr(soft_label) == false, the 2nd dimension of " + "Input(Label) should be 1."); } ctx.Output("Y")->Resize({x->dims()[0], 1}); @@ -57,35 +58,36 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); + "Input(Label) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) must not be null."); + "Input(Y@GRAD) shoudl be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); auto dy = ctx.Input(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, + "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], - "The 1st dimension of Input(X) and Input(Y@Grad) must " + "The 1st dimension of Input(X) and Input(Y@Grad) should " "be equal."); PADDLE_ENFORCE_EQ(dy->dims()[1], 1, - "The 2nd dimension of Input(Y@Grad) must be 1."); + "The 2nd dimension of Input(Y@Grad) should be 1."); if (ctx.Attr("soft_label")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "When Attr(soft_label) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(Label) should be 1."); } auto dx = ctx.Output(framework::GradVarName("X")); @@ -98,12 +100,26 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { CrossEntropyOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of CrossEntropyOp"); - AddInput("Label", "The second input of CrossEntropyOp"); - AddOutput("Y", "The output of CrossEntropyOp"); - AddAttr("soft_label", "Is soft label. Default zero.") + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "This input is a probability computed by the previous operator, " + "which is almost always the result of a softmax operator."); + AddInput("Label", + "(Tensor, default Tensor), the ground truth which is " + "a 1-D or 2-D tensor. " + "When soft_label is set to 0, `Label` is a Tensor with shape " + "[N x 1]. " + "When soft_label is set to 1, `Label` is a Tensor " + "with shape [N x K]."); + AddOutput("Y", + "(Tensor, default Tensor), a 1-D tensor " + "with shape [N x 1]. The cross entropy loss."); + AddAttr( + "soft_label", + "(bool, default false), a flag to indicate whether to interpretate " + "the given labels as soft labels.") .SetDefault(false); - AddComment(R"DOC( CrossEntropy Operator. diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index d837f49126e91bc937e346a829c204bd0bcd38fa..3f34a2d52d68ca1f6f8f8656bb423f2ac40cbbff 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -32,37 +32,71 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, } } +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + +// This kernel is called when the class number is less than or equal to 512. +template +__global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label, + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); + } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; +} + +// This kernel is called when the class number is larger than 512. template -__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int N, const int D) { +__global__ void SoftCrossEntropyKernel2(T* Y, const T* X, const T* label, + const int class_num) { int tid = threadIdx.x; __shared__ T d_sum[BlockSize]; - int next_idx = blockIdx.x * D + tid; + int next_idx = blockIdx.x * class_num + tid; d_sum[tid] = 0; int cur_idx = tid; - while (cur_idx < D) { + while (cur_idx < class_num) { d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; next_idx += BlockSize; cur_idx += BlockSize; } __syncthreads(); - for (int stride = BlockSize >> 1; stride > 0; stride >>= 1) { + for (unsigned int stride = BlockSize >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; __syncthreads(); - if (tid < stride) { - next_idx = tid + stride; - d_sum[tid] += d_sum[next_idx]; - } } - __syncthreads(); - if (tid == 0) { - Y[blockIdx.x] = -d_sum[0]; - } + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; } -// TODO(qingqing): make zero setting an common function. +// TODO(qingqing): make zero setting a common function. template __global__ void zero(T* X, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; @@ -88,11 +122,9 @@ template __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const T* label, const int N, const int D) { - int row_ids = blockIdx.x * blockDim.x + threadIdx.x; - int col_ids = blockIdx.y * blockDim.y + threadIdx.y; - int ids = row_ids * D + col_ids; - + int ids = blockIdx.x * blockDim.x + threadIdx.x; if (ids < N * D) { + int row_ids = ids / D; dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } @@ -112,20 +144,34 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { y->mutable_data(ctx.GetPlace()); auto* y_data = y->data(); - int n = x->dims()[0]; - int d = x->dims()[1]; + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; int block = 512; - int grid = (n + block - 1) / block; - // TODO(qingqing) launch kernel on specified stream - // base on ExecutionContext. + if (ctx.Attr("soft_label")) { auto* label_data = ctx.Input("Label")->data(); - grid = d; - SoftCrossEntropyKernel<<>>(y_data, x_data, - label_data, n, d); + if (class_num > 512) { + SoftCrossEntropyKernel2< + T, 512><<( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, class_num); + } else { + int block_size = pow(2, int(std::log2(class_num))); + SoftCrossEntropyKernel1< + T><<( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, class_num); + } } else { auto* label_data = ctx.Input("Label")->data(); - CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + int grid = (batch_size + block - 1) / block; + CrossEntropyKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, + batch_size, class_num); } } }; @@ -148,25 +194,27 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { int n = x->dims()[0]; int d = x->dims()[1]; + int block = 512; int grid = (n * d + block - 1) / block; - zero<<>>(dx_data, n * d); - grid = (n + block - 1) / block; - // TODO(qingqing): launch kernel on specified stream - // base on ExecutionContext. + zero<<( + ctx.device_context()) + .stream()>>>(dx_data, n * d); if (ctx.Attr("soft_label")) { - int block_x = 32; - int block_y = 32; - dim3 block(block_x, block_y); - dim3 grid((n + block_x - 1) / block_x, (d + block_y - 1) / block_y); - auto* label_data = label->data(); - SoftCrossEntropyGradientKernel<<>>( - dx_data, dy_data, x_data, label_data, n, d); + SoftCrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + n, d); } else { auto* label_data = label->data(); - CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, - label_data, n, d); + CrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + n, d); } } }; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index cfb00fe647fe2ecda5323e9d170715d852e8043c..6b3f8c95bee58a3bf7d5f40ba84656beee6cdc97 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -31,12 +31,8 @@ struct TolerableValue { PADDLE_ASSERT(std::is_floating_point::value); const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; - } - if (x == -INFINITY) { - return -kApproInf; - } + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; return x; } }; @@ -58,11 +54,8 @@ class CrossEntropyOpKernel : public framework::OpKernel { auto lbl_mat = EigenMatrix::From(*labels); auto loss = EigenMatrix::From(*y); - // loss.device(ctx.GetEigenDevice()) = - // prob.log().unaryExpr(TolerableValue()); - loss.device(ctx.GetEigenDevice()) = - -((lbl_mat * prob.log()) + -((lbl_mat * prob.log().unaryExpr(TolerableValue())) .sum(Eigen::DSizes(1)) .reshape(Eigen::DSizes(batch_size, 1))); } else { diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 708344046760691aa2da562eb1ee3d8b130c5f18..62f63b4f3c876e084e2468001e8bcb9310d16a82 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTable<<>>(output, table, ids, N, K, D); + LookupTable<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(output, table, ids, N, K, D); } }; @@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGrad<<>>(d_table, d_output, ids, N, - K, D); + LookupTableGrad<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(d_table, d_output, ids, N, K, D); } }; diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu index afe4d149c53819c45e20353bc9d16393f3f61e0f..53fe505b77bfac8a33803f082f8e935d3ed403b6 100644 --- a/paddle/operators/top_k_op.cu +++ b/paddle/operators/top_k_op.cu @@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel { // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. - // TODO(typhoonzero): launch kernel on specified stream. // TODO(typhoonzero): refine this kernel. dim3 threads(256, 1); dim3 grid(input_height, 1); - KeMatrixTopK<<>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, int(k)); + KeMatrixTopK<<< + grid, threads, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(output_data, output->dims()[1], + indices_data, input_data, + input_width, input_width, int(k)); } }; diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 8fdf622970089f3504a30d2a9967cadf821bb87e..1715a4db345d1b123fca2266d198ee1622234501 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest): dtype="float32") self.inputs = {"X": X, "Label": label} self.outputs = {"Y": cross_entropy} - self.attrs = {'soft_label': False} + self.attrs = {"soft_label": False} def test_check_output(self): self.check_output() @@ -34,7 +34,8 @@ class TestCrossEntropyOp2(OpTest): def setUp(self): self.op_type = "cross_entropy" - batch_size = 13 + batch_size = 5 + # this setting tests threads in more than one wrap. class_num = 37 X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") @@ -43,9 +44,9 @@ class TestCrossEntropyOp2(OpTest): label /= label.sum(axis=1, keepdims=True) cross_entropy = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} def test_check_output(self): self.check_output() @@ -61,8 +62,9 @@ class TestCrossEntropyOp3(OpTest): def setUp(self): self.op_type = "cross_entropy" - batch_size = 13 - class_num = 37 + batch_size = 5 + # this setting tests all threads in one wrap. + class_num = 17 X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label_index = np.random.randint( @@ -74,9 +76,36 @@ class TestCrossEntropyOp3(OpTest): dtype="float32") cross_entropy2 = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.05) + + +class TestCrossEntropyOp4(OpTest): + """Test soft-label cross-entropy. + This unittest tests the gpu kernel for layer size excesses 512. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 2 + class_num = 517 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label /= label.sum(axis=1, keepdims=True) + cross_entropy = (-label * np.log(X)).sum( + axis=1, keepdims=True).astype("float32") + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} def test_check_output(self): self.check_output()