From e136661304355f82bc6c4974ccf1f680c93ce00f Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 24 Feb 2020 13:40:46 +0800 Subject: [PATCH] add partial_concat op in contrib (#22528) * add partial_concat, test=develop * fix the grids and blocks, test=develop * fix the Paddle_Enforce, test=develop * fix the doc of op, test=develop * fix the doc, test=develop * fix the doc of the op, test=develop * replace -1 with None, test=develop --- paddle/fluid/operators/partial_concat_op.cc | 210 +++++++++++++++++ paddle/fluid/operators/partial_concat_op.cu | 219 ++++++++++++++++++ paddle/fluid/operators/partial_concat_op.h | 127 ++++++++++ python/paddle/fluid/contrib/layers/nn.py | 77 +++++- .../fluid/tests/unittests/test_layers.py | 10 + .../tests/unittests/test_partial_concat_op.py | 104 +++++++++ 6 files changed, 738 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/partial_concat_op.cc create mode 100644 paddle/fluid/operators/partial_concat_op.cu create mode 100644 paddle/fluid/operators/partial_concat_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_partial_concat_op.py diff --git a/paddle/fluid/operators/partial_concat_op.cc b/paddle/fluid/operators/partial_concat_op.cc new file mode 100644 index 0000000000..d85d119d21 --- /dev/null +++ b/paddle/fluid/operators/partial_concat_op.cc @@ -0,0 +1,210 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/partial_concat_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +class PartialConcatOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GE( + ctx->Inputs("X").size(), 1UL, + platform::errors::InvalidArgument( + "Inputs(X) of Partial ConcatOp should not be empty.")); + + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of Partial ConcatOp should not be null.")); + + auto inputs_dims = ctx->GetInputsDim("X"); + PADDLE_ENFORCE_EQ(inputs_dims[0].size(), 2, + platform::errors::InvalidArgument( + "Only supports 2-D array with batch size in the 1st " + "dimension and data in the 2nd.")); + + const size_t inputs_num = inputs_dims.size(); + PADDLE_ENFORCE_GT(inputs_num, 0, + platform::errors::InvalidArgument( + "ShapeError: Input tensors count should > 0. But " + "recevied inputs' length is 0.")); + if (inputs_num == 1) { + VLOG(3) << "Warning: concat op have only one input, may waste memory"; + } + + int64_t batch_size = -1; + int64_t input_len = -1; + for (size_t i = 0; i < inputs_num; ++i) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), 2, + platform::errors::InvalidArgument( + "It only supports two dimensions input now.")); + if (i == 0) { + batch_size = inputs_dims[0][0]; + input_len = inputs_dims[0][1]; + } else { + PADDLE_ENFORCE_EQ(inputs_dims[i][0], batch_size, + platform::errors::InvalidArgument( + "The batch size of all inputs must be same")); + PADDLE_ENFORCE_EQ(inputs_dims[i][1], input_len, + platform::errors::InvalidArgument( + "The input length of all inputs must be same")); + } + } + + int start_index = ComputeStartIndex( + static_cast(ctx->Attrs().Get("start_index")), + inputs_dims[0][1]); + int partial_len = ctx->Attrs().Get("length"); + if (partial_len < 0) { + partial_len = inputs_dims[0][1] - start_index; + } + + ctx->SetOutputDim("Out", {inputs_dims[0][0], + static_cast(partial_len * inputs_num)}); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto inputs = ctx.MultiInput("X"); + auto input_data_type = framework::proto::VarType::Type(0); + bool flag = 0; + for (auto *input : inputs) { + if (input->IsInitialized() && input->numel() > 0) { + input_data_type = input->type(); + flag = 1; + break; + } + } + PADDLE_ENFORCE_EQ(flag, 1, platform::errors::InvalidArgument( + "All Inputs of PartialSum OP are Empty!")); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class PartialConcatGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto in_x = "X"; + auto out_x_g_n = framework::GradVarName(in_x); + ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x)); + + auto in_names = ctx->Inputs(in_x); + auto out_names = ctx->Outputs(out_x_g_n); + + PADDLE_ENFORCE_EQ( + in_names.size(), out_names.size(), + platform::errors::InvalidArgument( + "The number of arguments in %s[%d] and %s[%d] is not equal.", in_x, + in_names.size(), out_x_g_n, out_names.size())); + for (size_t i = 0; i < in_names.size(); ++i) { + if (out_names[i] != framework::kEmptyVarName) { + ctx->ShareLoD(in_x, out_x_g_n, i, i); + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class PartialConcatOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input tensors of concat operator.").AsDuplicable(); + AddOutput("Out", "Output tensor of concat operator."); + AddAttr("start_index", + "The start index of each instance for concatenation.") + .SetDefault(0); + AddAttr("length", + "The length of each instance for concatenation." + " Negative values for all elements after start_index") + .SetDefault(-1); + AddComment(R"DOC( +Partial Concat Operator. +Partial Concatenate the input tensors along the 2nd dimension. +Only 2-D Tensor or LodTensor input is supported. +Slice and concat can only be performed along the second dimension. +Examples: + Input[0] = [[1,2],[3,4]] + Input[1] = [[5,6],[7,8]] + start_index = 1 + length = 1 + Output = [[2,6], + [4,8]] +)DOC"); + } +}; + +template +class PartialConcatGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + op->SetType("partial_concat_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); + op->SetAttr("start_index", this->GetAttr("start_index")); + op->SetAttr("length", this->GetAttr("length")); + return op; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(partial_concat, ops::PartialConcatOp, + ops::PartialConcatOpMaker, + ops::PartialConcatGradMaker, + ops::PartialConcatGradMaker); + +REGISTER_OPERATOR(partial_concat_grad, ops::PartialConcatGradOp); + +REGISTER_OP_CPU_KERNEL( + partial_concat, + ops::PartialConcatKernel, + ops::PartialConcatKernel, + ops::PartialConcatKernel, + ops::PartialConcatKernel); + +REGISTER_OP_CPU_KERNEL(partial_concat_grad, + ops::PartialConcatGradientOpKernel, + ops::PartialConcatGradientOpKernel, + ops::PartialConcatGradientOpKernel, + ops::PartialConcatGradientOpKernel); diff --git a/paddle/fluid/operators/partial_concat_op.cu b/paddle/fluid/operators/partial_concat_op.cu new file mode 100644 index 0000000000..a155db0355 --- /dev/null +++ b/paddle/fluid/operators/partial_concat_op.cu @@ -0,0 +1,219 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/partial_concat_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +__global__ void ConcatPartialCUDAKernel(T **in, T *out, int64_t all_length, + int64_t in_batch_len, + int64_t start_index, + int64_t out_batch_len, + int64_t part_length) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + while (id < all_length) { + int64_t bs_id = id / out_batch_len; + int64_t bs_index = id % out_batch_len; + int64_t var_id = bs_index / part_length; + int64_t part_index = bs_index % part_length; + int64_t in_id = start_index + part_index; + const T *tmp = in[var_id]; + out[id] = tmp[bs_id * in_batch_len + in_id]; + id += blockDim.x * gridDim.x; + } +} + +template +__global__ void ConcatPartialGradCUDAKernel( + T **in, const T *out, int64_t all_length, int64_t in_batch_len, + int64_t start_index, int64_t out_batch_len, int64_t part_length) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + while (id < all_length) { + int64_t bs_id = id / out_batch_len; + int64_t bs_index = id % out_batch_len; + int64_t var_id = bs_index / part_length; + int64_t part_index = bs_index % part_length; + int64_t in_id = start_index + part_index; + T *tmp = in[var_id]; + tmp[bs_id * in_batch_len + in_id] = out[id]; + id += blockDim.x * gridDim.x; + } +} + +template +class PartialConcatOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto in_vars = ctx.MultiInput("X"); + Tensor *out = ctx.Output("Out"); + PADDLE_ENFORCE_EQ(in_vars[0] != nullptr, true, + platform::errors::InvalidArgument( + "The input of partial concat should not be null.")); + + auto input_dim = in_vars[0]->dims(); + PADDLE_ENFORCE_EQ(input_dim.size(), 2, + platform::errors::InvalidArgument( + "Only supports 2-D array with batch size in the 1st " + "dimension and data in the 2nd.")); + auto in_size = input_dim[1]; + // may be negative + auto start_index = ctx.Attr("start_index"); + start_index = ComputeStartIndex(start_index, in_size); + + auto partial_len = ctx.Attr("length"); + if (partial_len < 0) { + partial_len = in_size - start_index; + } + + int in_num = in_vars.size(); + int batch_size = input_dim[0]; + int out_batch_len = partial_len * in_num; + int all_length = batch_size * out_batch_len; + + constexpr size_t theory_sm_threads = 1024; + auto &dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + auto sm_count = max_threads / theory_sm_threads; + size_t tile_size = 0; + int grids; + int blocks; + auto ComputeKernelParameter = [&](size_t length) { + if (length >= max_threads) + tile_size = 1024; + else if (length < max_threads && length > sm_count * 128) + tile_size = 512; + else if (length <= sm_count * 128) + tile_size = 256; + grids = CEIL_DIV(length, tile_size); + blocks = tile_size; + }; + + auto place = ctx.GetPlace(); + T *out_data = out->mutable_data(place); + + std::vector in_data; + for (int i = 0; i < in_num; ++i) + in_data.emplace_back(in_vars[i]->data()); + + auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *)); + memory::Copy(boost::get(dev_ctx.GetPlace()), + tmp_in_array->ptr(), platform::CPUPlace(), + reinterpret_cast(in_data.data()), + in_data.size() * sizeof(T *), dev_ctx.stream()); + + T **in_array_data = reinterpret_cast(tmp_in_array->ptr()); + ComputeKernelParameter(all_length); + ConcatPartialCUDAKernel<<>>( + in_array_data, out->data(), all_length, in_size, start_index, + out_batch_len, partial_len); + } +}; + +template +class PartialConcatGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out_grad = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput(framework::GradVarName("X")); + + PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, + platform::errors::InvalidArgument( + "The input of partial concat should not be null.")); + // all parameters + auto batch_size = ins[0]->dims()[0]; + auto in_size = ins[0]->dims()[1]; + // may be negative + auto start_index = ctx.Attr("start_index"); + start_index = ComputeStartIndex(start_index, in_size); + auto partial_len = ctx.Attr("length"); + if (partial_len < 0) partial_len = in_size - start_index; + + auto in_num = ins.size(); + auto grad_batch_len = partial_len * in_num; + auto all_length = grad_batch_len * batch_size; + // initialize + auto &place = *ctx.template device_context() + .eigen_device(); + for (size_t i = 0; i < outs.size(); ++i) { + outs[i]->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*outs[i]); + dxt.device(place) = dxt.constant(static_cast(0)); + } + + constexpr size_t theory_sm_threads = 1024; + auto &dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + auto sm_count = max_threads / theory_sm_threads; + size_t tile_size = 0; + int grids; + int blocks; + auto ComputeKernelParameter = [&](size_t length) { + if (length >= max_threads) + tile_size = 1024; + else if (length < max_threads && length > sm_count * 128) + tile_size = 512; + else if (length <= sm_count * 128) + tile_size = 256; + grids = CEIL_DIV(length, tile_size); + blocks = tile_size; + }; + + std::vector out_data; + for (size_t i = 0; i < in_num; ++i) { + out_data.emplace_back(outs[i]->data()); + } + auto tmp_out_array = memory::Alloc(dev_ctx, out_data.size() * sizeof(T *)); + + memory::Copy(boost::get(dev_ctx.GetPlace()), + tmp_out_array->ptr(), platform::CPUPlace(), + reinterpret_cast(out_data.data()), + out_data.size() * sizeof(T *), dev_ctx.stream()); + + T **out_grad_data = reinterpret_cast(tmp_out_array->ptr()); + ComputeKernelParameter(all_length); + ConcatPartialGradCUDAKernel<<>>( + out_grad_data, out_grad->data(), all_length, in_size, start_index, + grad_batch_len, partial_len); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(partial_concat, ops::PartialConcatOpCUDAKernel, + ops::PartialConcatOpCUDAKernel, + ops::PartialConcatOpCUDAKernel, + ops::PartialConcatOpCUDAKernel, + ops::PartialConcatOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(partial_concat_grad, + ops::PartialConcatGradOpCUDAKernel, + ops::PartialConcatGradOpCUDAKernel, + ops::PartialConcatGradOpCUDAKernel, + ops::PartialConcatGradOpCUDAKernel, + ops::PartialConcatGradOpCUDAKernel); diff --git a/paddle/fluid/operators/partial_concat_op.h b/paddle/fluid/operators/partial_concat_op.h new file mode 100644 index 0000000000..20a6639e23 --- /dev/null +++ b/paddle/fluid/operators/partial_concat_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +static inline int64_t ComputeStartIndex(int64_t start_index, int64_t size) { + PADDLE_ENFORCE_EQ( + start_index >= -size && start_index < size, true, + platform::errors::InvalidArgument( + "The start_index is expected to be in range of [%d, %d), but got %d", + -size, size, start_index)); + if (start_index < 0) { + start_index += size; + } + return start_index; +} + +template +class PartialConcatKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + framework::Tensor* out = ctx.Output("Out"); + PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, + platform::errors::InvalidArgument( + "The input of partial concat should not be null.")); + + auto input_dim = ins[0]->dims(); + PADDLE_ENFORCE_EQ(input_dim.size(), 2, + platform::errors::InvalidArgument( + "Only supports 2-D array with batch size in the 1st " + "dimension and data in the 2nd.")); + auto in_size = input_dim[1]; + + // may be negative + auto start_index = ctx.Attr("start_index"); + start_index = ComputeStartIndex(start_index, in_size); + + auto partial_len = ctx.Attr("length"); + if (partial_len < 0) { + partial_len = in_size - start_index; + } + + int batch = input_dim[0]; + int out_size = partial_len * ins.size(); + out->Resize({batch, out_size}); + auto place = ctx.GetPlace(); + T* out_data = out->mutable_data(place); + + for (size_t i = 0; i < ins.size(); ++i) { + for (int j = 0; j < batch; ++j) { + const T* in_data = ins[i]->data(); + memcpy(out_data + out_size * j + partial_len * i, + in_data + in_size * j + start_index, partial_len * sizeof(T)); + } + } + } +}; + +template +class PartialConcatGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto outs = + ctx.MultiOutput(framework::GradVarName("X")); + + PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, + platform::errors::InvalidArgument( + "The input of partial concat should not be null.")); + // all parameters + auto batch_size = ins[0]->dims()[0]; + auto in_size = ins[0]->dims()[1]; + // may be negative + auto start_index = ctx.Attr("start_index"); + start_index = ComputeStartIndex(start_index, in_size); + auto partial_len = ctx.Attr("length"); + if (partial_len < 0) partial_len = in_size - start_index; + + auto in_num = ins.size(); + auto grad_batch_len = partial_len * in_num; + auto all_length = grad_batch_len * batch_size; + + // initialize + auto& place = *ctx.template device_context() + .eigen_device(); + for (size_t i = 0; i < outs.size(); ++i) { + outs[i]->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*outs[i]); + dxt.device(place) = dxt.constant(static_cast(0)); + } + + auto* out_grad_t = out_grad->data(); + for (size_t id = 0; id < all_length; id += partial_len) { + int bs_id = id / grad_batch_len; + int bs_index = id % grad_batch_len; + int var_id = bs_index / partial_len; + auto* out_t = outs[var_id]->data(); + memcpy(out_t + bs_id * in_size + start_index, out_grad_t + id, + partial_len * sizeof(T)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index d89b1cb41d..ee96969e1b 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -24,17 +24,14 @@ import inspect from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers import utils from ... import unique_name +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype +from paddle.fluid.framework import Variable +import warnings __all__ = [ - 'fused_elemwise_activation', - 'sequence_topk_avg_pooling', - 'var_conv_2d', - 'match_matrix_tensor', - 'tree_conv', - 'fused_embedding_seq_pool', - 'multiclass_nms2', - 'search_pyramid_hash', - 'shuffle_batch', + 'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d', + 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', + 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat' ] @@ -808,3 +805,65 @@ def shuffle_batch(x, seed=None): 'SeedOut': seed}, attrs=op_attrs) return out + + +def partial_concat(input, start_index=0, length=-1): + """ + **Partial Concat** + This OP concatenates the inputs according to the start index and length. This + OP exists in contrib, which means that it is not shown to the public. + Only 2-D Tensor or LodTensor input is supported. Slice and concat can only be + performed along the second dimension. + + .. code-block:: text + + Given: + x = [[0, 1, 2], + [3, 4, 5]] + y = [[6, 7 ,8], + [9, 10, 11]] + output = partial_concat([x, y], start_index=0, length=2) + + we get: + + output = [[0, 1, 6, 7], + [3, 4, 9, 10]] + + Args: + input(list): List of input Tensors with data type float32, float64, int32, + int64. + start_index(int32): The start index of each instance for partial concatenation. + Default is 0. + length(int32): The length of each instance for partial concatenation. Default is -1. + Negative values for all elements after start_index. + Returns: + Variable: A Tensor with the same data type as input's. + Examples: + .. code-block:: python + import paddle.fluid as fluid + x = fluid.data(name="x", shape=[None,3], dtype="float32") + y = fluid.data(name="y", shape=[None,3], dtype="float32") + concat = fluid.contrib.layers.partial_concat([x, y], start_index=0, length=2) + """ + if not isinstance(input, list): + warnings.warn( + "The type of input in partial_concat should be list, but received %s." + % (type(input))) + input = [input] + for id, x in enumerate(input): + check_variable_and_dtype( + x, 'input[' + str(id) + ']', + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'partial_concat') + check_type(start_index, 'start_index', (int), 'partial_concat') + check_type(length, 'length', (int), 'partial_concat') + inputs = {'X': input} + attrs = {'start_index': start_index, 'length': length} + helper = LayerHelper('partial_concat', **locals()) + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + helper.append_op( + type='partial_concat', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8e2928f93d..f1ae7a5942 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2912,6 +2912,16 @@ class TestBook(LayerTest): out = layers.unfold(x, [3, 3], 1, 1, 1) return (out) + def test_partial_concat(self): + with self.static_graph(): + x = fluid.data(name="x", shape=[None, 3], dtype="float32") + y = fluid.data(name="y", shape=[None, 3], dtype="float32") + concat1 = fluid.contrib.layers.partial_concat( + [x, y], start_index=0, length=2) + concat2 = fluid.contrib.layers.partial_concat( + x, start_index=0, length=-1) + return concat1, concat2 + def test_deform_roi_pooling(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()): diff --git a/python/paddle/fluid/tests/unittests/test_partial_concat_op.py b/python/paddle/fluid/tests/unittests/test_partial_concat_op.py new file mode 100644 index 0000000000..a83ca3f81a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_partial_concat_op.py @@ -0,0 +1,104 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import random +import six + + +def np_partial_concat(inputs, start, length): + assert (len(inputs[0].shape) == 2) + size = inputs[0].shape[1] + assert (start >= -size and start < size) + + if start < 0: + start += size + if length < 0: + length = size - start + assert (size >= start + length) + + elems = [] + for elem in inputs: + assert (elem.shape == inputs[0].shape) + elems.append(elem[:, start:start + length]) + res = np.concatenate(elems, axis=1) + return np.concatenate(elems, axis=1) + + +class TestPartialConcatOp(OpTest): + def setUp(self): + self.op_type = "partial_concat" + self.init_kernel_type() + self.init_para() + self.var_names = [ + 'x' + str(num) for num in six.moves.range(self.var_num) + ] + self.vars = [np.random.random((self.batch_size, self.column)).astype(self.dtype)\ + for num in six.moves.range(self.var_num) ] + self.inputs = {'X': list(zip(self.var_names, self.vars))} + self.attrs = {'start_index': self.start_index, 'length': self.length} + y = np_partial_concat(self.vars[:], self.start_index, self.length) + self.outputs = {'Out': y} + + def init_kernel_type(self): + self.dtype = np.float64 + + def init_para(self): + self.batch_size = random.randint(10, 20) + self.column = random.randint(101, 200) + self.start_index = random.randint(0, self.column - 1) + self.length = -1 + self.var_num = random.randint(1, 3) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + for var_name in self.var_names: + self.check_grad([var_name], 'Out') + + +class TestPartialConcatOp2(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = -5 + self.length = -1 + self.var_num = 3 + + +class TestPartialConcatOp3(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = 10 + self.length = 20 + self.var_num = 2 + + +class TestPartialConcatOp4(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = -1 + self.length = -1 + self.var_num = 1 + + +if __name__ == '__main__': + unittest.main() -- GitLab