未验证 提交 f4c750d7 编写于 作者: Z Zhong Hui 提交者: GitHub

Add the cpu version of segment sum mean max min op

Add the cpu version of segment sum mean max min op
上级 afe94903
......@@ -92,7 +92,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling executor device_memory_aligment generator)
sequence_pooling segment_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
......
......@@ -76,6 +76,7 @@ math_library(prelu)
math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function)
math_library(matrix_inverse)
math_library(segment_pooling)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/segment_pooling.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename IndexT>
class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& segments, framework::Tensor* output,
framework::Tensor* index,
const std::string pooltype = "SUM") {
const IndexT* segment_ids = segments.data<IndexT>();
auto curent_id = segment_ids[0];
int64_t last_idx = 0;
int64_t w = input.numel() / input.dims()[0];
auto& place = *context.eigen_device();
for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
if (idx < segments.numel()) {
if (segment_ids[idx] == curent_id) continue;
PADDLE_ENFORCE_GE(segment_ids[idx], curent_id,
platform::errors::InvalidArgument(
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d.",
idx - 1, curent_id, idx, segment_ids[idx]));
}
Tensor out_t = output->Slice(curent_id, curent_id + 1);
Tensor in_t = input.Slice(last_idx, idx);
int64_t h = idx - last_idx;
auto in_e =
framework::EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = framework::EigenVector<T>::Flatten(out_t);
auto reduce_dim = Eigen::array<int, 1>({{0}});
if (pooltype == "MEAN") {
out_e.device(place) = in_e.mean(reduce_dim);
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(reduce_dim);
} else if (pooltype == "MAX") {
out_e.device(place) = in_e.maximum(reduce_dim);
} else if (pooltype == "MIN") {
out_e.device(place) = in_e.minimum(reduce_dim);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s.",
pooltype));
}
last_idx = idx;
if (idx < segments.numel()) curent_id = segment_ids[idx];
}
}
};
template <typename T, typename IndexT>
class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& out_grad,
const framework::Tensor& segments, framework::Tensor* in_grad,
const framework::Tensor* index = nullptr,
const std::string pooltype = "SUM") {
const IndexT* segment_ids = segments.data<IndexT>();
auto& place = *context.eigen_device();
auto curent_id = segment_ids[0];
int64_t last_idx = 0;
int64_t w = in_grad->numel() / in_grad->dims()[0];
for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
if (idx < segments.numel()) {
if (segment_ids[idx] == curent_id) continue;
PADDLE_ENFORCE_GE(segment_ids[idx], curent_id,
platform::errors::InvalidArgument(
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d.",
idx - 1, curent_id, idx, segment_ids[idx]));
}
Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1);
Tensor in_g_t = in_grad->Slice(last_idx, idx);
int64_t h = idx - last_idx;
auto in_g_e = framework::EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = framework::EigenMatrix<T>::From(out_g_t, {1, w});
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "MEAN") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = out_g_e.broadcast(bcast);
} else if (pooltype == "MAX" || pooltype == "MIN") {
Tensor out_t = output.Slice(curent_id, curent_id + 1);
Tensor in_t = input.Slice(last_idx, idx);
auto in_e = framework::EigenMatrix<T>::From(in_t, {h, w});
auto out_e = framework::EigenMatrix<T>::From(out_t, {1, w});
in_g_e.device(place) =
(in_e == out_e.broadcast(bcast)).template cast<T>() *
out_g_e.broadcast(bcast);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s.",
pooltype));
}
last_idx = idx;
if (idx < segments.numel()) curent_id = segment_ids[idx];
}
}
};
using CPU = platform::CPUDeviceContext;
template class SegmentPoolFunctor<CPU, float, int>;
template class SegmentPoolFunctor<CPU, float, int64_t>;
template class SegmentPoolFunctor<CPU, double, int>;
template class SegmentPoolFunctor<CPU, double, int64_t>;
template class SegmentPoolGradFunctor<CPU, float, int>;
template class SegmentPoolGradFunctor<CPU, float, int64_t>;
template class SegmentPoolGradFunctor<CPU, double, int>;
template class SegmentPoolGradFunctor<CPU, double, int64_t>;
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T, typename IndexT>
class SegmentPoolFunctor {
public:
/* mean pool has summed_ids output */
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& segments, framework::Tensor* output,
framework::Tensor* summed_ids = nullptr,
const std::string pooltype = "SUM");
};
template <typename DeviceContext, typename T, typename IndexT>
class SegmentPoolGradFunctor {
public:
/* mean pool has summed_ids output */
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& out_grad,
const framework::Tensor& segments, framework::Tensor* in_grad,
const framework::Tensor* summed_ids = nullptr,
const std::string pooltype = "SUM");
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/segment_pool_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
class SegmentPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPool");
OP_INOUT_CHECK(ctx->HasInput("SegmentIds"), "Input", "SegmentIds",
"SegmentPool");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SegmentPool");
auto dims = ctx->GetInputDim("X");
dims[0] = -1;
ctx->SetOutputDim("Out", dims);
if (ctx->Attrs().Get<std::string>("pooltype") == "MEAN") {
OP_INOUT_CHECK(ctx->HasOutput("SummedIds"), "Output", "SummedIds",
"SegmentPool");
ctx->SetOutputDim("SummedIds", {-1, 1});
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class SegmentPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input data of SegmentPoolOp");
AddInput("SegmentIds",
"(Tensor) 1-D tensor which have the same size with the fist "
"dimension of input X.");
AddOutput("Out", "(Tensor) The output of SegmentPoolOp.");
AddOutput("SummedIds",
"(Tensor) This tensor is used to counts of segment ids for the "
"backward of the mean pool.")
.AsIntermediate();
AddAttr<std::string>(
"pooltype",
"(string, default 'SUM') the pooling type of SegmentPoolOp.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddComment(R"DOC(
Segment Pool Operator.
This operator will pool the elements of input `X` which with the same index
in `SegmentIds`.
For SUM operation, it computes a tensor such that $Out_i = \sum_{j} X_{j}$
where sum is over j such that `SegmentIds[j] == i`.
For MEAN operation, it computes a tensor such that
$Out_i = \frac{1}{n_i} \sum_{j} X_{j}$ where sum is over j such that
`SegmentIds[j] == i` and $n_i$ is the number of all index `SegmentIds[j] == i`.
For MIN operation, it computes a tensor such that $Out_i = \min_{j} X_{j}$
where min is over j such that `SegmentIds[j] == i`.
For MAX operation, it computes a tensor such that $Out_i = \max_{j} X_{j}$
where max is over j such that `SegmentIds[j] == i`.
)DOC");
}
};
class SegmentPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "SegmentPoolGrad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPoolGrad");
auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
platform::errors::InvalidArgument(
"The rank of output grad must equal to Input(X). But "
"received: input rank %u, input shape [%s].",
og_dims.size(), og_dims));
for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
og_dims[i], x_dims[i],
platform::errors::InvalidArgument(
"The dimension mismatch between Input(OUT@GRAD) and "
"Input(X). Received Input(OUT@GRAD): input rank %u, "
"input shape [%s]; received Input(X): input rank %u, "
"input shape [%s].",
og_dims.size(), og_dims, x_dims.size(), x_dims));
}
ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op_desc_ptr) const override {
op_desc_ptr->SetType("segment_pool_grad");
op_desc_ptr->SetInput("X", this->Input("X"));
op_desc_ptr->SetInput("SegmentIds", this->Input("SegmentIds"));
op_desc_ptr->SetInput("Out", this->Output("Out"));
if (BOOST_GET_CONST(std::string, this->GetAttr("pooltype")) == "MEAN") {
op_desc_ptr->SetInput("SummedIds", this->Output("SummedIds"));
}
op_desc_ptr->SetInput(framework::GradVarName("Out"),
this->OutputGrad("Out"));
op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op_desc_ptr->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker,
ops::SegmentPoolGradOpMaker<paddle::framework::OpDesc>,
ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp);
REGISTER_OP_CPU_KERNEL(
segment_pool,
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
segment_pool_grad,
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T, typename IndexT>
void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
auto* input = context.Input<Tensor>("X");
auto* segment = context.Input<Tensor>("SegmentIds");
auto* output = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
Tensor* summed_ids = nullptr;
int64_t num_indices = segment->numel();
PADDLE_ENFORCE_EQ(
num_indices, input->dims()[0],
platform::errors::InvalidArgument(
"Segment_ids should be the same size as dimension 0 of input X."));
PADDLE_ENFORCE_EQ(num_indices, segment->dims()[0],
platform::errors::InvalidArgument(
"Segment_ids should be 1-D tensor, or it's other "
"dimension size is 1. Segment_ids's shape is: [%s].",
segment->dims()));
if (input->numel() == 0 || segment->numel() == 0) {
return;
}
bool cpu_place = context.GetPlace().type() == typeid(platform::CPUPlace);
if (cpu_place) {
auto dims = input->dims();
auto* segment_ids = segment->data<IndexT>();
dims[0] = static_cast<int64_t>(segment_ids[segment->numel() - 1] + 1);
PADDLE_ENFORCE_GT(
dims[0], 0,
platform::errors::InvalidArgument(
"Segment ids must be >= 0, but got last id %d", dims[0]));
output->Resize({dims});
output->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, output, static_cast<T>(0));
}
SegmentPoolFunctor<DeviceContext, T, IndexT> pool;
pool(context.template device_context<DeviceContext>(), *input, *segment,
output, summed_ids, pooltype);
}
template <typename DeviceContext, typename T>
class SegmentPoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* segment = context.Input<Tensor>("SegmentIds");
auto index_type = segment->type();
if (index_type == framework::proto::VarType::INT32) {
SegmentKernelLaunchHelper<DeviceContext, T, int>(context);
} else if (index_type == framework::proto::VarType::INT64) {
SegmentKernelLaunchHelper<DeviceContext, T, int64_t>(context);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported index type, Expected int, int64, but got %s.",
index_type));
}
}
};
template <typename DeviceContext, typename T>
class SegmentPoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Input<Tensor>("Out");
auto* segment = context.Input<Tensor>("SegmentIds");
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooltype = context.Attr<std::string>("pooltype");
const Tensor* summed_ids = nullptr;
if (pooltype == "MEAN") {
summed_ids = context.Input<Tensor>("SummedIds");
}
in_g->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, in_g, static_cast<T>(0));
auto index_type = segment->type();
if (index_type == framework::proto::VarType::INT32) {
SegmentPoolGradFunctor<DeviceContext, T, int> pool;
pool(context.template device_context<DeviceContext>(), *input, *output,
*out_g, *segment, in_g, summed_ids, pooltype);
} else if (index_type == framework::proto::VarType::INT64) {
SegmentPoolGradFunctor<DeviceContext, T, int64_t> pool;
pool(context.template device_context<DeviceContext>(), *input, *output,
*out_g, *segment, in_g, summed_ids, pooltype);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported index type, Expected int, int64, but got %s.",
index_type));
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
from op_test import OpTest
def compute_segment_sum(x, segment_ids):
length = segment_ids[-1] + 1
target_shape = list(x.shape)
target_shape[0] = length
results = np.zeros(target_shape, dtype=x.dtype)
for index, ids in enumerate(segment_ids):
results[ids, :] += x[index, :]
return results
def compute_segment_mean(x, segment_ids):
length = segment_ids[-1] + 1
target_shape = list(x.shape)
target_shape[0] = length
results = np.zeros(target_shape, dtype=x.dtype)
count = np.zeros(length, dtype=x.dtype) + 1e-8
for index, ids in enumerate(segment_ids):
results[ids, :] += x[index, :]
count[ids] += 1
results = results / count.reshape([-1, 1])
return results
def compute_segment_min_max(x, segment_ids, pooltype="MAX"):
length = segment_ids[-1] + 1
target_shape = list(x.shape)
target_shape[0] = length
gradient = np.zeros_like(x)
results = np.zeros(target_shape, dtype=x.dtype)
last_idx = 0
current_id = segment_ids[0]
for idx in range(1, len(segment_ids) + 1):
if idx < len(segment_ids):
if segment_ids[idx] == current_id:
continue
sub_x = x[last_idx:idx, :]
if pooltype == "MAX":
results[current_id] = np.amax(sub_x, axis=0)
elif pooltype == "MIN":
results[current_id] = np.amin(sub_x, axis=0)
else:
raise ValueError("Invalid pooltype, only MAX, MIN supported!")
gradient[last_idx:idx, :][sub_x == results[current_id]] = 1
last_idx = idx
if idx < len(segment_ids):
current_id = segment_ids[idx]
return results, gradient / results.size
class TestSegmentOps(OpTest):
def set_data(self):
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
segment_ids = self.set_segment(len(x), len(x) // 5 + 1)
return x, segment_ids
def set_segment(self, origin_len, reduce_len):
segment = np.zeros(reduce_len, dtype='int64')
segment = np.random.randint(0, reduce_len, size=[origin_len])
segment = np.sort(segment)
return segment.astype('int64')
def compute(self, x, segment_ids):
return compute_segment_sum(x, segment_ids)
def prepare(self):
self.op_type = "segment_pool"
self.dtype = np.float64
self.shape = [30, 15]
self.attrs = {"pooltype": "SUM"}
def setUp(self):
self.prepare()
x, segment_ids = self.set_data()
result = self.compute(x, segment_ids)
self.inputs = {
'X': x.astype(self.dtype),
'SegmentIds': segment_ids.astype(np.int64)
}
self.outputs = {'Out': result.astype(self.dtype)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSegmentSum2(TestSegmentOps):
def prepare(self):
super(TestSegmentSum2, self).prepare()
self.shape = [40, 20]
self.dtype = np.float32
def setUp(self):
self.prepare()
x, segment_ids = self.set_data()
result = self.compute(x, segment_ids)
self.inputs = {
'X': x.astype(self.dtype),
'SegmentIds': segment_ids.astype(np.int32)
}
self.outputs = {'Out': result.astype(self.dtype)}
class TestSegmentMax(TestSegmentOps):
def compute(self, x, segment_ids):
return compute_segment_min_max(x, segment_ids, pooltype="MAX")
def prepare(self):
super(TestSegmentMax, self).prepare()
self.shape = [40, 20]
self.attrs = {'pooltype': "MAX"}
def setUp(self):
self.prepare()
x, segment_ids = self.set_data()
result, self.gradient = self.compute(x, segment_ids)
self.inputs = {
'X': x.astype(self.dtype),
'SegmentIds': segment_ids.astype(np.int32)
}
self.outputs = {'Out': result.astype(self.dtype)}
def test_check_grad(self):
self.check_grad(["X"], "Out", user_defined_grads=[self.gradient])
class TestSegmentMax2(TestSegmentMax):
def prepare(self):
super(TestSegmentMax2, self).prepare()
self.dtype = np.float32
class TestSegmentMin(TestSegmentMax):
def compute(self, x, segment_ids):
return compute_segment_min_max(x, segment_ids, pooltype="MIN")
def prepare(self):
super(TestSegmentMin, self).prepare()
self.attrs = {'pooltype': "MIN"}
class TestSegmentMin2(TestSegmentMin):
def prepare(self):
super(TestSegmentMin2, self).prepare()
self.dtype = np.float32
class TestSegmentMean(TestSegmentOps):
def compute(self, x, segment_ids):
return compute_segment_mean(x, segment_ids)
def prepare(self):
super(TestSegmentMean, self).prepare()
self.shape = [40, 20]
self.attrs = {'pooltype': "MEAN"}
def setUp(self):
self.prepare()
x, segment_ids = self.set_data()
result = self.compute(x, segment_ids)
self.inputs = {'X': x, 'SegmentIds': segment_ids}
self.outputs = {
'Out': result,
'SummedIds': compute_segment_sum(
np.ones([len(x), 1]).astype(self.dtype), segment_ids)
}
class TestSegmentMean2(TestSegmentMean):
def prepare(self):
super(TestSegmentMean2, self).prepare()
self.dtype = np.float32
self.shape = [30, 20]
self.attrs = {'pooltype': "MEAN"}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册