未验证 提交 3132681e 编写于 作者: S ShenLiang 提交者: GitHub

add partial_sum op in contrib (#22292)

* add partial_sum_op, test=develop

* modify the Paddle Error Message, test=develop

* modify the Paddle Error Message, test=develop

* modify the bug for python3, test=develop

* modify the ut for ci, test=develop

* mv to contrib, test=develop

* use check_variable_and_dtype, test=develop

* fix ci, test=develop

* fix conflict, test=dvelop

* add partial concat, test=develop

* fix the conflict, test=develop

* fix the error, test=develop

* rm SSE4, test=develop
上级 611411b9
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/partial_sum_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class PartialSumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(X) of PartialSumOp should not be empty."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of PartialSumOp should not be null."));
auto inputs_dims = ctx->GetInputsDim("X");
const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(inputs_num, 0,
platform::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"recevied inputs' length is 0."));
if (inputs_num == 1) {
VLOG(3) << "Warning: partial_sum op have only one input, may be useless";
}
int start_index = ctx->Attrs().Get<int>("start_index");
int length = ctx->Attrs().Get<int>("length");
// Only suppert two dimensions now, should be extended later
// when length is -1, need make sure all dimensions to be added are the same
int64_t batch_size = -1;
int64_t input_len = -1;
for (size_t i = 0; i < inputs_num; ++i) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), 2,
platform::errors::InvalidArgument(
"Only suppert two dimensions input now."));
if (i == 0) {
batch_size = inputs_dims[0][0];
input_len = inputs_dims[0][1];
} else {
PADDLE_ENFORCE_EQ(inputs_dims[i][0], batch_size,
platform::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(inputs_dims[i][1], input_len,
platform::errors::InvalidArgument(
"The input len of all inputs must be same"));
}
}
PADDLE_ENFORCE_GT(input_len, start_index,
platform::errors::OutOfRange(
"start_index must be less than input len"));
if (length > 0) {
PADDLE_ENFORCE_GE(
input_len, start_index + length,
platform::errors::OutOfRange(
"start_index + length is larger than input length"));
}
std::vector<int64_t> out_dims(2);
out_dims[0] = batch_size;
out_dims[1] = (length == -1) ? input_len - start_index : length;
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto inputs = ctx.MultiInput<Tensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
input_data_type = input->type();
flag = 1;
break;
}
}
PADDLE_ENFORCE_EQ(flag, 1, platform::errors::InvalidArgument(
"All Inputs of PartialSum OP are Empty!"));
return framework::OpKernelType(input_data_type, platform::CPUPlace());
}
};
class PartialSumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X";
auto out_x_g_n = framework::GradVarName(in_x);
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
auto in_names = ctx->Inputs(in_x);
auto out_names = ctx->Outputs(out_x_g_n);
PADDLE_ENFORCE_EQ(
in_names.size(), out_names.size(),
platform::errors::InvalidArgument(
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
in_names.size(), out_x_g_n, out_names.size()));
for (size_t i = 0; i < in_names.size(); ++i) {
if (out_names[i] != framework::kEmptyVarName) {
ctx->ShareLoD(in_x, out_x_g_n, i, i);
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class PartialSumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input tensors of partial_sum operator.").AsDuplicable();
AddOutput("Out", "Output tensor of partial_sum operator.");
AddAttr<bool>(
"use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false);
AddAttr<int>("start_index", "The start index of tensor wanted to be added.")
.SetDefault(0);
AddAttr<int>("length", "The length of tensor wanted to be added.")
.SetDefault(-1);
AddComment(R"DOC(
PartialSum Operator.
This Op can sum the vars by specifying the initial position(start_index) and length(length).
This OP exists in contrib, which means that it is not shown to the public.
Only 2-D Tensor or LodTensor input is supported. Slice and concat can only be
performed along the second dimension.
Examples:
Input[0] = [[1,2,3],[3,4,5]]
Input[1] = [[5,6,7],[7,8,9]]
start_index = 0
length = 2
Output = [[6,8],
[10,12]]
)DOC");
}
};
template <typename T>
class PartialSumGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("partial_sum_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttr("start_index", this->GetAttr("start_index"));
op->SetAttr("length", this->GetAttr("length"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(partial_sum, ops::PartialSumOp, ops::PartialSumOpMaker,
ops::PartialSumGradMaker<paddle::framework::OpDesc>,
ops::PartialSumGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(partial_sum_grad, ops::PartialSumGradOp);
REGISTER_OP_CPU_KERNEL(
partial_sum,
ops::PartialSumKernel<paddle::platform::CPUDeviceContext, float>,
ops::PartialSumKernel<paddle::platform::CPUDeviceContext, int>,
ops::PartialSumKernel<paddle::platform::CPUDeviceContext, double>,
ops::PartialSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(partial_sum_grad, ops::PartialSumGradientOpKernel<float>,
ops::PartialSumGradientOpKernel<int>,
ops::PartialSumGradientOpKernel<double>,
ops::PartialSumGradientOpKernel<int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/partial_sum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <class T>
__global__ void SumArrayPartialCUDAKernel(T **in, T *out, int64_t lod_length,
size_t in_size, int64_t start_index,
int64_t length, int64_t row_length) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < lod_length) {
T total = static_cast<T>(0);
int b_id = id / length;
int b_offset = id % length;
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[start_index + b_id * row_length + b_offset];
}
}
out[id] = total;
id += blockDim.x * gridDim.x;
}
}
template <class T>
__global__ void PartialSumGradCUDAKernel(T **res_grad, const T *out_grad,
int64_t lod_length, size_t in_size,
int64_t start_index, int64_t length,
int64_t row_length) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < lod_length) {
T total = static_cast<T>(0);
int b_id = id / length;
int b_offset = id % length;
for (int i = 0; i < in_size; ++i) {
T *tmp = res_grad[i];
tmp[start_index + b_id * row_length + b_offset] = out_grad[i];
}
id += blockDim.x * gridDim.x;
}
}
template <typename T>
class PartialSumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto in_vars = ctx.MultiInput<Tensor>("X");
Tensor *out = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
in_vars[0] != nullptr, true,
platform::errors::InvalidArgument("The input should not be null."));
auto place = ctx.GetPlace(); // GPUPlace only now
auto start_index = ctx.Attr<int>("start_index");
auto length = ctx.Attr<int>("length");
auto batch_size = in_vars[0]->dims()[0];
if (length == -1) {
length = in_vars[0]->dims()[1] - start_index;
}
constexpr size_t theory_sm_threads = 1024;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = dev_ctx.stream();
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
auto sm_count = max_threads / theory_sm_threads;
size_t tile_size = 0;
dim3 grids;
dim3 blocks;
auto ComputeKernelParameter = [&](size_t length) {
if (length >= max_threads)
tile_size = 1024;
else if (length < max_threads && length > sm_count * 128)
tile_size = 512;
else if (length <= sm_count * 128)
tile_size = 256;
grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
blocks = dim3(tile_size, 1, 1);
};
auto lod_length = length * batch_size;
auto row_length = in_vars[0]->dims()[1];
auto in_num = in_vars.size();
std::vector<const T *> in_data;
for (int i = 0; i < in_num; ++i) {
in_data.emplace_back(in_vars[i]->data<T>());
}
if (!in_data.empty()) {
auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *));
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
tmp_in_array->ptr(), platform::CPUPlace(),
reinterpret_cast<void *>(in_data.data()),
in_data.size() * sizeof(T *), dev_ctx.stream());
T **in_array_data = reinterpret_cast<T **>(tmp_in_array->ptr());
ComputeKernelParameter(lod_length);
SumArrayPartialCUDAKernel<T><<<grids, blocks, 0, stream>>>(
in_array_data, out->data<T>(), lod_length, in_data.size(),
start_index, length, row_length);
}
}
};
template <typename T>
class PartialSumGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(
ins[0] != nullptr, true,
platform::errors::InvalidArgument("The input should not be null."));
auto start_index = ctx.Attr<int>("start_index");
auto length = ctx.Attr<int>("length");
if (length == -1) {
length = ins[0]->dims()[1] - start_index;
}
// initialize
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
for (size_t i = 0; i < outs.size(); ++i) {
outs[i]->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
dxt.device(place) = dxt.constant(static_cast<T>(0));
}
auto batch_size = ins[0]->dims()[0];
if (length == -1) {
length = ins[0]->dims()[1] - start_index;
}
auto lod_length = length * batch_size;
auto row_length = ins[0]->dims()[1];
auto out_num = outs.size();
constexpr size_t theory_sm_threads = 1024;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = dev_ctx.stream();
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
auto sm_count = max_threads / theory_sm_threads;
size_t tile_size = 0;
dim3 grids;
dim3 blocks;
auto ComputeKernelParameter = [&](size_t length) {
if (length >= max_threads)
tile_size = 1024;
else if (length < max_threads && length > sm_count * 128)
tile_size = 512;
else if (length <= sm_count * 128)
tile_size = 256;
grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
blocks = dim3(tile_size, 1, 1);
};
std::vector<const T *> out_data;
for (int i = 0; i < out_num; ++i) {
out_data.emplace_back(outs[i]->data<T>());
}
if (!out_data.empty()) {
auto tmp_out_array =
memory::Alloc(dev_ctx, out_data.size() * sizeof(T *));
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
tmp_out_array->ptr(), platform::CPUPlace(),
reinterpret_cast<void *>(out_data.data()),
out_data.size() * sizeof(T *), dev_ctx.stream());
T **out_grad_data = reinterpret_cast<T **>(tmp_out_array->ptr());
ComputeKernelParameter(lod_length);
PartialSumGradCUDAKernel<T><<<grids, blocks, 0, stream>>>(
out_grad_data, out_grad->data<T>(), lod_length, out_data.size(),
start_index, length, row_length);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(partial_sum, ops::PartialSumOpCUDAKernel<float>,
ops::PartialSumOpCUDAKernel<double>,
ops::PartialSumOpCUDAKernel<int>,
ops::PartialSumOpCUDAKernel<int64_t>,
ops::PartialSumOpCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(partial_sum_grad,
ops::PartialSumGradOpCUDAKernel<float>,
ops::PartialSumGradOpCUDAKernel<double>,
ops::PartialSumGradOpCUDAKernel<int>,
ops::PartialSumGradOpCUDAKernel<int64_t>,
ops::PartialSumGradOpCUDAKernel<plat::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class PartialSumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
ins[0] != nullptr, true,
platform::errors::InvalidArgument("The input should not be null."));
auto place = ctx.GetPlace(); // CPUPlace only now
auto* out_t = out->mutable_data<T>(place);
auto start_index = ctx.Attr<int>("start_index");
auto length = ctx.Attr<int>("length");
auto batch_size = ins[0]->dims()[0];
if (length == -1) {
length = ins[0]->dims()[1] - start_index;
}
memset(out_t, 0, sizeof(T) * batch_size * length);
for (size_t i = 0; i < ins.size(); ++i) {
auto* in_t = ins[i]->data<T>();
auto total_len = ins[i]->dims()[1];
for (auto bs_id = 0; bs_id < batch_size; ++bs_id) {
for (auto k = 0; k < length; ++k) {
out_t[bs_id * length + k] +=
in_t[bs_id * total_len + start_index + k];
}
}
}
}
};
template <typename T>
class PartialSumGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(
ins[0] != nullptr, true,
platform::errors::InvalidArgument("The input should not be null."));
auto start_index = ctx.Attr<int>("start_index");
auto length = ctx.Attr<int>("length");
auto batch_size = ins[0]->dims()[0];
if (length == -1) {
length = ins[0]->dims()[1] - start_index;
}
// initialize
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
for (size_t i = 0; i < outs.size(); ++i) {
outs[i]->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
dxt.device(place) = dxt.constant(static_cast<T>(0));
}
auto* out_grad_t = out_grad->data<T>();
for (size_t i = 0; i < outs.size(); ++i) {
auto* out_t = outs[i]->data<T>();
auto total_len = ins[i]->dims()[1];
for (auto bs_id = 0; bs_id < batch_size; ++bs_id) {
for (int len = 0; len < length; ++len) {
out_t[start_index + bs_id * total_len + len] =
out_grad_t[bs_id * length + len] * static_cast<T>(1);
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -31,7 +31,8 @@ import warnings
__all__ = [
'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d',
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat'
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
'partial_sum'
]
......@@ -867,3 +868,57 @@ def partial_concat(input, start_index=0, length=-1):
outputs={'Out': [out]},
attrs=attrs)
return out
def partial_sum(input, start_index=0, length=-1):
"""
**PartialSum**
This Op can sum the vars by specifying the initial position(start_index) and length(length).
This Op exists in contrib, which means that it is not shown to the public.
Only 2-D Tensor or LodTensor input is supported. Slice and concat can only be
performed along the second dimension.
.. code-block:: text
Given:
x = [[0, 1, 2],
[3, 4, 5]]
y = [[6, 7 ,8],
[9, 10, 11]]
output = partial_sum([x, y], start_index=0, length=2)
we get:
output = [[6, 8],
[12, 14]]
Args:
input(list): List of input Tensors with data type float32, float64, int32,
int64.
Returns:
Variable: A Tensor with the same data type as input's.
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
x = fluid.data(name="x", shape=[None, 3], dtype="float32")
y = fluid.data(name="y", shape=[None, 3], dtype="float32")
sum = layers.partial_sum([x,y], start_index=0, length=2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
xx = np.array([1,2,3,4,5,6]).reshape((2,3)).astype("float32")
yy = np.array([6,5,4,4,5,6]).reshape((2,3)).astype("float32")
out = exe.run(feed={"x":xx, "y":yy}, fetch_list=[sum])
"""
for id, x in enumerate(input):
check_variable_and_dtype(x, 'input[' + str(id) + ']',
['float32', 'float64', 'int32', 'int64'],
'partial_sum')
inputs = {'X': input}
attrs = {}
attrs['start_index'] = start_index
attrs['length'] = length
helper = LayerHelper('partial_sum', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='partial_sum', inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
......@@ -2790,6 +2790,14 @@ class TestBook(LayerTest):
self.assertIsNotNone(out2)
return (out1)
def test_partial_sum(self):
with self.static_graph():
x = fluid.data(name="x", shape=[None, 3], dtype="float32")
y = fluid.data(name="y", shape=[None, 3], dtype="float32")
sum = fluid.contrib.layers.partial_sum(
[x, y], start_index=0, length=2)
return (sum)
def test_roi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import random
import six
class TestPartialSumOp(OpTest):
def setUp(self):
self.op_type = "partial_sum"
self.init_kernel_type()
self.init_para()
if self.length is -1:
end_index = self.column
else:
end_index = self.start_index + self.length
self.var_names = [
'x' + str(num) for num in six.moves.range(self.var_num)
]
self.vars = [np.random.random((self.batch_size, self.column)).astype(self.dtype)\
for num in six.moves.range(self.var_num) ]
self.inputs = {'X': list(zip(self.var_names, self.vars))}
self.attrs = {'start_index': self.start_index, 'length': self.length}
y = self.vars[0][:, self.start_index:end_index]
for i in six.moves.range(1, self.var_num):
y = y + self.vars[i][:, self.start_index:end_index]
self.outputs = {'Out': y}
def init_kernel_type(self):
self.dtype = np.float64
def init_para(self):
self.batch_size = random.randint(10, 20)
self.column = random.randint(101, 200)
self.start_index = random.randint(0, self.column - 1)
self.length = random.randint(0, self.column - self.start_index)
self.var_num = random.randint(1, 3)
def test_check_output(self):
self.check_output()
def test_check_grad(self):
for var_name in self.var_names:
self.check_grad([var_name], 'Out')
class TestPartialSumOp2(TestPartialSumOp):
def init_para(self):
self.batch_size = random.randint(1, 10)
self.column = random.randint(101, 200)
self.start_index = random.randint(0, self.column - 1)
self.length = -1
self.var_num = 3
class TestPartialSumOp3(TestPartialSumOp):
def init_para(self):
self.batch_size = random.randint(1, 10)
self.column = random.randint(101, 200)
self.start_index = self.column - 1
self.length = 1
self.var_num = 2
class TestPartialSumOp4(TestPartialSumOp):
def init_para(self):
self.batch_size = random.randint(1, 10)
self.column = random.randint(101, 200)
self.start_index = self.column - 1
self.length = 1
self.var_num = 1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册