未验证 提交 f373269d 编写于 作者: Q Qi Li 提交者: GitHub

update histogram op for performance optimization, test=develop (#24912)

上级 4d5ddbf1
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/histogram_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -32,28 +30,38 @@ inline int GET_BLOCKS(const int N) {
}
template <typename T, typename IndexType>
__device__ static IndexType GetBin(T bVal, T minvalue, T maxvalue,
__device__ static IndexType GetBin(T input_value, T min_value, T max_value,
int64_t nbins) {
IndexType bin =
static_cast<int>((bVal - minvalue) * nbins / (maxvalue - minvalue));
if (bin == nbins) bin -= 1;
return bin;
IndexType bin = static_cast<int>((input_value - min_value) * nbins /
(max_value - min_value));
IndexType output_index = bin < nbins - 1 ? bin : nbins - 1;
return output_index;
}
template <typename T, typename IndexType>
__global__ void KernelHistogram(const T* input, const int totalElements,
const int64_t nbins, const T minvalue,
const T maxvalue, int64_t* output) {
CUDA_KERNEL_LOOP(linearIndex, totalElements) {
const IndexType inputIdx = threadIdx.x + blockIdx.x * blockDim.x;
const auto inputVal = input[inputIdx];
if (inputVal >= minvalue && inputVal <= maxvalue) {
const IndexType bin =
GetBin<T, IndexType>(inputVal, minvalue, maxvalue, nbins);
const IndexType outputIdx = bin < nbins - 1 ? bin : nbins - 1;
paddle::platform::CudaAtomicAdd(&output[outputIdx], 1);
__global__ void KernelHistogram(const T* input, const int total_elements,
const int64_t nbins, const T min_value,
const T max_value, int64_t* output) {
extern __shared__ int64_t buf_hist[];
for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
buf_hist[i] = 0;
}
__syncthreads();
CUDA_KERNEL_LOOP(input_index, total_elements) {
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
const auto input_value = input[input_index];
if (input_value >= min_value && input_value <= max_value) {
const IndexType output_index =
GetBin<T, IndexType>(input_value, min_value, max_value, nbins);
paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]);
}
}
template <typename DeviceContext, typename T>
......@@ -125,8 +133,9 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
KernelHistogram<T, IndexType><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
KernelHistogram<
T, IndexType><<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS,
nbins * sizeof(int64_t), stream>>>(
input_data, input_numel, nbins, output_min, output_max, out_data);
}
};
......
......@@ -58,12 +58,66 @@ class TestHistogramOpAPI(unittest.TestCase):
msg='histogram output is wrong, out =' + str(actual.numpy()))
class TestHistogramOpError(unittest.TestCase):
"""Test histogram op error."""
def run_network(self, net_func):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
net_func()
exe = fluid.Executor()
exe.run(main_program)
def test_bins_error(self):
"""Test bins should be greater than or equal to 1."""
def net_func():
input_value = paddle.fill_constant(
shape=[3, 4], dtype='float32', value=3.0)
paddle.histogram(input=input_value, bins=-1, min=1, max=5)
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_min_max_error(self):
"""Test max must be larger or equal to min."""
def net_func():
input_value = paddle.fill_constant(
shape=[3, 4], dtype='float32', value=3.0)
paddle.histogram(input=input_value, bins=1, min=5, max=1)
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_min_max_range_error(self):
"""Test range of min, max is not finite"""
def net_func():
input_value = paddle.fill_constant(
shape=[3, 4], dtype='float32', value=3.0)
paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5)
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_type_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(
TypeError, paddle.histogram, 1, bins=5, min=1, max=5)
# The input type must be 'int32', 'int64', 'float32', 'float64'
x_bool = fluid.data(name='x_bool', shape=[4, 3], dtype='bool')
self.assertRaises(
TypeError, paddle.histogram, x_bool, bins=5, min=1, max=5)
class TestHistogramOp(OpTest):
def setUp(self):
self.op_type = "histogram"
self.init_test_case()
np_input = np.random.randint(
low=0, high=20, size=self.in_shape, dtype=np.int64)
np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape)
self.inputs = {"X": np_input}
self.init_attrs()
Out, _ = np.histogram(
......
......@@ -862,41 +862,23 @@ def histogram(input, bins=100, min=0, max=0):
If min and max are both zero, the minimum and maximum values of the data are used.
Args:
input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
input (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64.
bins (int): number of histogram bins
min (int): lower end of the range (inclusive)
max (int): upper end of the range (inclusive)
Returns:
Variable: Tensor or LoDTensor calculated by histogram layer. The data type is int64.
Tensor: data type is int64, shape is (nbins,).
Code Example 1:
.. code-block:: python
import paddle
import numpy as np
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
inputs = paddle.data(name='input', dtype='int32', shape=[2,3])
output = paddle.histogram(inputs, bins=5, min=1, max=5)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int32)
res = exe.run(train_program,
feed={'input': img},
fetch_list=[output])
print(np.array(res[0])) # [0,3,0,2,1]
Code Example 2:
Examples:
.. code-block:: python
import paddle
paddle.disable_static(paddle.CPUPlace())
inputs = paddle.to_tensor([1, 2, 1])
result = paddle.histogram(inputs, bins=4, min=0, max=3)
print(result) # [0, 2, 1, 0]
paddle.enable_static()
"""
if in_dygraph_mode():
return core.ops.histogram(input, "bins", bins, "min", min, "max", max)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册