From 3132681e8a5d37a07c841dda36428d3ef10bbef0 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 24 Feb 2020 20:59:03 +0800 Subject: [PATCH] add partial_sum op in contrib (#22292) * add partial_sum_op, test=develop * modify the Paddle Error Message, test=develop * modify the Paddle Error Message, test=develop * modify the bug for python3, test=develop * modify the ut for ci, test=develop * mv to contrib, test=develop * use check_variable_and_dtype, test=develop * fix ci, test=develop * fix conflict, test=dvelop * add partial concat, test=develop * fix the conflict, test=develop * fix the error, test=develop * rm SSE4, test=develop --- paddle/fluid/operators/partial_sum_op.cc | 207 ++++++++++++++++ paddle/fluid/operators/partial_sum_op.cu | 223 ++++++++++++++++++ paddle/fluid/operators/partial_sum_op.h | 102 ++++++++ python/paddle/fluid/contrib/layers/nn.py | 57 ++++- .../fluid/tests/unittests/test_layers.py | 8 + .../tests/unittests/test_partial_sum_op.py | 96 ++++++++ 6 files changed, 692 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/partial_sum_op.cc create mode 100644 paddle/fluid/operators/partial_sum_op.cu create mode 100644 paddle/fluid/operators/partial_sum_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_partial_sum_op.py diff --git a/paddle/fluid/operators/partial_sum_op.cc b/paddle/fluid/operators/partial_sum_op.cc new file mode 100644 index 0000000000..f2767e5858 --- /dev/null +++ b/paddle/fluid/operators/partial_sum_op.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/partial_sum_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +class PartialSumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + platform::errors::InvalidArgument( + "Inputs(X) of PartialSumOp should not be empty.")); + + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of PartialSumOp should not be null.")); + + auto inputs_dims = ctx->GetInputsDim("X"); + + const size_t inputs_num = inputs_dims.size(); + PADDLE_ENFORCE_GT(inputs_num, 0, + platform::errors::InvalidArgument( + "ShapeError: Input tensors count should > 0. But " + "recevied inputs' length is 0.")); + if (inputs_num == 1) { + VLOG(3) << "Warning: partial_sum op have only one input, may be useless"; + } + + int start_index = ctx->Attrs().Get("start_index"); + int length = ctx->Attrs().Get("length"); + + // Only suppert two dimensions now, should be extended later + // when length is -1, need make sure all dimensions to be added are the same + int64_t batch_size = -1; + int64_t input_len = -1; + for (size_t i = 0; i < inputs_num; ++i) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), 2, + platform::errors::InvalidArgument( + "Only suppert two dimensions input now.")); + if (i == 0) { + batch_size = inputs_dims[0][0]; + input_len = inputs_dims[0][1]; + } else { + PADDLE_ENFORCE_EQ(inputs_dims[i][0], batch_size, + platform::errors::InvalidArgument( + "The batch size of all inputs must be same")); + PADDLE_ENFORCE_EQ(inputs_dims[i][1], input_len, + platform::errors::InvalidArgument( + "The input len of all inputs must be same")); + } + } + PADDLE_ENFORCE_GT(input_len, start_index, + platform::errors::OutOfRange( + "start_index must be less than input len")); + if (length > 0) { + PADDLE_ENFORCE_GE( + input_len, start_index + length, + platform::errors::OutOfRange( + "start_index + length is larger than input length")); + } + + std::vector out_dims(2); + out_dims[0] = batch_size; + out_dims[1] = (length == -1) ? input_len - start_index : length; + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto inputs = ctx.MultiInput("X"); + auto input_data_type = framework::proto::VarType::Type(0); + bool flag = 0; + for (auto *input : inputs) { + if (input->IsInitialized() && input->numel() > 0) { + input_data_type = input->type(); + flag = 1; + break; + } + } + + PADDLE_ENFORCE_EQ(flag, 1, platform::errors::InvalidArgument( + "All Inputs of PartialSum OP are Empty!")); + return framework::OpKernelType(input_data_type, platform::CPUPlace()); + } +}; + +class PartialSumGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto in_x = "X"; + auto out_x_g_n = framework::GradVarName(in_x); + ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x)); + + auto in_names = ctx->Inputs(in_x); + auto out_names = ctx->Outputs(out_x_g_n); + + PADDLE_ENFORCE_EQ( + in_names.size(), out_names.size(), + platform::errors::InvalidArgument( + "The number of arguments in %s[%d] and %s[%d] is not equal.", in_x, + in_names.size(), out_x_g_n, out_names.size())); + for (size_t i = 0; i < in_names.size(); ++i) { + if (out_names[i] != framework::kEmptyVarName) { + ctx->ShareLoD(in_x, out_x_g_n, i, i); + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class PartialSumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input tensors of partial_sum operator.").AsDuplicable(); + AddOutput("Out", "Output tensor of partial_sum operator."); + AddAttr( + "use_mkldnn", + "(bool, default false) Indicates if MKL-DNN kernel will be used") + .SetDefault(false); + AddAttr("start_index", "The start index of tensor wanted to be added.") + .SetDefault(0); + AddAttr("length", "The length of tensor wanted to be added.") + .SetDefault(-1); + AddComment(R"DOC( +PartialSum Operator. +This Op can sum the vars by specifying the initial position(start_index) and length(length). +This OP exists in contrib, which means that it is not shown to the public. +Only 2-D Tensor or LodTensor input is supported. Slice and concat can only be +performed along the second dimension. + +Examples: + Input[0] = [[1,2,3],[3,4,5]] + Input[1] = [[5,6,7],[7,8,9]] + start_index = 0 + length = 2 + Output = [[6,8], + [10,12]] +)DOC"); + } +}; + +template +class PartialSumGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + op->SetType("partial_sum_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); + op->SetAttr("start_index", this->GetAttr("start_index")); + op->SetAttr("length", this->GetAttr("length")); + return op; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(partial_sum, ops::PartialSumOp, ops::PartialSumOpMaker, + ops::PartialSumGradMaker, + ops::PartialSumGradMaker); + +REGISTER_OPERATOR(partial_sum_grad, ops::PartialSumGradOp); + +REGISTER_OP_CPU_KERNEL( + partial_sum, + ops::PartialSumKernel, + ops::PartialSumKernel, + ops::PartialSumKernel, + ops::PartialSumKernel); + +REGISTER_OP_CPU_KERNEL(partial_sum_grad, ops::PartialSumGradientOpKernel, + ops::PartialSumGradientOpKernel, + ops::PartialSumGradientOpKernel, + ops::PartialSumGradientOpKernel); diff --git a/paddle/fluid/operators/partial_sum_op.cu b/paddle/fluid/operators/partial_sum_op.cu new file mode 100644 index 0000000000..27b06e227f --- /dev/null +++ b/paddle/fluid/operators/partial_sum_op.cu @@ -0,0 +1,223 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/partial_sum_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +__global__ void SumArrayPartialCUDAKernel(T **in, T *out, int64_t lod_length, + size_t in_size, int64_t start_index, + int64_t length, int64_t row_length) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + while (id < lod_length) { + T total = static_cast(0); + int b_id = id / length; + int b_offset = id % length; + + for (int i = 0; i < in_size; ++i) { + const T *tmp = in[i]; + if (tmp) { + total += tmp[start_index + b_id * row_length + b_offset]; + } + } + out[id] = total; + id += blockDim.x * gridDim.x; + } +} + +template +__global__ void PartialSumGradCUDAKernel(T **res_grad, const T *out_grad, + int64_t lod_length, size_t in_size, + int64_t start_index, int64_t length, + int64_t row_length) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + while (id < lod_length) { + T total = static_cast(0); + int b_id = id / length; + int b_offset = id % length; + + for (int i = 0; i < in_size; ++i) { + T *tmp = res_grad[i]; + tmp[start_index + b_id * row_length + b_offset] = out_grad[i]; + } + id += blockDim.x * gridDim.x; + } +} + +template +class PartialSumOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto in_vars = ctx.MultiInput("X"); + Tensor *out = ctx.Output("Out"); + + PADDLE_ENFORCE_EQ( + in_vars[0] != nullptr, true, + platform::errors::InvalidArgument("The input should not be null.")); + + auto place = ctx.GetPlace(); // GPUPlace only now + auto start_index = ctx.Attr("start_index"); + auto length = ctx.Attr("length"); + auto batch_size = in_vars[0]->dims()[0]; + if (length == -1) { + length = in_vars[0]->dims()[1] - start_index; + } + + constexpr size_t theory_sm_threads = 1024; + auto &dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + auto sm_count = max_threads / theory_sm_threads; + size_t tile_size = 0; + dim3 grids; + dim3 blocks; + auto ComputeKernelParameter = [&](size_t length) { + if (length >= max_threads) + tile_size = 1024; + else if (length < max_threads && length > sm_count * 128) + tile_size = 512; + else if (length <= sm_count * 128) + tile_size = 256; + grids = dim3(CEIL_DIV(length, tile_size), 1, 1); + blocks = dim3(tile_size, 1, 1); + }; + + auto lod_length = length * batch_size; + auto row_length = in_vars[0]->dims()[1]; + auto in_num = in_vars.size(); + + std::vector in_data; + for (int i = 0; i < in_num; ++i) { + in_data.emplace_back(in_vars[i]->data()); + } + + if (!in_data.empty()) { + auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *)); + + memory::Copy(boost::get(dev_ctx.GetPlace()), + tmp_in_array->ptr(), platform::CPUPlace(), + reinterpret_cast(in_data.data()), + in_data.size() * sizeof(T *), dev_ctx.stream()); + + T **in_array_data = reinterpret_cast(tmp_in_array->ptr()); + ComputeKernelParameter(lod_length); + SumArrayPartialCUDAKernel<<>>( + in_array_data, out->data(), lod_length, in_data.size(), + start_index, length, row_length); + } + } +}; + +template +class PartialSumGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor *out_grad = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput(framework::GradVarName("X")); + + PADDLE_ENFORCE_EQ( + ins[0] != nullptr, true, + platform::errors::InvalidArgument("The input should not be null.")); + auto start_index = ctx.Attr("start_index"); + auto length = ctx.Attr("length"); + if (length == -1) { + length = ins[0]->dims()[1] - start_index; + } + + // initialize + auto &place = *ctx.template device_context() + .eigen_device(); + for (size_t i = 0; i < outs.size(); ++i) { + outs[i]->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*outs[i]); + dxt.device(place) = dxt.constant(static_cast(0)); + } + + auto batch_size = ins[0]->dims()[0]; + if (length == -1) { + length = ins[0]->dims()[1] - start_index; + } + auto lod_length = length * batch_size; + auto row_length = ins[0]->dims()[1]; + auto out_num = outs.size(); + + constexpr size_t theory_sm_threads = 1024; + auto &dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + auto sm_count = max_threads / theory_sm_threads; + size_t tile_size = 0; + dim3 grids; + dim3 blocks; + auto ComputeKernelParameter = [&](size_t length) { + if (length >= max_threads) + tile_size = 1024; + else if (length < max_threads && length > sm_count * 128) + tile_size = 512; + else if (length <= sm_count * 128) + tile_size = 256; + grids = dim3(CEIL_DIV(length, tile_size), 1, 1); + blocks = dim3(tile_size, 1, 1); + }; + + std::vector out_data; + for (int i = 0; i < out_num; ++i) { + out_data.emplace_back(outs[i]->data()); + } + + if (!out_data.empty()) { + auto tmp_out_array = + memory::Alloc(dev_ctx, out_data.size() * sizeof(T *)); + + memory::Copy(boost::get(dev_ctx.GetPlace()), + tmp_out_array->ptr(), platform::CPUPlace(), + reinterpret_cast(out_data.data()), + out_data.size() * sizeof(T *), dev_ctx.stream()); + + T **out_grad_data = reinterpret_cast(tmp_out_array->ptr()); + ComputeKernelParameter(lod_length); + PartialSumGradCUDAKernel<<>>( + out_grad_data, out_grad->data(), lod_length, out_data.size(), + start_index, length, row_length); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(partial_sum, ops::PartialSumOpCUDAKernel, + ops::PartialSumOpCUDAKernel, + ops::PartialSumOpCUDAKernel, + ops::PartialSumOpCUDAKernel, + ops::PartialSumOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(partial_sum_grad, + ops::PartialSumGradOpCUDAKernel, + ops::PartialSumGradOpCUDAKernel, + ops::PartialSumGradOpCUDAKernel, + ops::PartialSumGradOpCUDAKernel, + ops::PartialSumGradOpCUDAKernel); diff --git a/paddle/fluid/operators/partial_sum_op.h b/paddle/fluid/operators/partial_sum_op.h new file mode 100644 index 0000000000..d9c6fd758f --- /dev/null +++ b/paddle/fluid/operators/partial_sum_op.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class PartialSumKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + Tensor* out = ctx.Output("Out"); + PADDLE_ENFORCE_EQ( + ins[0] != nullptr, true, + platform::errors::InvalidArgument("The input should not be null.")); + + auto place = ctx.GetPlace(); // CPUPlace only now + + auto* out_t = out->mutable_data(place); + auto start_index = ctx.Attr("start_index"); + auto length = ctx.Attr("length"); + auto batch_size = ins[0]->dims()[0]; + if (length == -1) { + length = ins[0]->dims()[1] - start_index; + } + + memset(out_t, 0, sizeof(T) * batch_size * length); + + for (size_t i = 0; i < ins.size(); ++i) { + auto* in_t = ins[i]->data(); + auto total_len = ins[i]->dims()[1]; + for (auto bs_id = 0; bs_id < batch_size; ++bs_id) { + for (auto k = 0; k < length; ++k) { + out_t[bs_id * length + k] += + in_t[bs_id * total_len + start_index + k]; + } + } + } + } +}; + +template +class PartialSumGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto outs = + ctx.MultiOutput(framework::GradVarName("X")); + + PADDLE_ENFORCE_EQ( + ins[0] != nullptr, true, + platform::errors::InvalidArgument("The input should not be null.")); + auto start_index = ctx.Attr("start_index"); + auto length = ctx.Attr("length"); + auto batch_size = ins[0]->dims()[0]; + if (length == -1) { + length = ins[0]->dims()[1] - start_index; + } + + // initialize + auto& place = *ctx.template device_context() + .eigen_device(); + for (size_t i = 0; i < outs.size(); ++i) { + outs[i]->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*outs[i]); + dxt.device(place) = dxt.constant(static_cast(0)); + } + + auto* out_grad_t = out_grad->data(); + for (size_t i = 0; i < outs.size(); ++i) { + auto* out_t = outs[i]->data(); + auto total_len = ins[i]->dims()[1]; + for (auto bs_id = 0; bs_id < batch_size; ++bs_id) { + for (int len = 0; len < length; ++len) { + out_t[start_index + bs_id * total_len + len] = + out_grad_t[bs_id * length + len] * static_cast(1); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index ee96969e1b..b783b1a66b 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -31,7 +31,8 @@ import warnings __all__ = [ 'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d', 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', - 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat' + 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat', + 'partial_sum' ] @@ -867,3 +868,57 @@ def partial_concat(input, start_index=0, length=-1): outputs={'Out': [out]}, attrs=attrs) return out + + +def partial_sum(input, start_index=0, length=-1): + """ + **PartialSum** + This Op can sum the vars by specifying the initial position(start_index) and length(length). + This Op exists in contrib, which means that it is not shown to the public. + Only 2-D Tensor or LodTensor input is supported. Slice and concat can only be + performed along the second dimension. + .. code-block:: text + + Given: + x = [[0, 1, 2], + [3, 4, 5]] + y = [[6, 7 ,8], + [9, 10, 11]] + output = partial_sum([x, y], start_index=0, length=2) + we get: + + output = [[6, 8], + [12, 14]] + Args: + input(list): List of input Tensors with data type float32, float64, int32, + int64. + Returns: + Variable: A Tensor with the same data type as input's. + Examples: + .. code-block:: python + import paddle.fluid.layers as layers + import paddle.fluid as fluid + import numpy as np + x = fluid.data(name="x", shape=[None, 3], dtype="float32") + y = fluid.data(name="y", shape=[None, 3], dtype="float32") + sum = layers.partial_sum([x,y], start_index=0, length=2) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + xx = np.array([1,2,3,4,5,6]).reshape((2,3)).astype("float32") + yy = np.array([6,5,4,4,5,6]).reshape((2,3)).astype("float32") + out = exe.run(feed={"x":xx, "y":yy}, fetch_list=[sum]) + """ + for id, x in enumerate(input): + check_variable_and_dtype(x, 'input[' + str(id) + ']', + ['float32', 'float64', 'int32', 'int64'], + 'partial_sum') + + inputs = {'X': input} + attrs = {} + attrs['start_index'] = start_index + attrs['length'] = length + helper = LayerHelper('partial_sum', **locals()) + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + helper.append_op( + type='partial_sum', inputs=inputs, outputs={'Out': [out]}, attrs=attrs) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f1ae7a5942..967aa4789f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2790,6 +2790,14 @@ class TestBook(LayerTest): self.assertIsNotNone(out2) return (out1) + def test_partial_sum(self): + with self.static_graph(): + x = fluid.data(name="x", shape=[None, 3], dtype="float32") + y = fluid.data(name="y", shape=[None, 3], dtype="float32") + sum = fluid.contrib.layers.partial_sum( + [x, y], start_index=0, length=2) + return (sum) + def test_roi_pool(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): diff --git a/python/paddle/fluid/tests/unittests/test_partial_sum_op.py b/python/paddle/fluid/tests/unittests/test_partial_sum_op.py new file mode 100644 index 0000000000..eb51664301 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_partial_sum_op.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid.layers as layers +import paddle.fluid as fluid +import random +import six + + +class TestPartialSumOp(OpTest): + def setUp(self): + self.op_type = "partial_sum" + self.init_kernel_type() + self.init_para() + if self.length is -1: + end_index = self.column + else: + end_index = self.start_index + self.length + self.var_names = [ + 'x' + str(num) for num in six.moves.range(self.var_num) + ] + self.vars = [np.random.random((self.batch_size, self.column)).astype(self.dtype)\ + for num in six.moves.range(self.var_num) ] + self.inputs = {'X': list(zip(self.var_names, self.vars))} + self.attrs = {'start_index': self.start_index, 'length': self.length} + y = self.vars[0][:, self.start_index:end_index] + for i in six.moves.range(1, self.var_num): + y = y + self.vars[i][:, self.start_index:end_index] + + self.outputs = {'Out': y} + + def init_kernel_type(self): + self.dtype = np.float64 + + def init_para(self): + self.batch_size = random.randint(10, 20) + self.column = random.randint(101, 200) + self.start_index = random.randint(0, self.column - 1) + self.length = random.randint(0, self.column - self.start_index) + self.var_num = random.randint(1, 3) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + for var_name in self.var_names: + self.check_grad([var_name], 'Out') + + +class TestPartialSumOp2(TestPartialSumOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = random.randint(0, self.column - 1) + self.length = -1 + self.var_num = 3 + + +class TestPartialSumOp3(TestPartialSumOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = self.column - 1 + self.length = 1 + self.var_num = 2 + + +class TestPartialSumOp4(TestPartialSumOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = self.column - 1 + self.length = 1 + self.var_num = 1 + + +if __name__ == "__main__": + unittest.main() -- GitLab