diff --git a/paddle/fluid/operators/histogram_op.cu b/paddle/fluid/operators/histogram_op.cu index 3de24ead0de36245f96af4bb7b6c72209b37f885..5f86f8d72c079dd554482685403a74d14934336e 100644 --- a/paddle/fluid/operators/histogram_op.cu +++ b/paddle/fluid/operators/histogram_op.cu @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/histogram_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -32,28 +30,38 @@ inline int GET_BLOCKS(const int N) { } template -__device__ static IndexType GetBin(T bVal, T minvalue, T maxvalue, +__device__ static IndexType GetBin(T input_value, T min_value, T max_value, int64_t nbins) { - IndexType bin = - static_cast((bVal - minvalue) * nbins / (maxvalue - minvalue)); - if (bin == nbins) bin -= 1; - return bin; + IndexType bin = static_cast((input_value - min_value) * nbins / + (max_value - min_value)); + IndexType output_index = bin < nbins - 1 ? bin : nbins - 1; + return output_index; } template -__global__ void KernelHistogram(const T* input, const int totalElements, - const int64_t nbins, const T minvalue, - const T maxvalue, int64_t* output) { - CUDA_KERNEL_LOOP(linearIndex, totalElements) { - const IndexType inputIdx = threadIdx.x + blockIdx.x * blockDim.x; - const auto inputVal = input[inputIdx]; - if (inputVal >= minvalue && inputVal <= maxvalue) { - const IndexType bin = - GetBin(inputVal, minvalue, maxvalue, nbins); - const IndexType outputIdx = bin < nbins - 1 ? bin : nbins - 1; - paddle::platform::CudaAtomicAdd(&output[outputIdx], 1); +__global__ void KernelHistogram(const T* input, const int total_elements, + const int64_t nbins, const T min_value, + const T max_value, int64_t* output) { + extern __shared__ int64_t buf_hist[]; + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + buf_hist[i] = 0; + } + __syncthreads(); + + CUDA_KERNEL_LOOP(input_index, total_elements) { + // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; + const auto input_value = input[input_index]; + if (input_value >= min_value && input_value <= max_value) { + const IndexType output_index = + GetBin(input_value, min_value, max_value, nbins); + paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1); } } + __syncthreads(); + + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]); + } } template @@ -125,8 +133,9 @@ class HistogramCUDAKernel : public framework::OpKernel { auto stream = context.template device_context().stream(); - KernelHistogram<<>>( + KernelHistogram< + T, IndexType><<>>( input_data, input_numel, nbins, output_min, output_max, out_data); } }; diff --git a/python/paddle/fluid/tests/unittests/test_histogram_op.py b/python/paddle/fluid/tests/unittests/test_histogram_op.py index 0f880f2b03563c975790866a854ba5973f3730af..0ccb6fce8e4eda39f88a54141127d5bb1234ff51 100644 --- a/python/paddle/fluid/tests/unittests/test_histogram_op.py +++ b/python/paddle/fluid/tests/unittests/test_histogram_op.py @@ -58,12 +58,66 @@ class TestHistogramOpAPI(unittest.TestCase): msg='histogram output is wrong, out =' + str(actual.numpy())) +class TestHistogramOpError(unittest.TestCase): + """Test histogram op error.""" + + def run_network(self, net_func): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + net_func() + exe = fluid.Executor() + exe.run(main_program) + + def test_bins_error(self): + """Test bins should be greater than or equal to 1.""" + + def net_func(): + input_value = paddle.fill_constant( + shape=[3, 4], dtype='float32', value=3.0) + paddle.histogram(input=input_value, bins=-1, min=1, max=5) + + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_min_max_error(self): + """Test max must be larger or equal to min.""" + + def net_func(): + input_value = paddle.fill_constant( + shape=[3, 4], dtype='float32', value=3.0) + paddle.histogram(input=input_value, bins=1, min=5, max=1) + + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_min_max_range_error(self): + """Test range of min, max is not finite""" + + def net_func(): + input_value = paddle.fill_constant( + shape=[3, 4], dtype='float32', value=3.0) + paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5) + + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_type_errors(self): + with program_guard(Program()): + # The input type must be Variable. + self.assertRaises( + TypeError, paddle.histogram, 1, bins=5, min=1, max=5) + # The input type must be 'int32', 'int64', 'float32', 'float64' + x_bool = fluid.data(name='x_bool', shape=[4, 3], dtype='bool') + self.assertRaises( + TypeError, paddle.histogram, x_bool, bins=5, min=1, max=5) + + class TestHistogramOp(OpTest): def setUp(self): self.op_type = "histogram" self.init_test_case() - np_input = np.random.randint( - low=0, high=20, size=self.in_shape, dtype=np.int64) + np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape) self.inputs = {"X": np_input} self.init_attrs() Out, _ = np.histogram( diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index db0233e4423b3988895a57eae00a4c07217a246b..c41c9226d16b41934f738719ae1251127d439ccf 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -862,41 +862,23 @@ def histogram(input, bins=100, min=0, max=0): If min and max are both zero, the minimum and maximum values of the data are used. Args: - input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor should be float32, float64, int32, int64. bins (int): number of histogram bins min (int): lower end of the range (inclusive) max (int): upper end of the range (inclusive) Returns: - Variable: Tensor or LoDTensor calculated by histogram layer. The data type is int64. + Tensor: data type is int64, shape is (nbins,). - Code Example 1: - .. code-block:: python - import paddle - import numpy as np - startup_program = paddle.static.Program() - train_program = paddle.static.Program() - with paddle.static.program_guard(train_program, startup_program): - inputs = paddle.data(name='input', dtype='int32', shape=[2,3]) - output = paddle.histogram(inputs, bins=5, min=1, max=5) - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_program) - img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int32) - res = exe.run(train_program, - feed={'input': img}, - fetch_list=[output]) - print(np.array(res[0])) # [0,3,0,2,1] - - Code Example 2: + Examples: .. code-block:: python + import paddle - paddle.disable_static(paddle.CPUPlace()) + inputs = paddle.to_tensor([1, 2, 1]) result = paddle.histogram(inputs, bins=4, min=0, max=3) print(result) # [0, 2, 1, 0] - paddle.enable_static() """ if in_dygraph_mode(): return core.ops.histogram(input, "bins", bins, "min", min, "max", max)