diff --git a/paddle/fluid/operators/label_smooth_op.cc b/paddle/fluid/operators/label_smooth_op.cc index 6d0af573184b10a783f9c5802d1db3630eb55538..588582266c49d7c15c71fe38ddb591e36a606496 100644 --- a/paddle/fluid/operators/label_smooth_op.cc +++ b/paddle/fluid/operators/label_smooth_op.cc @@ -37,7 +37,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel { auto noise_dims = ctx->GetInputDim("PriorDist"); auto noise_numel = paddle::framework::product(noise_dims); PADDLE_ENFORCE( - in_dims[1] == noise_numel, + in_dims[in_dims.size() - 1] == noise_numel, "The number of elements in Input(PriorDist) must be equal to the " "dimension of each label."); } diff --git a/paddle/fluid/operators/label_smooth_op.cu b/paddle/fluid/operators/label_smooth_op.cu index 89f1d28e9988281c77e0cefa349bd7181b432c20..33ae35a81f848129223815c5c08f5de82b329f92 100644 --- a/paddle/fluid/operators/label_smooth_op.cu +++ b/paddle/fluid/operators/label_smooth_op.cu @@ -34,7 +34,7 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, const T* dist_data, T* dst) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < N; idx += blockDim.x * gridDim.x) { - int dist_idx = idx - (idx / dist_numel) * dist_numel; + int dist_idx = idx % dist_numel; dst[idx] = static_cast(1 - epsilon) * src[idx] + static_cast(epsilon) * dist_data[dist_idx]; } @@ -56,7 +56,7 @@ class LabelSmoothGPUKernel : public framework::OpKernel { auto* out_t = ctx.Output("Out"); auto* in_t = ctx.Input("X"); auto* dist_t = ctx.Input("PriorDist"); - auto label_dim = in_t->dims()[1]; + auto label_dim = in_t->dims()[in_t->dims().size() - 1]; auto epsilon = ctx.Attr("epsilon"); auto& dev = *ctx.template device_context().eigen_device(); auto size_prob = in_t->numel(); diff --git a/paddle/fluid/operators/label_smooth_op.h b/paddle/fluid/operators/label_smooth_op.h index f3da17de011053fa118b5a4257bb5c3b00084741..760d542505ec138eca032b82a7e7902db9c0887a 100644 --- a/paddle/fluid/operators/label_smooth_op.h +++ b/paddle/fluid/operators/label_smooth_op.h @@ -27,7 +27,7 @@ class LabelSmoothKernel : public framework::OpKernel { auto* out_t = ctx.Output("Out"); auto* in_t = ctx.Input("X"); auto* dist_t = ctx.Input("PriorDist"); - auto label_dim = in_t->dims()[1]; + auto label_dim = in_t->dims()[in_t->dims().size() - 1]; out_t->mutable_data(ctx.GetPlace()); auto epsilon = ctx.Attr("epsilon"); @@ -39,7 +39,7 @@ class LabelSmoothKernel : public framework::OpKernel { out.device(dev) = static_cast(1 - epsilon) * in + static_cast(epsilon) * - dist.broadcast(Eigen::DSizes(in_t->numel())); + dist.broadcast(Eigen::DSizes(in_t->numel() / label_dim)); } else { out.device(dev) = static_cast(1 - epsilon) * in + static_cast(epsilon / label_dim); diff --git a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py index 62d385bc52cfb3a9fe15a82096ff33abc1bcc552..222e1321fecf8a383ebdba66ff17e760bab6ae51 100644 --- a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py +++ b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py @@ -23,7 +23,7 @@ class TestLabelSmoothOp(OpTest): def config(self): self.op_type = "label_smooth" self.epsilon = 0.1 - batch_size, self.label_dim = 5, 10 + batch_size, self.label_dim = 10, 12 self.label = np.zeros((batch_size, self.label_dim)).astype("float64") nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) self.label[np.arange(batch_size), nonzero_index] = 1 @@ -53,5 +53,23 @@ class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): self.outputs = {'Out': smoothed_label} +class TestLabelSmoothOp3D(TestLabelSmoothOp): + def setUp(self): + super(TestLabelSmoothOp3D, self).setUp() + self.inputs['X'] = self.inputs['X'].reshape( + [2, -1, self.inputs['X'].shape[-1]]) + self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X'] + .shape) + + +class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist): + def setUp(self): + super(TestLabelSmoothOpWithPriorDist3D, self).setUp() + self.inputs['X'] = self.inputs['X'].reshape( + [2, -1, self.inputs['X'].shape[-1]]) + self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X'] + .shape) + + if __name__ == '__main__': unittest.main()