From f4c750d721a1226738bea382f6c0cf725cca8481 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 22 Sep 2020 10:28:42 +0800 Subject: [PATCH] Add the cpu version of segment sum mean max min op Add the cpu version of segment sum mean max min op --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 1 + .../fluid/operators/math/segment_pooling.cc | 148 +++++++++++++ paddle/fluid/operators/math/segment_pooling.h | 46 ++++ paddle/fluid/operators/segment_pool_op.cc | 166 ++++++++++++++ paddle/fluid/operators/segment_pool_op.h | 130 +++++++++++ .../fluid/tests/unittests/test_segment_ops.py | 202 ++++++++++++++++++ 7 files changed, 694 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/math/segment_pooling.cc create mode 100644 paddle/fluid/operators/math/segment_pooling.h create mode 100644 paddle/fluid/operators/segment_pool_op.cc create mode 100644 paddle/fluid/operators/segment_pool_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_segment_ops.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f0a04d850d..53e6f4aa6e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -92,7 +92,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project -sequence_pooling executor device_memory_aligment generator) +sequence_pooling segment_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 10d335b828..24ed4fcf66 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,6 +76,7 @@ math_library(prelu) math_library(bert_encoder_functor) math_library(tree2col DEPS math_function) math_library(matrix_inverse) +math_library(segment_pooling) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/segment_pooling.cc b/paddle/fluid/operators/math/segment_pooling.cc new file mode 100644 index 0000000000..3c77d3d4cf --- /dev/null +++ b/paddle/fluid/operators/math/segment_pooling.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/segment_pooling.h" +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SegmentPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& segments, framework::Tensor* output, + framework::Tensor* index, + const std::string pooltype = "SUM") { + const IndexT* segment_ids = segments.data(); + auto curent_id = segment_ids[0]; + int64_t last_idx = 0; + int64_t w = input.numel() / input.dims()[0]; + auto& place = *context.eigen_device(); + for (int64_t idx = 1; idx <= segments.numel(); ++idx) { + if (idx < segments.numel()) { + if (segment_ids[idx] == curent_id) continue; + PADDLE_ENFORCE_GE(segment_ids[idx], curent_id, + platform::errors::InvalidArgument( + "The segment ids should be sorted, but got " + "segment_ids[%d]:%d > segment_ids[%d]:%d.", + idx - 1, curent_id, idx, segment_ids[idx])); + } + + Tensor out_t = output->Slice(curent_id, curent_id + 1); + Tensor in_t = input.Slice(last_idx, idx); + + int64_t h = idx - last_idx; + auto in_e = + framework::EigenMatrix::From(in_t, framework::make_ddim({h, w})); + auto out_e = framework::EigenVector::Flatten(out_t); + + auto reduce_dim = Eigen::array({{0}}); + if (pooltype == "MEAN") { + out_e.device(place) = in_e.mean(reduce_dim); + } else if (pooltype == "SUM") { + out_e.device(place) = in_e.sum(reduce_dim); + } else if (pooltype == "MAX") { + out_e.device(place) = in_e.maximum(reduce_dim); + } else if (pooltype == "MIN") { + out_e.device(place) = in_e.minimum(reduce_dim); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN " + "available, but got %s.", + pooltype)); + } + + last_idx = idx; + if (idx < segments.numel()) curent_id = segment_ids[idx]; + } + } +}; + +template +class SegmentPoolGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& out_grad, + const framework::Tensor& segments, framework::Tensor* in_grad, + const framework::Tensor* index = nullptr, + const std::string pooltype = "SUM") { + const IndexT* segment_ids = segments.data(); + auto& place = *context.eigen_device(); + auto curent_id = segment_ids[0]; + int64_t last_idx = 0; + int64_t w = in_grad->numel() / in_grad->dims()[0]; + for (int64_t idx = 1; idx <= segments.numel(); ++idx) { + if (idx < segments.numel()) { + if (segment_ids[idx] == curent_id) continue; + PADDLE_ENFORCE_GE(segment_ids[idx], curent_id, + platform::errors::InvalidArgument( + "The segment ids should be sorted, but got " + "segment_ids[%d]:%d > segment_ids[%d]:%d.", + idx - 1, curent_id, idx, segment_ids[idx])); + } + + Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1); + Tensor in_g_t = in_grad->Slice(last_idx, idx); + + int64_t h = idx - last_idx; + auto in_g_e = framework::EigenMatrix::From(in_g_t, {h, w}); + auto out_g_e = framework::EigenMatrix::From(out_g_t, {1, w}); + Eigen::DSizes bcast(h, 1); + + if (pooltype == "MEAN") { + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + } else if (pooltype == "SUM") { + in_g_e.device(place) = out_g_e.broadcast(bcast); + } else if (pooltype == "MAX" || pooltype == "MIN") { + Tensor out_t = output.Slice(curent_id, curent_id + 1); + Tensor in_t = input.Slice(last_idx, idx); + auto in_e = framework::EigenMatrix::From(in_t, {h, w}); + auto out_e = framework::EigenMatrix::From(out_t, {1, w}); + in_g_e.device(place) = + (in_e == out_e.broadcast(bcast)).template cast() * + out_g_e.broadcast(bcast); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN " + "available, but got %s.", + pooltype)); + } + + last_idx = idx; + if (idx < segments.numel()) curent_id = segment_ids[idx]; + } + } +}; + +using CPU = platform::CPUDeviceContext; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/segment_pooling.h b/paddle/fluid/operators/math/segment_pooling.h new file mode 100644 index 0000000000..561fad6921 --- /dev/null +++ b/paddle/fluid/operators/math/segment_pooling.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +template +class SegmentPoolFunctor { + public: + /* mean pool has summed_ids output */ + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& segments, framework::Tensor* output, + framework::Tensor* summed_ids = nullptr, + const std::string pooltype = "SUM"); +}; + +template +class SegmentPoolGradFunctor { + public: + /* mean pool has summed_ids output */ + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& out_grad, + const framework::Tensor& segments, framework::Tensor* in_grad, + const framework::Tensor* summed_ids = nullptr, + const std::string pooltype = "SUM"); +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/segment_pool_op.cc b/paddle/fluid/operators/segment_pool_op.cc new file mode 100644 index 0000000000..322cd97f01 --- /dev/null +++ b/paddle/fluid/operators/segment_pool_op.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/segment_pool_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class SegmentPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPool"); + OP_INOUT_CHECK(ctx->HasInput("SegmentIds"), "Input", "SegmentIds", + "SegmentPool"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SegmentPool"); + auto dims = ctx->GetInputDim("X"); + dims[0] = -1; + ctx->SetOutputDim("Out", dims); + + if (ctx->Attrs().Get("pooltype") == "MEAN") { + OP_INOUT_CHECK(ctx->HasOutput("SummedIds"), "Output", "SummedIds", + "SegmentPool"); + ctx->SetOutputDim("SummedIds", {-1, 1}); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class SegmentPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input data of SegmentPoolOp"); + AddInput("SegmentIds", + "(Tensor) 1-D tensor which have the same size with the fist " + "dimension of input X."); + AddOutput("Out", "(Tensor) The output of SegmentPoolOp."); + AddOutput("SummedIds", + "(Tensor) This tensor is used to counts of segment ids for the " + "backward of the mean pool.") + .AsIntermediate(); + AddAttr( + "pooltype", + "(string, default 'SUM') the pooling type of SegmentPoolOp.") + .SetDefault("SUM") + .InEnum({"SUM", "MEAN", "MIN", "MAX"}); + AddComment(R"DOC( +Segment Pool Operator. + +This operator will pool the elements of input `X` which with the same index +in `SegmentIds`. + +For SUM operation, it computes a tensor such that $Out_i = \sum_{j} X_{j}$ +where sum is over j such that `SegmentIds[j] == i`. + +For MEAN operation, it computes a tensor such that +$Out_i = \frac{1}{n_i} \sum_{j} X_{j}$ where sum is over j such that +`SegmentIds[j] == i` and $n_i$ is the number of all index `SegmentIds[j] == i`. + +For MIN operation, it computes a tensor such that $Out_i = \min_{j} X_{j}$ +where min is over j such that `SegmentIds[j] == i`. + +For MAX operation, it computes a tensor such that $Out_i = \max_{j} X_{j}$ +where max is over j such that `SegmentIds[j] == i`. + )DOC"); + } +}; + +class SegmentPoolGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "SegmentPoolGrad"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPoolGrad"); + auto og_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), + platform::errors::InvalidArgument( + "The rank of output grad must equal to Input(X). But " + "received: input rank %u, input shape [%s].", + og_dims.size(), og_dims)); + for (int64_t i = 1; i < og_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + og_dims[i], x_dims[i], + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims.size(), og_dims, x_dims.size(), x_dims)); + } + + ctx->ShareDim("X", /*->*/ framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op_desc_ptr) const override { + op_desc_ptr->SetType("segment_pool_grad"); + op_desc_ptr->SetInput("X", this->Input("X")); + op_desc_ptr->SetInput("SegmentIds", this->Input("SegmentIds")); + op_desc_ptr->SetInput("Out", this->Output("Out")); + if (BOOST_GET_CONST(std::string, this->GetAttr("pooltype")) == "MEAN") { + op_desc_ptr->SetInput("SummedIds", this->Output("SummedIds")); + } + op_desc_ptr->SetInput(framework::GradVarName("Out"), + this->OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op_desc_ptr->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker, + ops::SegmentPoolGradOpMaker, + ops::SegmentPoolGradOpMaker); +REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp); + +REGISTER_OP_CPU_KERNEL( + segment_pool, + ops::SegmentPoolKernel, + ops::SegmentPoolKernel); + +REGISTER_OP_CPU_KERNEL( + segment_pool_grad, + ops::SegmentPoolGradKernel, + ops::SegmentPoolGradKernel); diff --git a/paddle/fluid/operators/segment_pool_op.h b/paddle/fluid/operators/segment_pool_op.h new file mode 100644 index 0000000000..a505946b9f --- /dev/null +++ b/paddle/fluid/operators/segment_pool_op.h @@ -0,0 +1,130 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/segment_pooling.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) { + auto* input = context.Input("X"); + auto* segment = context.Input("SegmentIds"); + auto* output = context.Output("Out"); + std::string pooltype = context.Attr("pooltype"); + Tensor* summed_ids = nullptr; + + int64_t num_indices = segment->numel(); + PADDLE_ENFORCE_EQ( + num_indices, input->dims()[0], + platform::errors::InvalidArgument( + "Segment_ids should be the same size as dimension 0 of input X.")); + PADDLE_ENFORCE_EQ(num_indices, segment->dims()[0], + platform::errors::InvalidArgument( + "Segment_ids should be 1-D tensor, or it's other " + "dimension size is 1. Segment_ids's shape is: [%s].", + segment->dims())); + + if (input->numel() == 0 || segment->numel() == 0) { + return; + } + + bool cpu_place = context.GetPlace().type() == typeid(platform::CPUPlace); + if (cpu_place) { + auto dims = input->dims(); + auto* segment_ids = segment->data(); + dims[0] = static_cast(segment_ids[segment->numel() - 1] + 1); + PADDLE_ENFORCE_GT( + dims[0], 0, + platform::errors::InvalidArgument( + "Segment ids must be >= 0, but got last id %d", dims[0])); + output->Resize({dims}); + output->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, output, static_cast(0)); + } + + SegmentPoolFunctor pool; + + pool(context.template device_context(), *input, *segment, + output, summed_ids, pooltype); +} + +template +class SegmentPoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* segment = context.Input("SegmentIds"); + auto index_type = segment->type(); + if (index_type == framework::proto::VarType::INT32) { + SegmentKernelLaunchHelper(context); + } else if (index_type == framework::proto::VarType::INT64) { + SegmentKernelLaunchHelper(context); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported index type, Expected int, int64, but got %s.", + index_type)); + } + } +}; + +template +class SegmentPoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Input("Out"); + auto* segment = context.Input("SegmentIds"); + auto* out_g = context.Input(framework::GradVarName("Out")); + auto* in_g = context.Output(framework::GradVarName("X")); + std::string pooltype = context.Attr("pooltype"); + + const Tensor* summed_ids = nullptr; + if (pooltype == "MEAN") { + summed_ids = context.Input("SummedIds"); + } + + in_g->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, in_g, static_cast(0)); + + auto index_type = segment->type(); + if (index_type == framework::proto::VarType::INT32) { + SegmentPoolGradFunctor pool; + pool(context.template device_context(), *input, *output, + *out_g, *segment, in_g, summed_ids, pooltype); + } else if (index_type == framework::proto::VarType::INT64) { + SegmentPoolGradFunctor pool; + pool(context.template device_context(), *input, *output, + *out_g, *segment, in_g, summed_ids, pooltype); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported index type, Expected int, int64, but got %s.", + index_type)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_segment_ops.py b/python/paddle/fluid/tests/unittests/test_segment_ops.py new file mode 100644 index 0000000000..b58d66676b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_segment_ops.py @@ -0,0 +1,202 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +from op_test import OpTest + + +def compute_segment_sum(x, segment_ids): + length = segment_ids[-1] + 1 + target_shape = list(x.shape) + target_shape[0] = length + results = np.zeros(target_shape, dtype=x.dtype) + for index, ids in enumerate(segment_ids): + results[ids, :] += x[index, :] + return results + + +def compute_segment_mean(x, segment_ids): + length = segment_ids[-1] + 1 + target_shape = list(x.shape) + target_shape[0] = length + results = np.zeros(target_shape, dtype=x.dtype) + count = np.zeros(length, dtype=x.dtype) + 1e-8 + for index, ids in enumerate(segment_ids): + results[ids, :] += x[index, :] + count[ids] += 1 + results = results / count.reshape([-1, 1]) + return results + + +def compute_segment_min_max(x, segment_ids, pooltype="MAX"): + length = segment_ids[-1] + 1 + target_shape = list(x.shape) + target_shape[0] = length + gradient = np.zeros_like(x) + results = np.zeros(target_shape, dtype=x.dtype) + last_idx = 0 + current_id = segment_ids[0] + for idx in range(1, len(segment_ids) + 1): + if idx < len(segment_ids): + if segment_ids[idx] == current_id: + continue + sub_x = x[last_idx:idx, :] + if pooltype == "MAX": + results[current_id] = np.amax(sub_x, axis=0) + elif pooltype == "MIN": + results[current_id] = np.amin(sub_x, axis=0) + else: + raise ValueError("Invalid pooltype, only MAX, MIN supported!") + gradient[last_idx:idx, :][sub_x == results[current_id]] = 1 + last_idx = idx + if idx < len(segment_ids): + current_id = segment_ids[idx] + + return results, gradient / results.size + + +class TestSegmentOps(OpTest): + def set_data(self): + x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + segment_ids = self.set_segment(len(x), len(x) // 5 + 1) + return x, segment_ids + + def set_segment(self, origin_len, reduce_len): + segment = np.zeros(reduce_len, dtype='int64') + segment = np.random.randint(0, reduce_len, size=[origin_len]) + segment = np.sort(segment) + return segment.astype('int64') + + def compute(self, x, segment_ids): + return compute_segment_sum(x, segment_ids) + + def prepare(self): + self.op_type = "segment_pool" + self.dtype = np.float64 + self.shape = [30, 15] + self.attrs = {"pooltype": "SUM"} + + def setUp(self): + self.prepare() + x, segment_ids = self.set_data() + result = self.compute(x, segment_ids) + self.inputs = { + 'X': x.astype(self.dtype), + 'SegmentIds': segment_ids.astype(np.int64) + } + self.outputs = {'Out': result.astype(self.dtype)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSegmentSum2(TestSegmentOps): + def prepare(self): + super(TestSegmentSum2, self).prepare() + self.shape = [40, 20] + self.dtype = np.float32 + + def setUp(self): + self.prepare() + x, segment_ids = self.set_data() + result = self.compute(x, segment_ids) + self.inputs = { + 'X': x.astype(self.dtype), + 'SegmentIds': segment_ids.astype(np.int32) + } + self.outputs = {'Out': result.astype(self.dtype)} + + +class TestSegmentMax(TestSegmentOps): + def compute(self, x, segment_ids): + return compute_segment_min_max(x, segment_ids, pooltype="MAX") + + def prepare(self): + super(TestSegmentMax, self).prepare() + self.shape = [40, 20] + self.attrs = {'pooltype': "MAX"} + + def setUp(self): + self.prepare() + x, segment_ids = self.set_data() + result, self.gradient = self.compute(x, segment_ids) + self.inputs = { + 'X': x.astype(self.dtype), + 'SegmentIds': segment_ids.astype(np.int32) + } + self.outputs = {'Out': result.astype(self.dtype)} + + def test_check_grad(self): + self.check_grad(["X"], "Out", user_defined_grads=[self.gradient]) + + +class TestSegmentMax2(TestSegmentMax): + def prepare(self): + super(TestSegmentMax2, self).prepare() + self.dtype = np.float32 + + +class TestSegmentMin(TestSegmentMax): + def compute(self, x, segment_ids): + return compute_segment_min_max(x, segment_ids, pooltype="MIN") + + def prepare(self): + super(TestSegmentMin, self).prepare() + self.attrs = {'pooltype': "MIN"} + + +class TestSegmentMin2(TestSegmentMin): + def prepare(self): + super(TestSegmentMin2, self).prepare() + self.dtype = np.float32 + + +class TestSegmentMean(TestSegmentOps): + def compute(self, x, segment_ids): + return compute_segment_mean(x, segment_ids) + + def prepare(self): + super(TestSegmentMean, self).prepare() + self.shape = [40, 20] + self.attrs = {'pooltype': "MEAN"} + + def setUp(self): + self.prepare() + x, segment_ids = self.set_data() + result = self.compute(x, segment_ids) + self.inputs = {'X': x, 'SegmentIds': segment_ids} + self.outputs = { + 'Out': result, + 'SummedIds': compute_segment_sum( + np.ones([len(x), 1]).astype(self.dtype), segment_ids) + } + + +class TestSegmentMean2(TestSegmentMean): + def prepare(self): + super(TestSegmentMean2, self).prepare() + self.dtype = np.float32 + self.shape = [30, 20] + self.attrs = {'pooltype': "MEAN"} + + +if __name__ == '__main__': + unittest.main() -- GitLab