未验证 提交 48410b9b 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15237 from tensor-tang/fuse/seqpool_concat_2

Fuse/seqpool concat 2
......@@ -42,6 +42,7 @@ pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#define MAX_CONCAT_INPUTS 200
namespace paddle {
namespace framework {
namespace ir {
PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
const std::string& name_scope,
int num_inputs) {
auto is_concat_op_with_inputs = [](Node* x, int num) -> bool {
return x && x->IsOp() && x->Op()->Type() == "concat" &&
x->Op()->Input("X").size() == static_cast<size_t>(num);
};
auto is_nth_input_var_of_concat = [=](Node* x, int idx) -> bool {
return x && x->IsVar() && VarLinksToOp(x, "concat") &&
x->outputs.size() == 1 && IsNthInput(x, x->outputs[0], "X", idx) &&
is_concat_op_with_inputs(x->outputs[0], num_inputs);
};
auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=](
Node* x, const std::string& type, int idx) -> bool {
bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
x->Op()->HasAttr("pooltype") &&
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
x->outputs.size() == 2; // seqpool should only have 2 outputs
if (ok) {
// only one output of seqpool_op is nth_input_var of concat
// the other one should be unused empty var
if (is_nth_input_var_of_concat(x->outputs[0], idx)) {
ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0;
} else {
ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) &&
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
}
}
return ok;
};
auto* concat_op = pattern->NewNode(
[=](Node* x) { return is_concat_op_with_inputs(x, num_inputs); },
name_scope + "/concat_op");
concat_op->assert_op_attr<int>("axis", 1);
auto* concat_out_var = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && VarLinksFromOp(x, "concat") &&
x->inputs.size() == 1 &&
is_concat_op_with_inputs(x->inputs[0], num_inputs);
},
name_scope + "/concat_out_var");
concat_out_var->assert_is_only_output_of_op("concat");
std::vector<PDNode*> seqpool_ops_input_var(num_inputs);
std::vector<PDNode*> seqpool_ops_output_var(num_inputs);
std::vector<PDNode*> seqpool_ops(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
seqpool_ops_output_var[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && is_nth_input_var_of_concat(x, i) &&
x->inputs.size() == 1 &&
is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0],
"SUM", i);
},
name_scope + "/sequence_pool_out_" + std::to_string(i));
seqpool_ops[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() &&
is_seqpool_op_with_pootype_of_nth_input_of_concat(x, "SUM", i);
},
name_scope + "/sequence_pool_op_" + std::to_string(i));
seqpool_ops_input_var[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && x->outputs.size() >= 1 &&
is_seqpool_op_with_pootype_of_nth_input_of_concat(
x->outputs[0], "SUM", i);
},
name_scope + "/sequence_pool_in_" + std::to_string(i));
// Links
seqpool_ops[i]
->LinksFrom({seqpool_ops_input_var[i]})
.LinksTo({seqpool_ops_output_var[i]});
}
concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var});
return concat_out_var;
}
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
int num_inputs) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs);
auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
"pattern has no Node called %s", name.c_str());
Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str());
return p;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle SeqPool Concat fuse";
std::vector<std::string> input_names(num_inputs);
std::vector<Node*> input_vars(num_inputs);
auto& fused_pattern = gpd.pattern();
for (int i = 0; i < num_inputs; ++i) {
input_vars[i] =
retrieve_node(name_scope + "/sequence_pool_in_" + std::to_string(i),
subgraph, fused_pattern);
input_names[i] = input_vars[i]->Name();
}
auto* concat_op =
retrieve_node(name_scope + "/concat_op", subgraph, fused_pattern);
auto* concat_out_var =
retrieve_node(name_scope + "/concat_out_var", subgraph, fused_pattern);
auto* seqpool_op0 = retrieve_node(name_scope + "/sequence_pool_op_0",
subgraph, fused_pattern);
// Create New OpDesc
OpDesc op_desc;
op_desc.SetType("fusion_seqpool_concat");
op_desc.SetInput("X", input_names);
op_desc.SetAttr("pooltype", seqpool_op0->Op()->GetAttr("pooltype"));
op_desc.SetAttr("axis", concat_op->Op()->GetAttr("axis"));
op_desc.SetOutput("Out", {concat_out_var->Name()});
auto* op = graph->CreateOpNode(&op_desc);
for (size_t i = 0; i < input_vars.size(); ++i) {
IR_NODE_LINK_TO(input_vars[i], op);
}
IR_NODE_LINK_TO(op, concat_out_var);
std::unordered_set<const Node*> marked_nodes;
for (auto& item : subgraph) {
marked_nodes.insert(item.second);
}
for (size_t i = 0; i < input_vars.size(); ++i) {
marked_nodes.erase(input_vars[i]);
}
marked_nodes.erase(concat_out_var);
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> SeqPoolConcatFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = 0;
for (int i = MAX_CONCAT_INPUTS; i > 0; --i) {
fusion_count += BuildFusion(
graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i);
}
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seqpool_concat_fuse_pass,
paddle::framework::ir::SeqPoolConcatFusePass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqPoolConcatFusePass : public FusePassBase {
public:
virtual ~SeqPoolConcatFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"seqpool_concat_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -89,6 +89,7 @@ class CpuPassStrategy : public PassStrategy {
passes_.assign({
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"seqpool_concat_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
// "embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
......
......@@ -177,8 +177,12 @@ TEST(Analyzer_seq_pool1, fuse_statis) {
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 349);
EXPECT_EQ(num_ops, 195);
}
} // namespace analysis
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_seqpool_concat_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace operators {
void FusionSeqPoolConcatOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqPoolConcatOp should be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqPoolConcatOp should not be null.");
int axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE_EQ(axis, 1,
"FusionSeqPoolConcatOp only supports concat axis=1 yet.");
auto ins_dims = ctx->GetInputsDim("X");
const size_t n = ins_dims.size();
PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0.");
if (n == 1) {
LOG(WARNING) << "Only have one input, may waste memory";
}
// The output height should be confirmed in Compute,
// since input lod is not accessible here.
PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2UL,
"The dims size of first input should be 2.");
ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
}
framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
}
void FusionSeqPoolConcatOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
AddAttr<std::string>("pooltype",
"(string, default 'AVERAGE') some of the pooling "
"pooltype of SequencePoolOp.")
.SetDefault("SUM")
.InEnum({"AVERAGE", "SUM", "SQRT"});
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
.SetDefault(1);
AddComment(R"DOC(
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
)DOC");
}
template <typename T>
class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
std::string pooltype = ctx.Attr<std::string>("pooltype");
auto x0_lod = ins[0]->lod();
auto x0_dims = ins[0]->dims();
auto y_dims = out->dims();
size_t bs = x0_lod[0].size() - 1;
out->Resize({static_cast<int64_t>(bs), y_dims[1]});
framework::LoD y_lod(1);
y_lod[0].resize(bs + 1);
for (size_t i = 0; i <= bs; ++i) {
y_lod[0][i] = i;
}
out->set_lod(y_lod);
auto place = ctx.GetPlace();
T* y_data = out->mutable_data<T>(place);
int w = ins[0]->numel() / x0_dims[0];
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
"The output of dims[1] should be dividable of w");
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
if (pooltype == "AVERAGE") {
attr.type = jit::SeqPoolType::kAvg;
} else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt;
}
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
size_t n = ins.size();
for (size_t i = 0; i < n; ++i) {
auto x_dims = ins[i]->dims();
auto x_lod = ins[i]->lod()[0];
const T* src = ins[i]->data<T>();
T* dst = y_data + i * w;
PADDLE_ENFORCE_EQ(static_cast<int>(ins[i]->numel() / x_dims[0]), w,
"Width of all inputs should be equal.");
PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1,
"Batchsize of all inputs should be equal.");
for (size_t j = 0; j < bs; ++j) {
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
seqpool(src, dst, &attr);
dst += n * w;
src += attr.h * attr.w;
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqpool_concat, ops::FusionSeqPoolConcatOp,
ops::FusionSeqPoolConcatOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seqpool_concat,
ops::FusionSeqPoolConcatKernel<float>,
ops::FusionSeqPoolConcatKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqPoolConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqPoolConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_reorder_lod_tensor import convert_to_offset
from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt
class TestFusionSeqPoolConcatOp(OpTest):
def setUp(self):
self.w = 11
self.lods = [[[2, 3, 5]], [[1, 5, 2]]]
self.set_conf()
self.set_pooltype()
self.op_type = 'fusion_seqpool_concat'
self.axis = 1
bs = len(self.lods[0][0])
inputs = []
outs = []
i = 0
for lod in self.lods:
assert bs == len(lod[0]), 'All lod size should be equal'
x = np.random.uniform(0.1, 1,
[sum(lod[0]), self.w]).astype('float32')
offset = convert_to_offset(lod)
out = np.zeros((bs, self.w)).astype('float32')
if self.pooltype == "SUM":
compute_seqpool_sum(x, offset, out)
elif self.pooltype == "AVERAGE":
compute_seqpool_avg(x, offset, out)
elif self.pooltype == "SQRT":
compute_seqpool_sqrt(x, offset, out)
else:
raise Exception("Unsupported pool type!")
inputs.append(('x_{0}'.format(i), (x, lod)))
outs.append(out)
i = i + 1
self.inputs = {'X': inputs}
self.outputs = {'Out': np.concatenate(outs, axis=self.axis)}
self.attrs = {
'pooltype': self.pooltype,
'axis': self.axis,
}
def set_pooltype(self):
self.pooltype = "SUM"
def set_conf(self):
pass
def test_check_output(self):
self.check_output()
class TestFusionSeqPoolConcatOpCase1(TestFusionSeqPoolConcatOp):
def set_conf(self):
self.lods = [[[1]]]
class TestFusionSeqPoolConcatOpCase2(TestFusionSeqPoolConcatOp):
def set_conf(self):
self.lods = [[[1]], [[1]], [[1]]]
class TestFusionSeqPoolConcatOpCase3(TestFusionSeqPoolConcatOp):
def set_conf(self):
self.lods = [[[1, 3, 4, 6]]]
self.w = 10
class TestFusionSeqPoolConcatOpCase4(TestFusionSeqPoolConcatOp):
def set_conf(self):
self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]]
self.w = 3
## test avg pool and sqrt
def create_test_avg_sqrt_class(parent):
class TestSeqPoolAvgCase(parent):
def set_pooltype(self):
self.pooltype = "AVERAGE"
class TestSeqPoolSqrtCase(parent):
def set_pooltype(self):
self.pooltype = "SQRT"
cls_name_avg = "{0}_{1}".format(parent.__name__, "avg")
cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt")
TestSeqPoolAvgCase.__name__ = cls_name_avg
TestSeqPoolSqrtCase.__name__ = cls_name_sqrt
globals()[cls_name_avg] = TestSeqPoolAvgCase
globals()[cls_name_sqrt] = TestSeqPoolSqrtCase
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOp)
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase1)
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase2)
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase3)
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase4)
if __name__ == '__main__':
unittest.main()
......@@ -22,6 +22,14 @@ import numpy
import functools
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestReorderLoDTensor(unittest.TestCase):
num_seq = 5
# [name, shape, lod_level] pair indicating data info of source and target
......@@ -91,13 +99,6 @@ class TestReorderLoDTensor(unittest.TestCase):
self.inputs[desc[0]] = tensor
def reorder(self):
def convert_to_offset(lod):
offset_lod = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset_lod[i].append(offset_lod[i][-1] + seq_len)
return offset_lod
level = 0
# compute the rank_table according to ref_lod
ref_lod = self.data[self.data_desc[1][0]][1][level]
......
......@@ -17,33 +17,43 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_reorder_lod_tensor import convert_to_offset
class TestSeqAvgPool(OpTest):
def convert_to_offset(self, lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
def compute_seqpool_sum(x, offset, out):
for i in range(len(offset[0]) - 1):
sub_x = x[offset[0][i]:offset[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
def compute_seqpool_avg(x, offset, out):
for i in range(len(offset[0]) - 1):
sub_x = x[offset[0][i]:offset[0][i + 1], :]
out[i] = sub_x.mean(axis=0)
def compute_seqpool_sqrt(x, offset, out):
for i in range(len(offset[0]) - 1):
sub_x = x[offset[0][i]:offset[0][i + 1], :]
seq_len = offset[0][i + 1] - offset[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len)
class TestSeqAvgPool(OpTest):
def set_data(self):
self.op_type = 'sequence_pool'
# one level, batch size is 4
x = np.random.uniform(0.1, 1, [11, 23]).astype('float32')
lod = [[11]]
self.inputs = {'X': (x, lod)}
offset = self.convert_to_offset(lod)
offset = convert_to_offset(lod)
out = np.zeros((len(lod[0]), 23)).astype('float32')
self.outputs = {'Out': out}
return x, offset, out
def compute(self, x, offset, out):
self.attrs = {'pooltype': "AVERAGE"}
for i in range(len(offset[0]) - 1):
sub_x = x[offset[0][i]:offset[0][i + 1], :]
out[i] = sub_x.mean(axis=0)
compute_seqpool_avg(x, offset, out)
def setUp(self):
x, offset, out = self.set_data()
......@@ -62,9 +72,7 @@ class TestSeqAvgPool(OpTest):
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, offset, out):
self.attrs = {'pooltype': "SUM"}
for i in range(len(offset[0]) - 1):
sub_x = x[offset[0][i]:offset[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
compute_seqpool_sum(x, offset, out)
class TestSeqMaxPool(TestSeqAvgPool):
......@@ -72,7 +80,7 @@ class TestSeqMaxPool(TestSeqAvgPool):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[13]]
offset = self.convert_to_offset(lod)
offset = convert_to_offset(lod)
for i in range(len(offset[0]) - 1):
l = offset[0][i + 1] - offset[0][i]
x[offset[0][i] + np.random.randint(l), :] += 2.0
......@@ -93,10 +101,7 @@ class TestSeqMaxPool(TestSeqAvgPool):
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, offset, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(len(offset[0]) - 1):
sub_x = x[offset[0][i]:offset[0][i + 1], :]
seq_len = offset[0][i + 1] - offset[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len)
compute_seqpool_sqrt(x, offset, out)
class TestSeqLastPool(TestSeqAvgPool):
......@@ -122,7 +127,7 @@ class TestSeqAvgPool2D(TestSeqAvgPool):
x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32')
lod = [[4, 1, 3, 5]]
self.inputs = {'X': (x, lod)}
offset = self.convert_to_offset(lod)
offset = convert_to_offset(lod)
out = np.zeros((4, 3, 17)).astype('float32')
self.outputs = {'Out': out}
......@@ -167,7 +172,7 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32')
lod = [[4, 1, 3, 5]]
self.inputs = {'X': (x, lod)}
offset = self.convert_to_offset(lod)
offset = convert_to_offset(lod)
for i in range(len(offset[0]) - 1):
l = offset[0][i + 1] - offset[0][i]
x[offset[0][i] + np.random.randint(l), :] += 1.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册