未验证 提交 e1366613 编写于 作者: S ShenLiang 提交者: GitHub

add partial_concat op in contrib (#22528)

* add partial_concat, test=develop

* fix the grids and blocks, test=develop

* fix the Paddle_Enforce, test=develop

* fix the doc of op, test=develop

* fix the doc, test=develop

* fix the doc of the op, test=develop

* replace -1 with None, test=develop
上级 dab5e5d8
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/partial_concat_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class PartialConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(
ctx->Inputs("X").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(X) of Partial ConcatOp should not be empty."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of Partial ConcatOp should not be null."));
auto inputs_dims = ctx->GetInputsDim("X");
PADDLE_ENFORCE_EQ(inputs_dims[0].size(), 2,
platform::errors::InvalidArgument(
"Only supports 2-D array with batch size in the 1st "
"dimension and data in the 2nd."));
const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(inputs_num, 0,
platform::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"recevied inputs' length is 0."));
if (inputs_num == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
int64_t batch_size = -1;
int64_t input_len = -1;
for (size_t i = 0; i < inputs_num; ++i) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), 2,
platform::errors::InvalidArgument(
"It only supports two dimensions input now."));
if (i == 0) {
batch_size = inputs_dims[0][0];
input_len = inputs_dims[0][1];
} else {
PADDLE_ENFORCE_EQ(inputs_dims[i][0], batch_size,
platform::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(inputs_dims[i][1], input_len,
platform::errors::InvalidArgument(
"The input length of all inputs must be same"));
}
}
int start_index = ComputeStartIndex(
static_cast<int64_t>(ctx->Attrs().Get<int>("start_index")),
inputs_dims[0][1]);
int partial_len = ctx->Attrs().Get<int>("length");
if (partial_len < 0) {
partial_len = inputs_dims[0][1] - start_index;
}
ctx->SetOutputDim("Out", {inputs_dims[0][0],
static_cast<int64_t>(partial_len * inputs_num)});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto inputs = ctx.MultiInput<Tensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
input_data_type = input->type();
flag = 1;
break;
}
}
PADDLE_ENFORCE_EQ(flag, 1, platform::errors::InvalidArgument(
"All Inputs of PartialSum OP are Empty!"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class PartialConcatGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X";
auto out_x_g_n = framework::GradVarName(in_x);
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
auto in_names = ctx->Inputs(in_x);
auto out_names = ctx->Outputs(out_x_g_n);
PADDLE_ENFORCE_EQ(
in_names.size(), out_names.size(),
platform::errors::InvalidArgument(
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
in_names.size(), out_x_g_n, out_names.size()));
for (size_t i = 0; i < in_names.size(); ++i) {
if (out_names[i] != framework::kEmptyVarName) {
ctx->ShareLoD(in_x, out_x_g_n, i, i);
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class PartialConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("start_index",
"The start index of each instance for concatenation.")
.SetDefault(0);
AddAttr<int>("length",
"The length of each instance for concatenation."
" Negative values for all elements after start_index")
.SetDefault(-1);
AddComment(R"DOC(
Partial Concat Operator.
Partial Concatenate the input tensors along the 2nd dimension.
Only 2-D Tensor or LodTensor input is supported.
Slice and concat can only be performed along the second dimension.
Examples:
Input[0] = [[1,2],[3,4]]
Input[1] = [[5,6],[7,8]]
start_index = 1
length = 1
Output = [[2,6],
[4,8]]
)DOC");
}
};
template <typename T>
class PartialConcatGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("partial_concat_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttr("start_index", this->GetAttr("start_index"));
op->SetAttr("length", this->GetAttr("length"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(partial_concat, ops::PartialConcatOp,
ops::PartialConcatOpMaker,
ops::PartialConcatGradMaker<paddle::framework::OpDesc>,
ops::PartialConcatGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(partial_concat_grad, ops::PartialConcatGradOp);
REGISTER_OP_CPU_KERNEL(
partial_concat,
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, float>,
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(partial_concat_grad,
ops::PartialConcatGradientOpKernel<float>,
ops::PartialConcatGradientOpKernel<int>,
ops::PartialConcatGradientOpKernel<double>,
ops::PartialConcatGradientOpKernel<int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/partial_concat_op.h"
#include "paddle/fluid/platform/float16.h"
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <class T>
__global__ void ConcatPartialCUDAKernel(T **in, T *out, int64_t all_length,
int64_t in_batch_len,
int64_t start_index,
int64_t out_batch_len,
int64_t part_length) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < all_length) {
int64_t bs_id = id / out_batch_len;
int64_t bs_index = id % out_batch_len;
int64_t var_id = bs_index / part_length;
int64_t part_index = bs_index % part_length;
int64_t in_id = start_index + part_index;
const T *tmp = in[var_id];
out[id] = tmp[bs_id * in_batch_len + in_id];
id += blockDim.x * gridDim.x;
}
}
template <class T>
__global__ void ConcatPartialGradCUDAKernel(
T **in, const T *out, int64_t all_length, int64_t in_batch_len,
int64_t start_index, int64_t out_batch_len, int64_t part_length) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < all_length) {
int64_t bs_id = id / out_batch_len;
int64_t bs_index = id % out_batch_len;
int64_t var_id = bs_index / part_length;
int64_t part_index = bs_index % part_length;
int64_t in_id = start_index + part_index;
T *tmp = in[var_id];
tmp[bs_id * in_batch_len + in_id] = out[id];
id += blockDim.x * gridDim.x;
}
}
template <typename T>
class PartialConcatOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto in_vars = ctx.MultiInput<Tensor>("X");
Tensor *out = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(in_vars[0] != nullptr, true,
platform::errors::InvalidArgument(
"The input of partial concat should not be null."));
auto input_dim = in_vars[0]->dims();
PADDLE_ENFORCE_EQ(input_dim.size(), 2,
platform::errors::InvalidArgument(
"Only supports 2-D array with batch size in the 1st "
"dimension and data in the 2nd."));
auto in_size = input_dim[1];
// may be negative
auto start_index = ctx.Attr<int>("start_index");
start_index = ComputeStartIndex(start_index, in_size);
auto partial_len = ctx.Attr<int>("length");
if (partial_len < 0) {
partial_len = in_size - start_index;
}
int in_num = in_vars.size();
int batch_size = input_dim[0];
int out_batch_len = partial_len * in_num;
int all_length = batch_size * out_batch_len;
constexpr size_t theory_sm_threads = 1024;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = dev_ctx.stream();
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
auto sm_count = max_threads / theory_sm_threads;
size_t tile_size = 0;
int grids;
int blocks;
auto ComputeKernelParameter = [&](size_t length) {
if (length >= max_threads)
tile_size = 1024;
else if (length < max_threads && length > sm_count * 128)
tile_size = 512;
else if (length <= sm_count * 128)
tile_size = 256;
grids = CEIL_DIV(length, tile_size);
blocks = tile_size;
};
auto place = ctx.GetPlace();
T *out_data = out->mutable_data<T>(place);
std::vector<const T *> in_data;
for (int i = 0; i < in_num; ++i)
in_data.emplace_back(in_vars[i]->data<T>());
auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *));
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
tmp_in_array->ptr(), platform::CPUPlace(),
reinterpret_cast<void *>(in_data.data()),
in_data.size() * sizeof(T *), dev_ctx.stream());
T **in_array_data = reinterpret_cast<T **>(tmp_in_array->ptr());
ComputeKernelParameter(all_length);
ConcatPartialCUDAKernel<T><<<grids, blocks, 0, stream>>>(
in_array_data, out->data<T>(), all_length, in_size, start_index,
out_batch_len, partial_len);
}
};
template <typename T>
class PartialConcatGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true,
platform::errors::InvalidArgument(
"The input of partial concat should not be null."));
// all parameters
auto batch_size = ins[0]->dims()[0];
auto in_size = ins[0]->dims()[1];
// may be negative
auto start_index = ctx.Attr<int>("start_index");
start_index = ComputeStartIndex(start_index, in_size);
auto partial_len = ctx.Attr<int>("length");
if (partial_len < 0) partial_len = in_size - start_index;
auto in_num = ins.size();
auto grad_batch_len = partial_len * in_num;
auto all_length = grad_batch_len * batch_size;
// initialize
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
for (size_t i = 0; i < outs.size(); ++i) {
outs[i]->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
dxt.device(place) = dxt.constant(static_cast<T>(0));
}
constexpr size_t theory_sm_threads = 1024;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = dev_ctx.stream();
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
auto sm_count = max_threads / theory_sm_threads;
size_t tile_size = 0;
int grids;
int blocks;
auto ComputeKernelParameter = [&](size_t length) {
if (length >= max_threads)
tile_size = 1024;
else if (length < max_threads && length > sm_count * 128)
tile_size = 512;
else if (length <= sm_count * 128)
tile_size = 256;
grids = CEIL_DIV(length, tile_size);
blocks = tile_size;
};
std::vector<const T *> out_data;
for (size_t i = 0; i < in_num; ++i) {
out_data.emplace_back(outs[i]->data<T>());
}
auto tmp_out_array = memory::Alloc(dev_ctx, out_data.size() * sizeof(T *));
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
tmp_out_array->ptr(), platform::CPUPlace(),
reinterpret_cast<void *>(out_data.data()),
out_data.size() * sizeof(T *), dev_ctx.stream());
T **out_grad_data = reinterpret_cast<T **>(tmp_out_array->ptr());
ComputeKernelParameter(all_length);
ConcatPartialGradCUDAKernel<T><<<grids, blocks, 0, stream>>>(
out_grad_data, out_grad->data<T>(), all_length, in_size, start_index,
grad_batch_len, partial_len);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(partial_concat, ops::PartialConcatOpCUDAKernel<float>,
ops::PartialConcatOpCUDAKernel<double>,
ops::PartialConcatOpCUDAKernel<int>,
ops::PartialConcatOpCUDAKernel<int64_t>,
ops::PartialConcatOpCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(partial_concat_grad,
ops::PartialConcatGradOpCUDAKernel<float>,
ops::PartialConcatGradOpCUDAKernel<double>,
ops::PartialConcatGradOpCUDAKernel<int>,
ops::PartialConcatGradOpCUDAKernel<int64_t>,
ops::PartialConcatGradOpCUDAKernel<plat::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static inline int64_t ComputeStartIndex(int64_t start_index, int64_t size) {
PADDLE_ENFORCE_EQ(
start_index >= -size && start_index < size, true,
platform::errors::InvalidArgument(
"The start_index is expected to be in range of [%d, %d), but got %d",
-size, size, start_index));
if (start_index < 0) {
start_index += size;
}
return start_index;
}
template <typename DeviceContext, typename T>
class PartialConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true,
platform::errors::InvalidArgument(
"The input of partial concat should not be null."));
auto input_dim = ins[0]->dims();
PADDLE_ENFORCE_EQ(input_dim.size(), 2,
platform::errors::InvalidArgument(
"Only supports 2-D array with batch size in the 1st "
"dimension and data in the 2nd."));
auto in_size = input_dim[1];
// may be negative
auto start_index = ctx.Attr<int>("start_index");
start_index = ComputeStartIndex(start_index, in_size);
auto partial_len = ctx.Attr<int>("length");
if (partial_len < 0) {
partial_len = in_size - start_index;
}
int batch = input_dim[0];
int out_size = partial_len * ins.size();
out->Resize({batch, out_size});
auto place = ctx.GetPlace();
T* out_data = out->mutable_data<T>(place);
for (size_t i = 0; i < ins.size(); ++i) {
for (int j = 0; j < batch; ++j) {
const T* in_data = ins[i]->data<T>();
memcpy(out_data + out_size * j + partial_len * i,
in_data + in_size * j + start_index, partial_len * sizeof(T));
}
}
}
};
template <typename T>
class PartialConcatGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true,
platform::errors::InvalidArgument(
"The input of partial concat should not be null."));
// all parameters
auto batch_size = ins[0]->dims()[0];
auto in_size = ins[0]->dims()[1];
// may be negative
auto start_index = ctx.Attr<int>("start_index");
start_index = ComputeStartIndex(start_index, in_size);
auto partial_len = ctx.Attr<int>("length");
if (partial_len < 0) partial_len = in_size - start_index;
auto in_num = ins.size();
auto grad_batch_len = partial_len * in_num;
auto all_length = grad_batch_len * batch_size;
// initialize
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
for (size_t i = 0; i < outs.size(); ++i) {
outs[i]->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
dxt.device(place) = dxt.constant(static_cast<T>(0));
}
auto* out_grad_t = out_grad->data<T>();
for (size_t id = 0; id < all_length; id += partial_len) {
int bs_id = id / grad_batch_len;
int bs_index = id % grad_batch_len;
int var_id = bs_index / partial_len;
auto* out_t = outs[var_id]->data<T>();
memcpy(out_t + bs_id * in_size + start_index, out_grad_t + id,
partial_len * sizeof(T));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -24,17 +24,14 @@ import inspect
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import utils
from ... import unique_name
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from paddle.fluid.framework import Variable
import warnings
__all__ = [
'fused_elemwise_activation',
'sequence_topk_avg_pooling',
'var_conv_2d',
'match_matrix_tensor',
'tree_conv',
'fused_embedding_seq_pool',
'multiclass_nms2',
'search_pyramid_hash',
'shuffle_batch',
'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d',
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat'
]
......@@ -808,3 +805,65 @@ def shuffle_batch(x, seed=None):
'SeedOut': seed},
attrs=op_attrs)
return out
def partial_concat(input, start_index=0, length=-1):
"""
**Partial Concat**
This OP concatenates the inputs according to the start index and length. This
OP exists in contrib, which means that it is not shown to the public.
Only 2-D Tensor or LodTensor input is supported. Slice and concat can only be
performed along the second dimension.
.. code-block:: text
Given:
x = [[0, 1, 2],
[3, 4, 5]]
y = [[6, 7 ,8],
[9, 10, 11]]
output = partial_concat([x, y], start_index=0, length=2)
we get:
output = [[0, 1, 6, 7],
[3, 4, 9, 10]]
Args:
input(list): List of input Tensors with data type float32, float64, int32,
int64.
start_index(int32): The start index of each instance for partial concatenation.
Default is 0.
length(int32): The length of each instance for partial concatenation. Default is -1.
Negative values for all elements after start_index.
Returns:
Variable: A Tensor with the same data type as input's.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name="x", shape=[None,3], dtype="float32")
y = fluid.data(name="y", shape=[None,3], dtype="float32")
concat = fluid.contrib.layers.partial_concat([x, y], start_index=0, length=2)
"""
if not isinstance(input, list):
warnings.warn(
"The type of input in partial_concat should be list, but received %s."
% (type(input)))
input = [input]
for id, x in enumerate(input):
check_variable_and_dtype(
x, 'input[' + str(id) + ']',
['float16', 'float32', 'float64', 'int32', 'int64'],
'partial_concat')
check_type(start_index, 'start_index', (int), 'partial_concat')
check_type(length, 'length', (int), 'partial_concat')
inputs = {'X': input}
attrs = {'start_index': start_index, 'length': length}
helper = LayerHelper('partial_concat', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='partial_concat',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
return out
......@@ -2912,6 +2912,16 @@ class TestBook(LayerTest):
out = layers.unfold(x, [3, 3], 1, 1, 1)
return (out)
def test_partial_concat(self):
with self.static_graph():
x = fluid.data(name="x", shape=[None, 3], dtype="float32")
y = fluid.data(name="y", shape=[None, 3], dtype="float32")
concat1 = fluid.contrib.layers.partial_concat(
[x, y], start_index=0, length=2)
concat2 = fluid.contrib.layers.partial_concat(
x, start_index=0, length=-1)
return concat1, concat2
def test_deform_roi_pooling(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import random
import six
def np_partial_concat(inputs, start, length):
assert (len(inputs[0].shape) == 2)
size = inputs[0].shape[1]
assert (start >= -size and start < size)
if start < 0:
start += size
if length < 0:
length = size - start
assert (size >= start + length)
elems = []
for elem in inputs:
assert (elem.shape == inputs[0].shape)
elems.append(elem[:, start:start + length])
res = np.concatenate(elems, axis=1)
return np.concatenate(elems, axis=1)
class TestPartialConcatOp(OpTest):
def setUp(self):
self.op_type = "partial_concat"
self.init_kernel_type()
self.init_para()
self.var_names = [
'x' + str(num) for num in six.moves.range(self.var_num)
]
self.vars = [np.random.random((self.batch_size, self.column)).astype(self.dtype)\
for num in six.moves.range(self.var_num) ]
self.inputs = {'X': list(zip(self.var_names, self.vars))}
self.attrs = {'start_index': self.start_index, 'length': self.length}
y = np_partial_concat(self.vars[:], self.start_index, self.length)
self.outputs = {'Out': y}
def init_kernel_type(self):
self.dtype = np.float64
def init_para(self):
self.batch_size = random.randint(10, 20)
self.column = random.randint(101, 200)
self.start_index = random.randint(0, self.column - 1)
self.length = -1
self.var_num = random.randint(1, 3)
def test_check_output(self):
self.check_output()
def test_check_grad(self):
for var_name in self.var_names:
self.check_grad([var_name], 'Out')
class TestPartialConcatOp2(TestPartialConcatOp):
def init_para(self):
self.batch_size = random.randint(1, 10)
self.column = random.randint(101, 200)
self.start_index = -5
self.length = -1
self.var_num = 3
class TestPartialConcatOp3(TestPartialConcatOp):
def init_para(self):
self.batch_size = random.randint(1, 10)
self.column = random.randint(101, 200)
self.start_index = 10
self.length = 20
self.var_num = 2
class TestPartialConcatOp4(TestPartialConcatOp):
def init_para(self):
self.batch_size = random.randint(1, 10)
self.column = random.randint(101, 200)
self.start_index = -1
self.length = -1
self.var_num = 1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册