diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index e076444626e6afa7893b539481804964750e48bd..2c7fd8f4173ea72f0b40fe0d7620f168554fd33f 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel { "The Attr(groups) of Op(group_norm) must be " "greater than or equal to 1. But received: groups is [%s].", groups)); + PADDLE_ENFORCE_EQ( + channel_num % groups, 0, + platform::errors::InvalidArgument( + "Expected number of channels in input to be divisible by " + "num_groups, but got input channel is %d and num_groups is %d", + channel_num, groups)); if (ctx->HasInput("Scale")) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index f199bfeb9443b668b5833d9a011e36577b66bb4a..e029c84090af19e186e30be63f28b01270ef94c5 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -144,7 +144,8 @@ class GroupNormKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; + const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); @@ -314,7 +315,7 @@ class GroupNormGradKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h index f2388699e266f52fd1b06612ee4f78fb4ec88b21..9cb451235f152cc855e4b47388b9ce13e7ff8911 100644 --- a/paddle/fluid/operators/group_norm_op.h +++ b/paddle/fluid/operators/group_norm_op.h @@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel { const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; y->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); @@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel { int imid; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M) { - // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used // in template class/function, before we complete high // performance cpu vector extension, temporarily unrolling // loop to get high precision and performance @@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel { int imid; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M * C) { - // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used // in template class/function, before we complete high // performance cpu vector extension, temporarily unrolling // loop to get high precision and performance @@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel { const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; d_x->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; diff --git a/paddle/fluid/operators/index_select_op_npu.cc b/paddle/fluid/operators/index_select_op_npu.cc index 8df6c4e5d9ea7203dee3958545c55a33899ae231..b624d03cc8555938ee6f527b890c0575d59799e3 100644 --- a/paddle/fluid/operators/index_select_op_npu.cc +++ b/paddle/fluid/operators/index_select_op_npu.cc @@ -21,12 +21,12 @@ namespace operators { template class IndexSelectNPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* index = ctx.Input("Index"); auto dim = ctx.Attr("dim"); - auto *out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto stream = @@ -43,7 +43,104 @@ class IndexSelectNPUKernel : public framework::OpKernel { } }; -// todo: add class 'IndexSelectGradNPUKernel' here. +template +class IndexSelectGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_grad = ctx.Output(framework::GradVarName("X")); + auto* index = ctx.Input("Index"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + + auto stream = + ctx.template device_context() + .stream(); + + auto x_dims = x_grad->dims(); + auto out_dims = out_grad->dims(); + + int dim = ctx.Attr("dim"); + if (dim < 0) { + dim += out_dims.size(); + } + + Tensor casted_index; + if (index->type() != framework::proto::VarType::INT32) { + casted_index.mutable_data(index->dims(), ctx.GetPlace()); + const auto& cast_runner = NpuOpRunner("Cast", {*index}, {casted_index}, + {{"dst_type", ACL_INT32}}); + cast_runner.Run(stream); + } else { + casted_index.ShareDataWith(*index); + } + + if (dim == 0) { + x_grad->mutable_data(ctx.GetPlace()); + const auto& zeros_runner = NpuOpRunner("ZerosLike", {*x_grad}, {*x_grad}); + zeros_runner.Run(stream); + + NpuOpRunner runner; + runner.SetType("UnsortedSegmentSum") + .AddInput(*out_grad) + .AddInput(casted_index) + .AddInput(std::vector{x_dims[dim]}) + .AddOutput(*x_grad); + runner.Run(stream); + } else { + Tensor transed_out_grad; + std::vector in_trans_perm; + in_trans_perm.push_back(dim); + for (int i = 0; i < out_dims.size(); ++i) { + if (i == dim) continue; + in_trans_perm.push_back(i); + } + framework::DDim transed_out_dims(out_dims); + for (size_t i = 0; i < in_trans_perm.size(); ++i) { + transed_out_dims[i] = out_dims[in_trans_perm[i]]; + } + transed_out_grad.mutable_data(transed_out_dims, ctx.GetPlace()); + framework::NPUAttributeMap in_trans_attr = {{"perm", in_trans_perm}}; + + const auto& in_trans_runner = NpuOpRunner( + "TransposeD", {*out_grad}, {transed_out_grad}, in_trans_attr); + in_trans_runner.Run(stream); + + Tensor sum_out; + framework::DDim sum_dims(x_dims); + sum_dims[0] = x_dims[dim]; + auto idx = 1; + for (int i = 0; i < x_dims.size(); ++i) { + if (i == dim) continue; + sum_dims[idx++] = x_dims[i]; + } + sum_out.mutable_data(sum_dims, ctx.GetPlace()); + const auto& zeros_runner = NpuOpRunner("ZerosLike", {sum_out}, {sum_out}); + zeros_runner.Run(stream); + + NpuOpRunner runner; + runner.SetType("UnsortedSegmentSum") + .AddInput(transed_out_grad) + .AddInput(casted_index) + .AddInput(std::vector{x_dims[dim]}) + .AddOutput(sum_out); + runner.Run(stream); + + std::vector out_trans_perm; + for (int i = 1; i < 1 + dim; ++i) { + out_trans_perm.push_back(i); + } + out_trans_perm.push_back(0); + for (int i = 1 + dim; i < x_dims.size(); ++i) { + out_trans_perm.push_back(i); + } + framework::NPUAttributeMap out_trans_attr = {{"perm", out_trans_perm}}; + x_grad->mutable_data(ctx.GetPlace()); + const auto& out_trans_runner = + NpuOpRunner("TransposeD", {sum_out}, {*x_grad}, out_trans_attr); + out_trans_runner.Run(stream); + } + } +}; } // namespace operators } // namespace paddle @@ -54,4 +151,8 @@ REGISTER_OP_NPU_KERNEL( ops::IndexSelectNPUKernel, ops::IndexSelectNPUKernel, ops::IndexSelectNPUKernel); -// todo: register npu index_select_grad kernel here. +REGISTER_OP_NPU_KERNEL( + index_select_grad, + ops::IndexSelectGradNPUKernel, + ops::IndexSelectGradNPUKernel, + ops::IndexSelectGradNPUKernel); diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index cc98d378f148949e4c443c4494e3936ed6a34e09..70c7b01b05ba3bd71c69aaf8c37ae9d5830c8fd7 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -43,6 +43,36 @@ from .flat import _flatten_batch, _restore_batch __all__ = ['get_worker_info'] +# NOTE: fix `terminate called without an active exception` +# if for loop break and program exit immediately(with no model +# layers processing) after iterate **the first few data** in +# distributed lauch mode, distributed launch will call +# terminate() to kill main process on each devices, but thread +# is still iterating to fullfill blocking queue caches, which +# may cause thread error `terminate called without an active +# exception` for terminate is a strong singal and `__del__` +# of DataLoader may not be called, so we add a global link to +# the last DataLoader instance to call `__del__` to clean up +# resources +# NOTE: cannot simply as `__del__` to CleanupFuncRegistrar, +# for this will remain a link to each DataLoader instance in +# global, and will precludes GC to auto collect DataLoader +# instance and will cause memory leak +_loader = None + + +def _clear_loader(): + global _loader + if _loader is not None: + try: + _loader.__del__() + del _loader + except: + pass + + +CleanupFuncRegistrar.register(_clear_loader) + class _DataLoaderIterBase(object): """ @@ -100,6 +130,16 @@ class _DataLoaderIterBase(object): def __len__(self): return len(self._batch_sampler) + def _exit_thread_expectedly(self): + self._thread_done_event.set() + if self._blocking_queue: + self._blocking_queue.close() + + def _exit_thread_unexpectedly(self): + self._thread_done_event.set() + if self._blocking_queue: + self._blocking_queue.kill() + class _DataLoaderIterSingleProcess(_DataLoaderIterBase): """ @@ -125,9 +165,13 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): # NOTE: len(self._places) batch data compose as an output # iteration, set blocking_queue can cache 2 iteration datas # at most here - self._blocking_queue_capacity = 2 * len(self._places) + self._blocking_queue_capacity = 1 * len(self._places) self._init_thread() + self._shutdown = False + + global _loader + _loader = self def _init_thread(self): self._var_names = [v.name for v in self._feed_list] @@ -151,22 +195,35 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): self._thread.start() def _thread_loop(self, legacy_expected_place): - try: - #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, - # and it will call platform::SetDeviceId() in c++ internally. - # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0, - # Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda - # APIs in this thread. - _set_expected_place(legacy_expected_place) - - for indices in self._sampler_iter: + #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, + # and it will call platform::SetDeviceId() in c++ internally. + # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0, + # Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda + # APIs in this thread. + _set_expected_place(legacy_expected_place) + + while not self._thread_done_event.is_set(): + try: + indices = next(self._sampler_iter) + # read data from dataset in mini-batch - batch = self._dataset_fetcher.fetch(indices) + # with paddle.fluid.dygraph.guard(place=paddle.CPUPlace()): + # read data from dataset in mini-batch + batch = self._dataset_fetcher.fetch(indices, + self._thread_done_event) + except StopIteration: + self._exit_thread_expectedly() + return + + if batch is None or self._thread_done_event.is_set(): break + + # flat batch and record structure infos + batch, structure = _flatten_batch(batch) + self._structure_infos.append(structure) - # flat batch and record structure infos - batch, structure = _flatten_batch(batch) - self._structure_infos.append(structure) + if self._thread_done_event.is_set(): break + try: # pack as LoDTensorArray array = core.LoDTensorArray() for slot in batch: @@ -179,21 +236,18 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): array.append(slot) - if not self._blocking_queue.push(array): - break + if self._thread_done_event.is_set(): break - if self._thread_done_event.is_set(): - break + try: + self._blocking_queue.push(array) + except: + self._exit_thread_expectedly() - self._blocking_queue.close() - self._shutdown_thread() - except StopIteration: - self._blocking_queue.close() - except Exception: - self._blocking_queue.kill() - self._shutdown_thread() - logging.warning("DataLoader reader thread raised an exception.") - six.reraise(*sys.exc_info()) + except: + self._exit_thread_unexpectedly() + six.reraise(*sys.exc_info()) + + self._exit_thread_expectedly() def __next__(self): try: @@ -221,28 +275,46 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): return data except StopIteration: self._reader.shutdown() + self._try_shutdown_all() six.reraise(*sys.exc_info()) def _shutdown_thread(self): if self._thread: self._thread_done_event.set() - if self._thread is not threading.current_thread(): - self._thread.join() + # NOTE: we wait for _thread exit for 3 seconds, if + # thread not exit normally, force kill it + for _ in range(3): + if self._thread.is_alive(): + time.sleep(1) + else: + break + else: + if self._thread is not threading.current_thread(): + self._thread.join() + self._thread = None # python2 compatibility def next(self): return self.__next__() + def _try_shutdown_all(self): + if not self._shutdown: + try: + # # _blocking_queue in keep order mode holds sub-threads + # # need to release thread resources on unexpected exit + if self._blocking_queue: + self._blocking_queue.close() + self._blocking_queue = None + # NOTE: blocking queue should be closed firstly for + # blocking queue read may hang and _thread_done_event + # cannot be checked + self._shutdown_thread() + finally: + self._shutdown = True + def __del__(self): - # _blocking_queue in keep order mode holds sub-threads - # need to release thread resources on unexpected exit - if self._blocking_queue: - self._blocking_queue.close() - # NOTE: blocking queue should be closed firstly for - # blocking queue read may hang and _thread_done_event - # cannot be checked - self._shutdown_thread() + self._try_shutdown_all() class _DataLoaderIterMultiProcess(_DataLoaderIterBase): @@ -421,15 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): core._erase_process_pids(id(self)) self._shutdown = True - def _exit_thread_expectedly(self): - self._thread_done_event.set() - self._blocking_queue.close() - - def _exit_thread_unexpectedly(self): - self._thread_done_event.set() - self._blocking_queue.kill() - logging.error("DataLoader reader thread raised an exception!") - def _thread_loop(self, legacy_expected_place): #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, # and it will call platform::SetDeviceId() in c++ internally. diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 8ccec81810a0a60d75b2546bd7cad4ede226855b..ec3240a326b8eddfe04be5e8f1e8d785265b7690 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -26,7 +26,16 @@ class _DatasetFetcher(object): self.collate_fn = collate_fn self.drop_last = drop_last - def fetch(self, batch_indices): + # NOTE: fetch function here perform the whole pipeline of dataset + # reading and data trasforms of a batch in each calling, this + # may take a long time inside, if DataLoader is exit outside, + # fetch need to perceive exit situation, so we pass done_event + # here for fetch to check exit status + # NOTE: if DataLoadet exit by `break`, performing GPU tensor operations, + # e.g. to_tensor may cause SIGSEGV in thread, so we pass the + # done_event argument to check DataLoader exit status between + # ecah sample processing in the batch + def fetch(self, batch_indices, done_event=None): raise NotImplementedError("'fetch' not implement for class {}".format( self.__class__.__name__)) @@ -69,15 +78,18 @@ class _IterableDatasetFetcher(_DatasetFetcher): dataset, auto_collate_batch, collate_fn, drop_last) self.dataset_iter = iter(dataset) - def fetch(self, batch_indices): + def fetch(self, batch_indices, done_event=None): if self.auto_collate_batch: data = [] for _ in batch_indices: - try: - data.append(next(self.dataset_iter)) - except StopIteration: - break + if done_event is None or not done_event.is_set(): + try: + data.append(next(self.dataset_iter)) + except StopIteration: + break + else: + return None if len(data) == 0 or (self.drop_last and len(data) < len(batch_indices)): @@ -101,9 +113,14 @@ class _MapDatasetFetcher(_DatasetFetcher): super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last) - def fetch(self, batch_indices): + def fetch(self, batch_indices, done_event=None): if self.auto_collate_batch: - data = [self.dataset[idx] for idx in batch_indices] + data = [] + for idx in batch_indices: + if done_event is None or not done_event.is_set(): + data.append(self.dataset[idx]) + else: + return None global _WARNING_TO_LOG if not isinstance(data[0], (Sequence, Mapping)) \ diff --git a/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py index ff0d57d1d4da1028d0db28ee90f6a950ce33b9ea..57293ad5e56335aeb04949177b632e1e6763fefe 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py @@ -35,7 +35,10 @@ class TestNPUIndexSelect(OpTest): x_np = np.random.random(self.x_shape).astype(self.x_type) index_np = np.random.randint( - low=0, high=self.x_shape[self.dim], size=self.index_size) + low=0, + high=self.x_shape[self.dim], + size=self.index_size, + dtype=self.index_type) # compute real output as baseline. outer_loop = np.prod(self.x_shape[:self.dim]) @@ -56,18 +59,14 @@ class TestNPUIndexSelect(OpTest): self.attrs = {'dim': self.dim} self.outputs = {'Out': out} - # todo: comment second line when index_select grad npu op is ready. def set_npu(self): self.__class__.use_npu = True - self.__class__.no_need_check_grad = True def test_check_output(self): self.check_output_with_place(self.place) - # todo: replace first line with second line when index_select grad npu op is ready. def test_check_grad(self): - pass - #self.check_grad_with_place(self.place, ['X'], 'Out') + self.check_grad_with_place(self.place, ['X'], 'Out') def config(self): self.x_shape = (100, 4, 5) @@ -86,6 +85,24 @@ class TestNPUIndexSelectCase2(TestNPUIndexSelect): self.index_size = 10 +class TestNPUIndexSelectCase3(TestNPUIndexSelect): + def config(self): + self.dim = 0 + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + +class TestNPUIndexSelectCase4(TestNPUIndexSelect): + def config(self): + self.dim = -1 + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + class TestNPUIndexSelectAPI(unittest.TestCase): def input_data(self): self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], diff --git a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py index d2f4eadc9c56412adeec96c58d31bec5da3ab9ac..c54a1406e39bf3ffb97a050928630c3e0272fe7a 100644 --- a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py @@ -43,14 +43,18 @@ class TestDatasetAbstract(unittest.TestCase): class TestDatasetWithDiffOutputPlace(unittest.TestCase): def get_dataloader(self, num_workers): dataset = paddle.vision.datasets.MNIST( - mode='test', transform=transforms.ToTensor()) + mode='test', + transform=transforms.Compose([ + transforms.CenterCrop(20), transforms.RandomResizedCrop(14), + transforms.Normalize(), transforms.ToTensor() + ])) loader = paddle.io.DataLoader( dataset, batch_size=32, num_workers=num_workers, shuffle=True) return loader def run_check_on_cpu(self): paddle.set_device('cpu') - loader = self.get_dataloader(0) + loader = self.get_dataloader(1) for image, label in loader: self.assertTrue(image.place.is_cpu_place()) self.assertTrue(label.place.is_cpu_place()) @@ -66,12 +70,7 @@ class TestDatasetWithDiffOutputPlace(unittest.TestCase): for image, label in loader: self.assertTrue(image.place.is_gpu_place()) self.assertTrue(label.place.is_cuda_pinned_place()) - # FIXME(dkp): when input tensor is in GPU place and - # iteration break in the median, it seems the GPU - # tensor put into blocking_queue cannot be safely - # released and may cause ABRT/SEGV, this should - # be fixed - # break + break def test_multi_process(self): # DataLoader with multi-process mode is not supported on MacOs and Windows currently diff --git a/python/paddle/fluid/tests/unittests/test_segment_ops.py b/python/paddle/fluid/tests/unittests/test_segment_ops.py index b58d66676b05524766366d9587d395aadc32a7b4..e2aadbedbd07fdc77e9a6d1f6e2740a826032393 100644 --- a/python/paddle/fluid/tests/unittests/test_segment_ops.py +++ b/python/paddle/fluid/tests/unittests/test_segment_ops.py @@ -15,8 +15,11 @@ from __future__ import print_function import unittest -import numpy as np import sys + +import numpy as np +import paddle + from op_test import OpTest @@ -198,5 +201,62 @@ class TestSegmentMean2(TestSegmentMean): self.attrs = {'pooltype': "MEAN"} +class API_SegmentOpsTest(unittest.TestCase): + def test_static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[3], dtype='int32') + + res_sum = paddle.incubate.segment_sum(x, y) + res_mean = paddle.incubate.segment_mean(x, y) + res_max = paddle.incubate.segment_max(x, y) + res_min = paddle.incubate.segment_min(x, y) + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + data2 = np.array([0, 0, 1], dtype="int32") + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32") + + ret = exe.run(feed={'x': data1, + 'y': data2}, + fetch_list=[res_sum, res_mean, res_max, res_min]) + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_dygraph(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor( + [[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + y = paddle.to_tensor([0, 0, 1], dtype="int32") + res_sum = paddle.incubate.segment_sum(x, y) + res_mean = paddle.incubate.segment_mean(x, y) + res_max = paddle.incubate.segment_max(x, y) + res_min = paddle.incubate.segment_min(x, y) + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res.numpy(), atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index efaeda272087fcce65cf9a4b174b491e7e60d097..644b934814020f9d781f771f19896126186e50cd 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -18,7 +18,18 @@ from .checkpoint import auto_checkpoint # noqa: F401 from ..fluid.layer_helper import LayerHelper # noqa: F401 from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 from .operators import softmax_mask_fuse # noqa: F401 +from .tensor import segment_sum +from .tensor import segment_mean +from .tensor import segment_max +from .tensor import segment_min -__all__ = [ # noqa - 'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse' +__all__ = [ + 'LookAhead', + 'ModelAverage', + 'softmax_mask_fuse_upper_triangle', + 'softmax_mask_fuse', + 'segment_sum', + 'segment_mean', + 'segment_max', + 'segment_min', ] diff --git a/python/paddle/incubate/tensor/__init__.py b/python/paddle/incubate/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1018409ab0fbfca789d7b7fb8bbf3f2202302c --- /dev/null +++ b/python/paddle/incubate/tensor/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .math import segment_sum +from .math import segment_mean +from .math import segment_max +from .math import segment_min + +__all__ = [ + 'segment_sum', + 'segment_mean', + 'segment_max', + 'segment_min', +] diff --git a/python/paddle/incubate/tensor/math.py b/python/paddle/incubate/tensor/math.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cb8d50514f013d38dc3e3d218b86139d564d6c --- /dev/null +++ b/python/paddle/incubate/tensor/math.py @@ -0,0 +1,225 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'segment_sum', + 'segment_mean', + 'segment_max', + 'segment_min', +] + +import paddle + +from paddle.fluid.layer_helper import LayerHelper, in_dygraph_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle import _C_ops + + +def segment_sum(data, segment_ids, name=None): + """ + Segment Sum Operator. + + This operator sums the elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\sum_{j} data_{j}$ + where sum is over j such that `segment_ids[j] == i`. + + Args: + data (Tensor): A tensor, available data type float32, float64. + segment_ids (Tensor): A 1-D tensor, which have the same size + with the first dimension of input data. + Available data type is int32, int64. + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_sum(data, segment_ids) + #Outputs: [[4., 4., 4.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_sum", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "SUM"}) + return out + + +def segment_mean(data, segment_ids, name=None): + """ + Segment mean Operator. + + Ihis operator calculate the mean value of input `data` which + with the same index in `segment_ids`. + It computes a tensor such that $out_i = \\frac{1}{n_i} \\sum_{j} data[j]$ + where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number + of all index 'segment_ids[j] == i'. + + Args: + data (tensor): a tensor, available data type float32, float64. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_mean(data, segment_ids) + #Outputs: [[2., 2., 2.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_mean", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "MEAN"}) + return out + + +def segment_min(data, segment_ids, name=None): + """ + Segment min operator. + + This operator calculate the minimum elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\min_{j} data_{j}$ + where min is over j such that `segment_ids[j] == i`. + + Args: + data (tensor): a tensor, available data type float32, float64. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_min(data, segment_ids) + #Outputs: [[1., 2., 1.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_min", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "MIN"}) + return out + + +def segment_max(data, segment_ids, name=None): + """ + Segment max operator. + + This operator calculate the maximum elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\min_{j} data_{j}$ + where max is over j such that `segment_ids[j] == i`. + + Args: + data (tensor): a tensor, available data type float32, float64. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_max(data, segment_ids) + #Outputs: [[3., 2., 3.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_max", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "MAX"}) + return out diff --git a/python/setup.py.in b/python/setup.py.in index 6d3e6201dc772356396871283f028be76b3180e2..1b2897f230fbeebd071a9332c7a0f871b9238827 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -162,6 +162,7 @@ packages=['paddle', 'paddle.incubate.optimizer', 'paddle.incubate.checkpoint', 'paddle.incubate.operators', + 'paddle.incubate.tensor', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic',